Merge pull request 'Background variant processing' (#56) from asonix/backgrounded-variants into main

Reviewed-on: https://git.asonix.dog/asonix/pict-rs/pulls/56
This commit is contained in:
asonix 2024-04-01 22:17:30 +00:00
commit dfb38c7144
13 changed files with 919 additions and 617 deletions

View file

@ -253,9 +253,27 @@ Example:
### API ### API
pict-rs offers the following endpoints: pict-rs offers the following endpoints:
- `POST /image` for uploading an image. Uploaded content must be valid multipart/form-data with an - `POST /image?{args}` for uploading an image. Uploaded content must be valid multipart/form-data with an
image array located within the `images[]` key image array located within the `images[]` key
The {args} query serves multiple purpose for image uploads. The first is to provide
request-level validations for the uploaded media. Available keys are as follows:
- max_width: maximum width, in pixels, allowed for the uploaded media
- max_height: maximum height, in pixels, allowed for the uploaded media
- max_area: maximum area, in pixels, allowed for the uploaded media
- max_frame_count: maximum number of frames permitted for animations and videos
- max_file_size: maximum size, in megabytes, allowed
- allow_image: whether to permit still images in the upload
- allow_animation: whether to permit animations in the upload
- allow_video: whether to permit video in the upload
These validations apply in addition to the validations specified in the pict-rs configuration,
so uploaded media will be rejected if any of the validations fail.
The second purpose for the {args} query is to provide preprocess steps for the uploaded image.
The format is the same as in the process.{ext} endpoint. The images uploaded with these steps
provided will be processed before saving.
This endpoint returns the following JSON structure on success with a 201 Created status This endpoint returns the following JSON structure on success with a 201 Created status
```json ```json
{ {
@ -294,7 +312,9 @@ pict-rs offers the following endpoints:
"msg": "ok" "msg": "ok"
} }
``` ```
- `POST /image/backgrounded` Upload an image, like the `/image` endpoint, but don't wait to validate and process it. - `POST /image/backgrounded?{args}` Upload an image, like the `/image` endpoint, but don't wait to validate and process it.
The {args} query is the same format is the inline image upload endpoint.
This endpoint returns the following JSON structure on success with a 202 Accepted status This endpoint returns the following JSON structure on success with a 202 Accepted status
```json ```json
{ {

View file

@ -1,172 +0,0 @@
use crate::{
details::Details,
error::{Error, UploadError},
repo::Hash,
};
use dashmap::{mapref::entry::Entry, DashMap};
use flume::{r#async::RecvFut, Receiver, Sender};
use std::{
future::Future,
path::PathBuf,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use tracing::Span;
type OutcomeReceiver = Receiver<(Details, Arc<str>)>;
type ProcessMapKey = (Hash, PathBuf);
type ProcessMapInner = DashMap<ProcessMapKey, OutcomeReceiver>;
#[derive(Debug, Default, Clone)]
pub(crate) struct ProcessMap {
process_map: Arc<ProcessMapInner>,
}
impl ProcessMap {
pub(super) fn new() -> Self {
Self::default()
}
pub(super) async fn process<Fut>(
&self,
hash: Hash,
path: PathBuf,
fut: Fut,
) -> Result<(Details, Arc<str>), Error>
where
Fut: Future<Output = Result<(Details, Arc<str>), Error>>,
{
let key = (hash.clone(), path.clone());
let (sender, receiver) = flume::bounded(1);
let entry = self.process_map.entry(key.clone());
let (state, span) = match entry {
Entry::Vacant(vacant) => {
vacant.insert(receiver);
let span = tracing::info_span!(
"Processing image",
hash = ?hash,
path = ?path,
completed = &tracing::field::Empty,
);
metrics::counter!(crate::init_metrics::PROCESS_MAP_INSERTED).increment(1);
(CancelState::Sender { sender }, span)
}
Entry::Occupied(receiver) => {
let span = tracing::info_span!(
"Waiting for processed image",
hash = ?hash,
path = ?path,
);
let receiver = receiver.get().clone().into_recv_async();
(CancelState::Receiver { receiver }, span)
}
};
CancelSafeProcessor {
cancel_token: CancelToken {
span,
key,
state,
process_map: self.clone(),
},
fut,
}
.await
}
fn remove(&self, key: &ProcessMapKey) -> Option<OutcomeReceiver> {
self.process_map.remove(key).map(|(_, v)| v)
}
}
struct CancelToken {
span: Span,
key: ProcessMapKey,
state: CancelState,
process_map: ProcessMap,
}
enum CancelState {
Sender {
sender: Sender<(Details, Arc<str>)>,
},
Receiver {
receiver: RecvFut<'static, (Details, Arc<str>)>,
},
}
impl CancelState {
const fn is_sender(&self) -> bool {
matches!(self, Self::Sender { .. })
}
}
pin_project_lite::pin_project! {
struct CancelSafeProcessor<F> {
cancel_token: CancelToken,
#[pin]
fut: F,
}
}
impl<F> Future for CancelSafeProcessor<F>
where
F: Future<Output = Result<(Details, Arc<str>), Error>>,
{
type Output = Result<(Details, Arc<str>), Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.as_mut().project();
let span = &this.cancel_token.span;
let process_map = &this.cancel_token.process_map;
let state = &mut this.cancel_token.state;
let key = &this.cancel_token.key;
let fut = this.fut;
span.in_scope(|| match state {
CancelState::Sender { sender } => {
let res = std::task::ready!(fut.poll(cx));
if process_map.remove(key).is_some() {
metrics::counter!(crate::init_metrics::PROCESS_MAP_REMOVED).increment(1);
}
if let Ok(tup) = &res {
let _ = sender.try_send(tup.clone());
}
Poll::Ready(res)
}
CancelState::Receiver { ref mut receiver } => Pin::new(receiver)
.poll(cx)
.map(|res| res.map_err(|_| UploadError::Canceled.into())),
})
}
}
impl Drop for CancelToken {
fn drop(&mut self) {
if self.state.is_sender() {
let completed = self.process_map.remove(&self.key).is_none();
self.span.record("completed", completed);
if !completed {
metrics::counter!(crate::init_metrics::PROCESS_MAP_REMOVED).increment(1);
}
}
}
}

View file

@ -2,18 +2,17 @@ mod ffmpeg;
mod magick; mod magick;
use crate::{ use crate::{
concurrent_processor::ProcessMap,
details::Details, details::Details,
error::{Error, UploadError}, error::{Error, UploadError},
formats::{ImageFormat, InputProcessableFormat, InternalVideoFormat, ProcessableFormat}, formats::{ImageFormat, InputProcessableFormat, InternalVideoFormat, ProcessableFormat},
future::{WithMetrics, WithPollTimer, WithTimeout}, future::{WithMetrics, WithPollTimer, WithTimeout},
repo::{Hash, VariantAlreadyExists}, repo::{Hash, NotificationEntry, VariantAlreadyExists},
state::State, state::State,
store::Store, store::Store,
}; };
use std::{ use std::{
path::PathBuf, future::Future,
sync::Arc, sync::Arc,
time::{Duration, Instant}, time::{Duration, Instant},
}; };
@ -48,13 +47,12 @@ impl Drop for MetricsGuard {
} }
} }
#[tracing::instrument(skip(state, process_map, original_details, hash))] #[tracing::instrument(skip(state, original_details, hash))]
pub(crate) async fn generate<S: Store + 'static>( pub(crate) async fn generate<S: Store + 'static>(
state: &State<S>, state: &State<S>,
process_map: &ProcessMap,
format: InputProcessableFormat, format: InputProcessableFormat,
thumbnail_path: PathBuf, variant: String,
thumbnail_args: Vec<String>, variant_args: Vec<String>,
original_details: &Details, original_details: &Details,
hash: Hash, hash: Hash,
) -> Result<(Details, Arc<str>), Error> { ) -> Result<(Details, Arc<str>), Error> {
@ -67,25 +65,122 @@ pub(crate) async fn generate<S: Store + 'static>(
Ok((original_details.clone(), identifier)) Ok((original_details.clone(), identifier))
} else { } else {
let process_fut = process( let mut attempts = 0;
state, let tup = loop {
format, if attempts > 2 {
thumbnail_path.clone(), return Err(UploadError::ProcessTimeout.into());
thumbnail_args, }
original_details,
hash.clone(),
)
.with_poll_timer("process-future");
let (details, identifier) = process_map match state
.process(hash, thumbnail_path, process_fut) .repo
.with_poll_timer("process-map-future") .claim_variant_processing_rights(hash.clone(), variant.clone())
.with_timeout(Duration::from_secs(state.config.media.process_timeout * 4)) .await?
.with_metrics(crate::init_metrics::GENERATE_PROCESS) {
.await Ok(()) => {
.map_err(|_| UploadError::ProcessTimeout)??; // process
let process_future = process(
state,
format,
variant.clone(),
variant_args,
original_details,
hash.clone(),
)
.with_poll_timer("process-future");
Ok((details, identifier)) let res = heartbeat(state, hash.clone(), variant.clone(), process_future)
.with_poll_timer("heartbeat-future")
.with_timeout(Duration::from_secs(state.config.media.process_timeout * 4))
.with_metrics(crate::init_metrics::GENERATE_PROCESS)
.await
.map_err(|_| Error::from(UploadError::ProcessTimeout));
state
.repo
.notify_variant(hash.clone(), variant.clone())
.await?;
break res???;
}
Err(entry) => {
if let Some(tuple) = wait_timeout(
hash.clone(),
variant.clone(),
entry,
state,
Duration::from_secs(20),
)
.await?
{
break tuple;
}
attempts += 1;
}
}
};
Ok(tup)
}
}
pub(crate) async fn wait_timeout<S: Store + 'static>(
hash: Hash,
variant: String,
mut entry: NotificationEntry,
state: &State<S>,
timeout: Duration,
) -> Result<Option<(Details, Arc<str>)>, Error> {
let notified = entry.notified_timeout(timeout);
if let Some(identifier) = state
.repo
.variant_identifier(hash.clone(), variant.clone())
.await?
{
let details = crate::ensure_details_identifier(state, &identifier).await?;
return Ok(Some((details, identifier)));
}
match notified.await {
Ok(()) => tracing::debug!("notified"),
Err(_) => tracing::debug!("timeout"),
}
Ok(None)
}
async fn heartbeat<S, O>(
state: &State<S>,
hash: Hash,
variant: String,
future: impl Future<Output = O>,
) -> Result<O, Error> {
let repo = state.repo.clone();
let handle = crate::sync::abort_on_drop(crate::sync::spawn("heartbeat-task", async move {
let mut interval = tokio::time::interval(Duration::from_secs(5));
loop {
interval.tick().await;
if let Err(e) = repo.variant_heartbeat(hash.clone(), variant.clone()).await {
break Error::from(e);
}
}
}));
let future = std::pin::pin!(future);
tokio::select! {
biased;
output = future => {
Ok(output)
}
res = handle => {
Err(res.map_err(|_| UploadError::Canceled)?)
}
} }
} }
@ -93,8 +188,8 @@ pub(crate) async fn generate<S: Store + 'static>(
async fn process<S: Store + 'static>( async fn process<S: Store + 'static>(
state: &State<S>, state: &State<S>,
output_format: InputProcessableFormat, output_format: InputProcessableFormat,
thumbnail_path: PathBuf, variant: String,
thumbnail_args: Vec<String>, variant_args: Vec<String>,
original_details: &Details, original_details: &Details,
hash: Hash, hash: Hash,
) -> Result<(Details, Arc<str>), Error> { ) -> Result<(Details, Arc<str>), Error> {
@ -120,7 +215,7 @@ async fn process<S: Store + 'static>(
let stream = state.store.to_stream(&identifier, None, None).await?; let stream = state.store.to_stream(&identifier, None, None).await?;
let bytes = let bytes =
crate::magick::process_image_command(state, thumbnail_args, input_format, format, quality) crate::magick::process_image_command(state, variant_args, input_format, format, quality)
.await? .await?
.drive_with_stream(stream) .drive_with_stream(stream)
.into_bytes_stream() .into_bytes_stream()
@ -142,19 +237,21 @@ async fn process<S: Store + 'static>(
) )
.await?; .await?;
if let Err(VariantAlreadyExists) = state let identifier = if let Err(VariantAlreadyExists) = state
.repo .repo
.relate_variant_identifier( .relate_variant_identifier(hash.clone(), variant.clone(), &identifier)
hash,
thumbnail_path.to_string_lossy().to_string(),
&identifier,
)
.await? .await?
{ {
state.store.remove(&identifier).await?; state.store.remove(&identifier).await?;
} state
.repo
state.repo.relate_details(&identifier, &details).await?; .variant_identifier(hash, variant)
.await?
.ok_or(UploadError::MissingIdentifier)?
} else {
state.repo.relate_details(&identifier, &details).await?;
identifier
};
guard.disarm(); guard.disarm();

View file

@ -1,7 +1,6 @@
mod backgrounded; mod backgrounded;
mod blurhash; mod blurhash;
mod bytes_stream; mod bytes_stream;
mod concurrent_processor;
mod config; mod config;
mod details; mod details;
mod discover; mod discover;
@ -57,7 +56,6 @@ use state::State;
use std::{ use std::{
marker::PhantomData, marker::PhantomData,
path::Path, path::Path,
path::PathBuf,
rc::Rc, rc::Rc,
sync::{Arc, OnceLock}, sync::{Arc, OnceLock},
time::{Duration, SystemTime}, time::{Duration, SystemTime},
@ -71,7 +69,6 @@ use tracing_actix_web::TracingLogger;
use self::{ use self::{
backgrounded::Backgrounded, backgrounded::Backgrounded,
concurrent_processor::ProcessMap,
config::{Configuration, Operation}, config::{Configuration, Operation},
details::Details, details::Details,
either::Either, either::Either,
@ -123,6 +120,7 @@ async fn ensure_details<S: Store + 'static>(
ensure_details_identifier(state, &identifier).await ensure_details_identifier(state, &identifier).await
} }
#[tracing::instrument(skip(state))]
async fn ensure_details_identifier<S: Store + 'static>( async fn ensure_details_identifier<S: Store + 'static>(
state: &State<S>, state: &State<S>,
identifier: &Arc<str>, identifier: &Arc<str>,
@ -775,7 +773,7 @@ fn prepare_process(
config: &Configuration, config: &Configuration,
operations: Vec<(String, String)>, operations: Vec<(String, String)>,
ext: &str, ext: &str,
) -> Result<(InputProcessableFormat, PathBuf, Vec<String>), Error> { ) -> Result<(InputProcessableFormat, String, Vec<String>), Error> {
let operations = operations let operations = operations
.into_iter() .into_iter()
.filter(|(k, _)| config.media.filters.contains(&k.to_lowercase())) .filter(|(k, _)| config.media.filters.contains(&k.to_lowercase()))
@ -785,10 +783,9 @@ fn prepare_process(
.parse::<InputProcessableFormat>() .parse::<InputProcessableFormat>()
.map_err(|_| UploadError::UnsupportedProcessExtension)?; .map_err(|_| UploadError::UnsupportedProcessExtension)?;
let (thumbnail_path, thumbnail_args) = let (variant, variant_args) = self::processor::build_chain(&operations, &format.to_string())?;
self::processor::build_chain(&operations, &format.to_string())?;
Ok((format, thumbnail_path, thumbnail_args)) Ok((format, variant, variant_args))
} }
#[tracing::instrument(name = "Fetching derived details", skip(state))] #[tracing::instrument(name = "Fetching derived details", skip(state))]
@ -799,7 +796,7 @@ async fn process_details<S: Store>(
) -> Result<HttpResponse, Error> { ) -> Result<HttpResponse, Error> {
let alias = alias_from_query(source.into(), &state).await?; let alias = alias_from_query(source.into(), &state).await?;
let (_, thumbnail_path, _) = prepare_process(&state.config, operations, ext.as_str())?; let (_, variant, _) = prepare_process(&state.config, operations, ext.as_str())?;
let hash = state let hash = state
.repo .repo
@ -807,18 +804,16 @@ async fn process_details<S: Store>(
.await? .await?
.ok_or(UploadError::MissingAlias)?; .ok_or(UploadError::MissingAlias)?;
let thumbnail_string = thumbnail_path.to_string_lossy().to_string();
if !state.config.server.read_only { if !state.config.server.read_only {
state state
.repo .repo
.accessed_variant(hash.clone(), thumbnail_string.clone()) .accessed_variant(hash.clone(), variant.clone())
.await?; .await?;
} }
let identifier = state let identifier = state
.repo .repo
.variant_identifier(hash, thumbnail_string) .variant_identifier(hash, variant)
.await? .await?
.ok_or(UploadError::MissingAlias)?; .ok_or(UploadError::MissingAlias)?;
@ -848,20 +843,16 @@ async fn not_found_hash(repo: &ArcRepo) -> Result<Option<(Alias, Hash)>, Error>
} }
/// Process files /// Process files
#[tracing::instrument(name = "Serving processed image", skip(state, process_map))] #[tracing::instrument(name = "Serving processed image", skip(state))]
async fn process<S: Store + 'static>( async fn process<S: Store + 'static>(
range: Option<web::Header<Range>>, range: Option<web::Header<Range>>,
web::Query(ProcessQuery { source, operations }): web::Query<ProcessQuery>, web::Query(ProcessQuery { source, operations }): web::Query<ProcessQuery>,
ext: web::Path<String>, ext: web::Path<String>,
state: web::Data<State<S>>, state: web::Data<State<S>>,
process_map: web::Data<ProcessMap>,
) -> Result<HttpResponse, Error> { ) -> Result<HttpResponse, Error> {
let alias = proxy_alias_from_query(source.into(), &state).await?; let alias = proxy_alias_from_query(source.into(), &state).await?;
let (format, thumbnail_path, thumbnail_args) = let (format, variant, variant_args) = prepare_process(&state.config, operations, ext.as_str())?;
prepare_process(&state.config, operations, ext.as_str())?;
let path_string = thumbnail_path.to_string_lossy().to_string();
let (hash, alias, not_found) = if let Some(hash) = state.repo.hash(&alias).await? { let (hash, alias, not_found) = if let Some(hash) = state.repo.hash(&alias).await? {
(hash, alias, false) (hash, alias, false)
@ -876,13 +867,13 @@ async fn process<S: Store + 'static>(
if !state.config.server.read_only { if !state.config.server.read_only {
state state
.repo .repo
.accessed_variant(hash.clone(), path_string.clone()) .accessed_variant(hash.clone(), variant.clone())
.await?; .await?;
} }
let identifier_opt = state let identifier_opt = state
.repo .repo
.variant_identifier(hash.clone(), path_string) .variant_identifier(hash.clone(), variant.clone())
.await?; .await?;
let (details, identifier) = if let Some(identifier) = identifier_opt { let (details, identifier) = if let Some(identifier) = identifier_opt {
@ -894,18 +885,34 @@ async fn process<S: Store + 'static>(
return Err(UploadError::ReadOnly.into()); return Err(UploadError::ReadOnly.into());
} }
let original_details = ensure_details(&state, &alias).await?; queue_generate(&state.repo, format, alias, variant.clone(), variant_args).await?;
generate::generate( let mut attempts = 0;
&state, loop {
&process_map, if attempts > 6 {
format, return Err(UploadError::ProcessTimeout.into());
thumbnail_path, }
thumbnail_args,
&original_details, let entry = state
hash, .repo
) .variant_waiter(hash.clone(), variant.clone())
.await? .await?;
let opt = generate::wait_timeout(
hash.clone(),
variant.clone(),
entry,
&state,
Duration::from_secs(5),
)
.await?;
if let Some(tuple) = opt {
break tuple;
}
attempts += 1;
}
}; };
if let Some(public_url) = state.store.public_url(&identifier) { if let Some(public_url) = state.store.public_url(&identifier) {
@ -936,9 +943,8 @@ async fn process_head<S: Store + 'static>(
} }
}; };
let (_, thumbnail_path, _) = prepare_process(&state.config, operations, ext.as_str())?; let (_, variant, _) = prepare_process(&state.config, operations, ext.as_str())?;
let path_string = thumbnail_path.to_string_lossy().to_string();
let Some(hash) = state.repo.hash(&alias).await? else { let Some(hash) = state.repo.hash(&alias).await? else {
// Invalid alias // Invalid alias
return Ok(HttpResponse::NotFound().finish()); return Ok(HttpResponse::NotFound().finish());
@ -947,14 +953,11 @@ async fn process_head<S: Store + 'static>(
if !state.config.server.read_only { if !state.config.server.read_only {
state state
.repo .repo
.accessed_variant(hash.clone(), path_string.clone()) .accessed_variant(hash.clone(), variant.clone())
.await?; .await?;
} }
let identifier_opt = state let identifier_opt = state.repo.variant_identifier(hash.clone(), variant).await?;
.repo
.variant_identifier(hash.clone(), path_string)
.await?;
if let Some(identifier) = identifier_opt { if let Some(identifier) = identifier_opt {
let details = ensure_details_identifier(&state, &identifier).await?; let details = ensure_details_identifier(&state, &identifier).await?;
@ -973,7 +976,7 @@ async fn process_head<S: Store + 'static>(
/// Process files /// Process files
#[tracing::instrument(name = "Spawning image process", skip(state))] #[tracing::instrument(name = "Spawning image process", skip(state))]
async fn process_backgrounded<S: Store>( async fn process_backgrounded<S: Store + 'static>(
web::Query(ProcessQuery { source, operations }): web::Query<ProcessQuery>, web::Query(ProcessQuery { source, operations }): web::Query<ProcessQuery>,
ext: web::Path<String>, ext: web::Path<String>,
state: web::Data<State<S>>, state: web::Data<State<S>>,
@ -990,10 +993,9 @@ async fn process_backgrounded<S: Store>(
} }
}; };
let (target_format, process_path, process_args) = let (target_format, variant, variant_args) =
prepare_process(&state.config, operations, ext.as_str())?; prepare_process(&state.config, operations, ext.as_str())?;
let path_string = process_path.to_string_lossy().to_string();
let Some(hash) = state.repo.hash(&source).await? else { let Some(hash) = state.repo.hash(&source).await? else {
// Invalid alias // Invalid alias
return Ok(HttpResponse::BadRequest().finish()); return Ok(HttpResponse::BadRequest().finish());
@ -1001,7 +1003,7 @@ async fn process_backgrounded<S: Store>(
let identifier_opt = state let identifier_opt = state
.repo .repo
.variant_identifier(hash.clone(), path_string) .variant_identifier(hash.clone(), variant.clone())
.await?; .await?;
if identifier_opt.is_some() { if identifier_opt.is_some() {
@ -1012,14 +1014,7 @@ async fn process_backgrounded<S: Store>(
return Err(UploadError::ReadOnly.into()); return Err(UploadError::ReadOnly.into());
} }
queue_generate( queue_generate(&state.repo, target_format, source, variant, variant_args).await?;
&state.repo,
target_format,
source,
process_path,
process_args,
)
.await?;
Ok(HttpResponse::Accepted().finish()) Ok(HttpResponse::Accepted().finish())
} }
@ -1591,14 +1586,12 @@ fn json_config() -> web::JsonConfig {
fn configure_endpoints<S: Store + 'static, F: Fn(&mut web::ServiceConfig)>( fn configure_endpoints<S: Store + 'static, F: Fn(&mut web::ServiceConfig)>(
config: &mut web::ServiceConfig, config: &mut web::ServiceConfig,
state: State<S>, state: State<S>,
process_map: ProcessMap,
extra_config: F, extra_config: F,
) { ) {
config config
.app_data(query_config()) .app_data(query_config())
.app_data(json_config()) .app_data(json_config())
.app_data(web::Data::new(state.clone())) .app_data(web::Data::new(state.clone()))
.app_data(web::Data::new(process_map.clone()))
.route("/healthz", web::get().to(healthz::<S>)) .route("/healthz", web::get().to(healthz::<S>))
.service( .service(
web::scope("/image") web::scope("/image")
@ -1706,12 +1699,12 @@ fn spawn_cleanup<S>(state: State<S>) {
}); });
} }
fn spawn_workers<S>(state: State<S>, process_map: ProcessMap) fn spawn_workers<S>(state: State<S>)
where where
S: Store + 'static, S: Store + 'static,
{ {
crate::sync::spawn("cleanup-worker", queue::process_cleanup(state.clone())); crate::sync::spawn("cleanup-worker", queue::process_cleanup(state.clone()));
crate::sync::spawn("process-worker", queue::process_images(state, process_map)); crate::sync::spawn("process-worker", queue::process_images(state));
} }
fn watch_keys(tls: Tls, sender: ChannelSender) -> DropHandle<()> { fn watch_keys(tls: Tls, sender: ChannelSender) -> DropHandle<()> {
@ -1737,8 +1730,6 @@ async fn launch<
state: State<S>, state: State<S>,
extra_config: F, extra_config: F,
) -> color_eyre::Result<()> { ) -> color_eyre::Result<()> {
let process_map = ProcessMap::new();
let address = state.config.server.address; let address = state.config.server.address;
let tls = Tls::from_config(&state.config); let tls = Tls::from_config(&state.config);
@ -1748,18 +1739,15 @@ async fn launch<
let server = HttpServer::new(move || { let server = HttpServer::new(move || {
let extra_config = extra_config.clone(); let extra_config = extra_config.clone();
let state = state.clone(); let state = state.clone();
let process_map = process_map.clone();
spawn_workers(state.clone(), process_map.clone()); spawn_workers(state.clone());
App::new() App::new()
.wrap(TracingLogger::default()) .wrap(TracingLogger::default())
.wrap(Deadline) .wrap(Deadline)
.wrap(Metrics) .wrap(Metrics)
.wrap(Payload::new()) .wrap(Payload::new())
.configure(move |sc| { .configure(move |sc| configure_endpoints(sc, state.clone(), extra_config))
configure_endpoints(sc, state.clone(), process_map.clone(), extra_config)
})
}); });
if let Some(tls) = tls { if let Some(tls) = tls {

View file

@ -91,7 +91,7 @@ impl ResizeKind {
pub(crate) fn build_chain( pub(crate) fn build_chain(
args: &[(String, String)], args: &[(String, String)],
ext: &str, ext: &str,
) -> Result<(PathBuf, Vec<String>), Error> { ) -> Result<(String, Vec<String>), Error> {
fn parse<P: Processor>(key: &str, value: &str) -> Result<Option<P>, Error> { fn parse<P: Processor>(key: &str, value: &str) -> Result<Option<P>, Error> {
if key == P::NAME { if key == P::NAME {
return Ok(Some(P::parse(key, value).ok_or(UploadError::ParsePath)?)); return Ok(Some(P::parse(key, value).ok_or(UploadError::ParsePath)?));
@ -122,7 +122,7 @@ pub(crate) fn build_chain(
path.push(ext); path.push(ext);
Ok((path, args)) Ok((path.to_string_lossy().to_string(), args))
} }
impl Processor for Identity { impl Processor for Identity {

View file

@ -1,5 +1,4 @@
use crate::{ use crate::{
concurrent_processor::ProcessMap,
error::{Error, UploadError}, error::{Error, UploadError},
formats::InputProcessableFormat, formats::InputProcessableFormat,
future::{LocalBoxFuture, WithPollTimer}, future::{LocalBoxFuture, WithPollTimer},
@ -12,7 +11,6 @@ use crate::{
use std::{ use std::{
ops::Deref, ops::Deref,
path::PathBuf,
sync::Arc, sync::Arc,
time::{Duration, Instant}, time::{Duration, Instant},
}; };
@ -63,7 +61,7 @@ enum Process {
Generate { Generate {
target_format: InputProcessableFormat, target_format: InputProcessableFormat,
source: Serde<Alias>, source: Serde<Alias>,
process_path: PathBuf, process_path: String,
process_args: Vec<String>, process_args: Vec<String>,
}, },
} }
@ -178,13 +176,13 @@ pub(crate) async fn queue_generate(
repo: &ArcRepo, repo: &ArcRepo,
target_format: InputProcessableFormat, target_format: InputProcessableFormat,
source: Alias, source: Alias,
process_path: PathBuf, variant: String,
process_args: Vec<String>, process_args: Vec<String>,
) -> Result<(), Error> { ) -> Result<(), Error> {
let job = serde_json::to_value(Process::Generate { let job = serde_json::to_value(Process::Generate {
target_format, target_format,
source: Serde::new(source), source: Serde::new(source),
process_path, process_path: variant,
process_args, process_args,
}) })
.map_err(UploadError::PushJob)?; .map_err(UploadError::PushJob)?;
@ -196,8 +194,8 @@ pub(crate) async fn process_cleanup<S: Store + 'static>(state: State<S>) {
process_jobs(state, CLEANUP_QUEUE, cleanup::perform).await process_jobs(state, CLEANUP_QUEUE, cleanup::perform).await
} }
pub(crate) async fn process_images<S: Store + 'static>(state: State<S>, process_map: ProcessMap) { pub(crate) async fn process_images<S: Store + 'static>(state: State<S>) {
process_image_jobs(state, process_map, PROCESS_QUEUE, process::perform).await process_jobs(state, PROCESS_QUEUE, process::perform).await
} }
struct MetricsGuard { struct MetricsGuard {
@ -357,7 +355,7 @@ where
let (job_id, job) = state let (job_id, job) = state
.repo .repo
.pop(queue, worker_id) .pop(queue, worker_id)
.with_poll_timer("pop-cleanup") .with_poll_timer("pop-job")
.await?; .await?;
let guard = MetricsGuard::guard(worker_id, queue); let guard = MetricsGuard::guard(worker_id, queue);
@ -369,99 +367,13 @@ where
job_id, job_id,
(callback)(state, job), (callback)(state, job),
) )
.with_poll_timer("cleanup-job-and-heartbeat") .with_poll_timer("job-and-heartbeat")
.await;
state
.repo
.complete_job(queue, worker_id, job_id, job_result(&res))
.with_poll_timer("cleanup-job-complete")
.await?;
res?;
guard.disarm();
Ok(()) as Result<(), Error>
}
.instrument(tracing::info_span!("tick", %queue, %worker_id))
.await?;
}
}
async fn process_image_jobs<S, F>(
state: State<S>,
process_map: ProcessMap,
queue: &'static str,
callback: F,
) where
S: Store,
for<'a> F: Fn(&'a State<S>, &'a ProcessMap, serde_json::Value) -> JobFuture<'a> + Copy,
{
let worker_id = uuid::Uuid::new_v4();
loop {
tracing::trace!("process_image_jobs: looping");
crate::sync::cooperate().await;
let res = image_job_loop(&state, &process_map, worker_id, queue, callback)
.with_poll_timer("image-job-loop")
.await;
if let Err(e) = res {
tracing::warn!("Error processing jobs: {}", format!("{e}"));
tracing::warn!("{}", format!("{e:?}"));
if e.is_disconnected() {
tokio::time::sleep(Duration::from_secs(10)).await;
}
continue;
}
break;
}
}
async fn image_job_loop<S, F>(
state: &State<S>,
process_map: &ProcessMap,
worker_id: uuid::Uuid,
queue: &'static str,
callback: F,
) -> Result<(), Error>
where
S: Store,
for<'a> F: Fn(&'a State<S>, &'a ProcessMap, serde_json::Value) -> JobFuture<'a> + Copy,
{
loop {
tracing::trace!("image_job_loop: looping");
crate::sync::cooperate().await;
async {
let (job_id, job) = state
.repo
.pop(queue, worker_id)
.with_poll_timer("pop-process")
.await?;
let guard = MetricsGuard::guard(worker_id, queue);
let res = heartbeat(
&state.repo,
queue,
worker_id,
job_id,
(callback)(state, process_map, job),
)
.with_poll_timer("process-job-and-heartbeat")
.await; .await;
state state
.repo .repo
.complete_job(queue, worker_id, job_id, job_result(&res)) .complete_job(queue, worker_id, job_id, job_result(&res))
.with_poll_timer("job-complete")
.await?; .await?;
res?; res?;

View file

@ -2,7 +2,6 @@ use time::Instant;
use tracing::{Instrument, Span}; use tracing::{Instrument, Span};
use crate::{ use crate::{
concurrent_processor::ProcessMap,
error::{Error, UploadError}, error::{Error, UploadError},
formats::InputProcessableFormat, formats::InputProcessableFormat,
future::WithPollTimer, future::WithPollTimer,
@ -14,15 +13,11 @@ use crate::{
store::Store, store::Store,
UploadQuery, UploadQuery,
}; };
use std::{path::PathBuf, sync::Arc}; use std::sync::Arc;
use super::{JobContext, JobFuture, JobResult}; use super::{JobContext, JobFuture, JobResult};
pub(super) fn perform<'a, S>( pub(super) fn perform<S>(state: &State<S>, job: serde_json::Value) -> JobFuture<'_>
state: &'a State<S>,
process_map: &'a ProcessMap,
job: serde_json::Value,
) -> JobFuture<'a>
where where
S: Store + 'static, S: Store + 'static,
{ {
@ -58,7 +53,6 @@ where
} => { } => {
generate( generate(
state, state,
process_map,
target_format, target_format,
Serde::into_inner(source), Serde::into_inner(source),
process_path, process_path,
@ -178,13 +172,12 @@ where
Ok(()) Ok(())
} }
#[tracing::instrument(skip(state, process_map, process_path, process_args))] #[tracing::instrument(skip(state, variant, process_args))]
async fn generate<S: Store + 'static>( async fn generate<S: Store + 'static>(
state: &State<S>, state: &State<S>,
process_map: &ProcessMap,
target_format: InputProcessableFormat, target_format: InputProcessableFormat,
source: Alias, source: Alias,
process_path: PathBuf, variant: String,
process_args: Vec<String>, process_args: Vec<String>,
) -> JobResult { ) -> JobResult {
let hash = state let hash = state
@ -195,10 +188,9 @@ async fn generate<S: Store + 'static>(
.ok_or(UploadError::MissingAlias) .ok_or(UploadError::MissingAlias)
.abort()?; .abort()?;
let path_string = process_path.to_string_lossy().to_string();
let identifier_opt = state let identifier_opt = state
.repo .repo
.variant_identifier(hash.clone(), path_string) .variant_identifier(hash.clone(), variant.clone())
.await .await
.retry()?; .retry()?;
@ -211,9 +203,8 @@ async fn generate<S: Store + 'static>(
crate::generate::generate( crate::generate::generate(
state, state,
process_map,
target_format, target_format,
process_path, variant,
process_args, process_args,
&original_details, &original_details,
hash, hash,

View file

@ -3,6 +3,7 @@ mod delete_token;
mod hash; mod hash;
mod metrics; mod metrics;
mod migrate; mod migrate;
mod notification_map;
use crate::{ use crate::{
config, config,
@ -23,6 +24,7 @@ pub(crate) use alias::Alias;
pub(crate) use delete_token::DeleteToken; pub(crate) use delete_token::DeleteToken;
pub(crate) use hash::Hash; pub(crate) use hash::Hash;
pub(crate) use migrate::{migrate_04, migrate_repo}; pub(crate) use migrate::{migrate_04, migrate_repo};
pub(crate) use notification_map::NotificationEntry;
pub(crate) type ArcRepo = Arc<dyn FullRepo>; pub(crate) type ArcRepo = Arc<dyn FullRepo>;
@ -103,6 +105,7 @@ pub(crate) trait FullRepo:
+ AliasRepo + AliasRepo
+ QueueRepo + QueueRepo
+ HashRepo + HashRepo
+ VariantRepo
+ StoreMigrationRepo + StoreMigrationRepo
+ AliasAccessRepo + AliasAccessRepo
+ VariantAccessRepo + VariantAccessRepo
@ -653,20 +656,6 @@ pub(crate) trait HashRepo: BaseRepo {
async fn identifier(&self, hash: Hash) -> Result<Option<Arc<str>>, RepoError>; async fn identifier(&self, hash: Hash) -> Result<Option<Arc<str>>, RepoError>;
async fn relate_variant_identifier(
&self,
hash: Hash,
variant: String,
identifier: &Arc<str>,
) -> Result<Result<(), VariantAlreadyExists>, RepoError>;
async fn variant_identifier(
&self,
hash: Hash,
variant: String,
) -> Result<Option<Arc<str>>, RepoError>;
async fn variants(&self, hash: Hash) -> Result<Vec<(String, Arc<str>)>, RepoError>;
async fn remove_variant(&self, hash: Hash, variant: String) -> Result<(), RepoError>;
async fn relate_blurhash(&self, hash: Hash, blurhash: Arc<str>) -> Result<(), RepoError>; async fn relate_blurhash(&self, hash: Hash, blurhash: Arc<str>) -> Result<(), RepoError>;
async fn blurhash(&self, hash: Hash) -> Result<Option<Arc<str>>, RepoError>; async fn blurhash(&self, hash: Hash) -> Result<Option<Arc<str>>, RepoError>;
@ -726,6 +715,96 @@ where
T::identifier(self, hash).await T::identifier(self, hash).await
} }
async fn relate_blurhash(&self, hash: Hash, blurhash: Arc<str>) -> Result<(), RepoError> {
T::relate_blurhash(self, hash, blurhash).await
}
async fn blurhash(&self, hash: Hash) -> Result<Option<Arc<str>>, RepoError> {
T::blurhash(self, hash).await
}
async fn relate_motion_identifier(
&self,
hash: Hash,
identifier: &Arc<str>,
) -> Result<(), RepoError> {
T::relate_motion_identifier(self, hash, identifier).await
}
async fn motion_identifier(&self, hash: Hash) -> Result<Option<Arc<str>>, RepoError> {
T::motion_identifier(self, hash).await
}
async fn cleanup_hash(&self, hash: Hash) -> Result<(), RepoError> {
T::cleanup_hash(self, hash).await
}
}
#[async_trait::async_trait(?Send)]
pub(crate) trait VariantRepo: BaseRepo {
async fn claim_variant_processing_rights(
&self,
hash: Hash,
variant: String,
) -> Result<Result<(), NotificationEntry>, RepoError>;
async fn variant_waiter(
&self,
hash: Hash,
variant: String,
) -> Result<NotificationEntry, RepoError>;
async fn variant_heartbeat(&self, hash: Hash, variant: String) -> Result<(), RepoError>;
async fn notify_variant(&self, hash: Hash, variant: String) -> Result<(), RepoError>;
async fn relate_variant_identifier(
&self,
hash: Hash,
variant: String,
identifier: &Arc<str>,
) -> Result<Result<(), VariantAlreadyExists>, RepoError>;
async fn variant_identifier(
&self,
hash: Hash,
variant: String,
) -> Result<Option<Arc<str>>, RepoError>;
async fn variants(&self, hash: Hash) -> Result<Vec<(String, Arc<str>)>, RepoError>;
async fn remove_variant(&self, hash: Hash, variant: String) -> Result<(), RepoError>;
}
#[async_trait::async_trait(?Send)]
impl<T> VariantRepo for Arc<T>
where
T: VariantRepo,
{
async fn claim_variant_processing_rights(
&self,
hash: Hash,
variant: String,
) -> Result<Result<(), NotificationEntry>, RepoError> {
T::claim_variant_processing_rights(self, hash, variant).await
}
async fn variant_waiter(
&self,
hash: Hash,
variant: String,
) -> Result<NotificationEntry, RepoError> {
T::variant_waiter(self, hash, variant).await
}
async fn variant_heartbeat(&self, hash: Hash, variant: String) -> Result<(), RepoError> {
T::variant_heartbeat(self, hash, variant).await
}
async fn notify_variant(&self, hash: Hash, variant: String) -> Result<(), RepoError> {
T::notify_variant(self, hash, variant).await
}
async fn relate_variant_identifier( async fn relate_variant_identifier(
&self, &self,
hash: Hash, hash: Hash,
@ -750,30 +829,6 @@ where
async fn remove_variant(&self, hash: Hash, variant: String) -> Result<(), RepoError> { async fn remove_variant(&self, hash: Hash, variant: String) -> Result<(), RepoError> {
T::remove_variant(self, hash, variant).await T::remove_variant(self, hash, variant).await
} }
async fn relate_blurhash(&self, hash: Hash, blurhash: Arc<str>) -> Result<(), RepoError> {
T::relate_blurhash(self, hash, blurhash).await
}
async fn blurhash(&self, hash: Hash) -> Result<Option<Arc<str>>, RepoError> {
T::blurhash(self, hash).await
}
async fn relate_motion_identifier(
&self,
hash: Hash,
identifier: &Arc<str>,
) -> Result<(), RepoError> {
T::relate_motion_identifier(self, hash, identifier).await
}
async fn motion_identifier(&self, hash: Hash) -> Result<Option<Arc<str>>, RepoError> {
T::motion_identifier(self, hash).await
}
async fn cleanup_hash(&self, hash: Hash) -> Result<(), RepoError> {
T::cleanup_hash(self, hash).await
}
} }
#[async_trait::async_trait(?Send)] #[async_trait::async_trait(?Send)]

View file

@ -0,0 +1,94 @@
use dashmap::DashMap;
use std::{
future::Future,
sync::{
atomic::{AtomicBool, Ordering},
Arc, Weak,
},
time::Duration,
};
use tokio::sync::Notify;
use crate::future::WithTimeout;
type Map = Arc<DashMap<Arc<str>, Weak<NotificationEntryInner>>>;
#[derive(Clone)]
pub(super) struct NotificationMap {
map: Map,
}
pub(crate) struct NotificationEntry {
inner: Arc<NotificationEntryInner>,
}
struct NotificationEntryInner {
key: Arc<str>,
map: Map,
notify: Notify,
armed: AtomicBool,
}
impl NotificationMap {
pub(super) fn new() -> Self {
Self {
map: Arc::new(DashMap::new()),
}
}
pub(super) fn register_interest(&self, key: Arc<str>) -> NotificationEntry {
let new_entry = Arc::new(NotificationEntryInner {
key: key.clone(),
map: self.map.clone(),
notify: crate::sync::bare_notify(),
armed: AtomicBool::new(false),
});
let mut key_entry = self
.map
.entry(key)
.or_insert_with(|| Arc::downgrade(&new_entry));
let upgraded_entry = key_entry.value().upgrade();
let inner = if let Some(entry) = upgraded_entry {
entry
} else {
*key_entry.value_mut() = Arc::downgrade(&new_entry);
new_entry
};
inner.armed.store(true, Ordering::Release);
NotificationEntry { inner }
}
pub(super) fn notify(&self, key: &str) {
if let Some(notifier) = self.map.get(key).and_then(|v| v.upgrade()) {
notifier.notify.notify_waiters();
}
}
}
impl NotificationEntry {
pub(crate) fn notified_timeout(
&mut self,
duration: Duration,
) -> impl Future<Output = Result<(), tokio::time::error::Elapsed>> + '_ {
self.inner.notify.notified().with_timeout(duration)
}
}
impl Default for NotificationMap {
fn default() -> Self {
Self::new()
}
}
impl Drop for NotificationEntryInner {
fn drop(&mut self) {
if self.armed.load(Ordering::Acquire) {
self.map.remove(&self.key);
}
}
}

View file

@ -4,6 +4,7 @@ mod schema;
use std::{ use std::{
collections::{BTreeSet, VecDeque}, collections::{BTreeSet, VecDeque},
future::Future,
path::PathBuf, path::PathBuf,
sync::{ sync::{
atomic::{AtomicU64, Ordering}, atomic::{AtomicU64, Ordering},
@ -43,10 +44,11 @@ use self::job_status::JobStatus;
use super::{ use super::{
metrics::{PopMetricsGuard, PushMetricsGuard, WaitMetricsGuard}, metrics::{PopMetricsGuard, PushMetricsGuard, WaitMetricsGuard},
notification_map::{NotificationEntry, NotificationMap},
Alias, AliasAccessRepo, AliasAlreadyExists, AliasRepo, BaseRepo, DeleteToken, DetailsRepo, Alias, AliasAccessRepo, AliasAlreadyExists, AliasRepo, BaseRepo, DeleteToken, DetailsRepo,
FullRepo, Hash, HashAlreadyExists, HashPage, HashRepo, JobId, JobResult, OrderedHash, FullRepo, Hash, HashAlreadyExists, HashPage, HashRepo, JobId, JobResult, OrderedHash,
ProxyRepo, QueueRepo, RepoError, SettingsRepo, StoreMigrationRepo, UploadId, UploadRepo, ProxyRepo, QueueRepo, RepoError, SettingsRepo, StoreMigrationRepo, UploadId, UploadRepo,
UploadResult, VariantAccessRepo, VariantAlreadyExists, UploadResult, VariantAccessRepo, VariantAlreadyExists, VariantRepo,
}; };
#[derive(Clone)] #[derive(Clone)]
@ -62,6 +64,7 @@ struct Inner {
notifier_pool: Pool<AsyncPgConnection>, notifier_pool: Pool<AsyncPgConnection>,
queue_notifications: DashMap<String, Arc<Notify>>, queue_notifications: DashMap<String, Arc<Notify>>,
upload_notifications: DashMap<UploadId, Weak<Notify>>, upload_notifications: DashMap<UploadId, Weak<Notify>>,
keyed_notifications: NotificationMap,
} }
struct UploadInterest { struct UploadInterest {
@ -81,6 +84,10 @@ struct UploadNotifierState<'a> {
inner: &'a Inner, inner: &'a Inner,
} }
struct KeyedNotifierState<'a> {
inner: &'a Inner,
}
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub(crate) enum ConnectPostgresError { pub(crate) enum ConnectPostgresError {
#[error("Failed to connect to postgres for migrations")] #[error("Failed to connect to postgres for migrations")]
@ -102,7 +109,7 @@ pub(crate) enum PostgresError {
Pool(#[source] RunError), Pool(#[source] RunError),
#[error("Error in database")] #[error("Error in database")]
Diesel(#[source] diesel::result::Error), Diesel(#[from] diesel::result::Error),
#[error("Error deserializing hex value")] #[error("Error deserializing hex value")]
Hex(#[source] hex::FromHexError), Hex(#[source] hex::FromHexError),
@ -331,6 +338,7 @@ impl PostgresRepo {
notifier_pool, notifier_pool,
queue_notifications: DashMap::new(), queue_notifications: DashMap::new(),
upload_notifications: DashMap::new(), upload_notifications: DashMap::new(),
keyed_notifications: NotificationMap::new(),
}); });
let handle = crate::sync::abort_on_drop(crate::sync::spawn_sendable( let handle = crate::sync::abort_on_drop(crate::sync::spawn_sendable(
@ -363,8 +371,97 @@ impl PostgresRepo {
.with_poll_timer("postgres-get-notifier-connection") .with_poll_timer("postgres-get-notifier-connection")
.await .await
} }
async fn insert_keyed_notifier(
&self,
input_key: &str,
) -> Result<Result<(), AlreadyInserted>, PostgresError> {
use schema::keyed_notifications::dsl::*;
let mut conn = self.get_connection().await?;
let timestamp = to_primitive(time::OffsetDateTime::now_utc());
diesel::delete(keyed_notifications)
.filter(heartbeat.le(timestamp.saturating_sub(time::Duration::minutes(2))))
.execute(&mut conn)
.with_timeout(Duration::from_secs(5))
.await
.map_err(|_| PostgresError::DbTimeout)?
.map_err(PostgresError::Diesel)?;
let res = diesel::insert_into(keyed_notifications)
.values(key.eq(input_key))
.execute(&mut conn)
.with_timeout(Duration::from_secs(5))
.await
.map_err(|_| PostgresError::DbTimeout)?;
match res {
Ok(_) => Ok(Ok(())),
Err(diesel::result::Error::DatabaseError(
diesel::result::DatabaseErrorKind::UniqueViolation,
_,
)) => Ok(Err(AlreadyInserted)),
Err(e) => Err(PostgresError::Diesel(e)),
}
}
async fn keyed_notifier_heartbeat(&self, input_key: &str) -> Result<(), PostgresError> {
use schema::keyed_notifications::dsl::*;
let mut conn = self.get_connection().await?;
let timestamp = to_primitive(time::OffsetDateTime::now_utc());
diesel::update(keyed_notifications)
.filter(key.eq(input_key))
.set(heartbeat.eq(timestamp))
.execute(&mut conn)
.with_timeout(Duration::from_secs(5))
.await
.map_err(|_| PostgresError::DbTimeout)?
.map_err(PostgresError::Diesel)?;
Ok(())
}
fn listen_on_key(&self, key: Arc<str>) -> NotificationEntry {
self.inner.keyed_notifications.register_interest(key)
}
async fn register_interest(&self) -> Result<(), PostgresError> {
let mut notifier_conn = self.get_notifier_connection().await?;
diesel::sql_query("LISTEN keyed_notification_channel;")
.execute(&mut notifier_conn)
.with_timeout(Duration::from_secs(5))
.await
.map_err(|_| PostgresError::DbTimeout)?
.map_err(PostgresError::Diesel)?;
Ok(())
}
async fn clear_keyed_notifier(&self, input_key: String) -> Result<(), PostgresError> {
use schema::keyed_notifications::dsl::*;
let mut conn = self.get_connection().await?;
diesel::delete(keyed_notifications)
.filter(key.eq(input_key))
.execute(&mut conn)
.with_timeout(Duration::from_secs(5))
.await
.map_err(|_| PostgresError::DbTimeout)?
.map_err(PostgresError::Diesel)?;
Ok(())
}
} }
struct AlreadyInserted;
struct GetConnectionMetricsGuard { struct GetConnectionMetricsGuard {
start: Instant, start: Instant,
armed: bool, armed: bool,
@ -437,13 +534,15 @@ impl Inner {
} }
impl UploadInterest { impl UploadInterest {
async fn notified_timeout(&self, timeout: Duration) -> Result<(), tokio::time::error::Elapsed> { fn notified_timeout(
&self,
timeout: Duration,
) -> impl Future<Output = Result<(), tokio::time::error::Elapsed>> + '_ {
self.interest self.interest
.as_ref() .as_ref()
.expect("interest exists") .expect("interest exists")
.notified() .notified()
.with_timeout(timeout) .with_timeout(timeout)
.await
} }
} }
@ -511,6 +610,12 @@ impl<'a> UploadNotifierState<'a> {
} }
} }
impl<'a> KeyedNotifierState<'a> {
fn handle(&self, key: &str) {
self.inner.keyed_notifications.notify(key);
}
}
type BoxFuture<'a, T> = std::pin::Pin<Box<dyn std::future::Future<Output = T> + Send + 'a>>; type BoxFuture<'a, T> = std::pin::Pin<Box<dyn std::future::Future<Output = T> + Send + 'a>>;
type ConfigFn = type ConfigFn =
Box<dyn Fn(&str) -> BoxFuture<'_, ConnectionResult<AsyncPgConnection>> + Send + Sync + 'static>; Box<dyn Fn(&str) -> BoxFuture<'_, ConnectionResult<AsyncPgConnection>> + Send + Sync + 'static>;
@ -529,6 +634,8 @@ async fn delegate_notifications(
let upload_notifier_state = UploadNotifierState { inner: &inner }; let upload_notifier_state = UploadNotifierState { inner: &inner };
let keyed_notifier_state = KeyedNotifierState { inner: &inner };
while let Ok(notification) = receiver.recv_async().await { while let Ok(notification) = receiver.recv_async().await {
tracing::trace!("delegate_notifications: looping"); tracing::trace!("delegate_notifications: looping");
metrics::counter!(crate::init_metrics::POSTGRES_NOTIFICATION).increment(1); metrics::counter!(crate::init_metrics::POSTGRES_NOTIFICATION).increment(1);
@ -542,6 +649,10 @@ async fn delegate_notifications(
// new upload finished // new upload finished
upload_notifier_state.handle(notification.payload()); upload_notifier_state.handle(notification.payload());
} }
"keyed_notification_channel" => {
// new keyed notification
keyed_notifier_state.handle(notification.payload());
}
channel => { channel => {
tracing::info!( tracing::info!(
"Unhandled postgres notification: {channel}: {}", "Unhandled postgres notification: {channel}: {}",
@ -863,110 +974,6 @@ impl HashRepo for PostgresRepo {
Ok(opt.map(Arc::from)) Ok(opt.map(Arc::from))
} }
#[tracing::instrument(level = "debug", skip(self))]
async fn relate_variant_identifier(
&self,
input_hash: Hash,
input_variant: String,
input_identifier: &Arc<str>,
) -> Result<Result<(), VariantAlreadyExists>, RepoError> {
use schema::variants::dsl::*;
let mut conn = self.get_connection().await?;
let res = diesel::insert_into(variants)
.values((
hash.eq(&input_hash),
variant.eq(&input_variant),
identifier.eq(input_identifier.as_ref()),
))
.execute(&mut conn)
.with_metrics(crate::init_metrics::POSTGRES_VARIANTS_RELATE_VARIANT_IDENTIFIER)
.with_timeout(Duration::from_secs(5))
.await
.map_err(|_| PostgresError::DbTimeout)?;
match res {
Ok(_) => Ok(Ok(())),
Err(diesel::result::Error::DatabaseError(
diesel::result::DatabaseErrorKind::UniqueViolation,
_,
)) => Ok(Err(VariantAlreadyExists)),
Err(e) => Err(PostgresError::Diesel(e).into()),
}
}
#[tracing::instrument(level = "debug", skip(self))]
async fn variant_identifier(
&self,
input_hash: Hash,
input_variant: String,
) -> Result<Option<Arc<str>>, RepoError> {
use schema::variants::dsl::*;
let mut conn = self.get_connection().await?;
let opt = variants
.select(identifier)
.filter(hash.eq(&input_hash))
.filter(variant.eq(&input_variant))
.get_result::<String>(&mut conn)
.with_metrics(crate::init_metrics::POSTGRES_VARIANTS_IDENTIFIER)
.with_timeout(Duration::from_secs(5))
.await
.map_err(|_| PostgresError::DbTimeout)?
.optional()
.map_err(PostgresError::Diesel)?
.map(Arc::from);
Ok(opt)
}
#[tracing::instrument(level = "debug", skip(self))]
async fn variants(&self, input_hash: Hash) -> Result<Vec<(String, Arc<str>)>, RepoError> {
use schema::variants::dsl::*;
let mut conn = self.get_connection().await?;
let vec = variants
.select((variant, identifier))
.filter(hash.eq(&input_hash))
.get_results::<(String, String)>(&mut conn)
.with_metrics(crate::init_metrics::POSTGRES_VARIANTS_FOR_HASH)
.with_timeout(Duration::from_secs(5))
.await
.map_err(|_| PostgresError::DbTimeout)?
.map_err(PostgresError::Diesel)?
.into_iter()
.map(|(s, i)| (s, Arc::from(i)))
.collect();
Ok(vec)
}
#[tracing::instrument(level = "debug", skip(self))]
async fn remove_variant(
&self,
input_hash: Hash,
input_variant: String,
) -> Result<(), RepoError> {
use schema::variants::dsl::*;
let mut conn = self.get_connection().await?;
diesel::delete(variants)
.filter(hash.eq(&input_hash))
.filter(variant.eq(&input_variant))
.execute(&mut conn)
.with_metrics(crate::init_metrics::POSTGRES_VARIANTS_REMOVE)
.with_timeout(Duration::from_secs(5))
.await
.map_err(|_| PostgresError::DbTimeout)?
.map_err(PostgresError::Diesel)?;
Ok(())
}
#[tracing::instrument(level = "debug", skip(self))] #[tracing::instrument(level = "debug", skip(self))]
async fn relate_blurhash( async fn relate_blurhash(
&self, &self,
@ -1083,6 +1090,167 @@ impl HashRepo for PostgresRepo {
} }
} }
#[async_trait::async_trait(?Send)]
impl VariantRepo for PostgresRepo {
#[tracing::instrument(level = "debug", skip(self))]
async fn claim_variant_processing_rights(
&self,
hash: Hash,
variant: String,
) -> Result<Result<(), NotificationEntry>, RepoError> {
let key = Arc::from(format!("{}{variant}", hash.to_base64()));
let entry = self.listen_on_key(Arc::clone(&key));
self.register_interest().await?;
if self
.variant_identifier(hash.clone(), variant.clone())
.await?
.is_some()
{
return Ok(Err(entry));
}
match self.insert_keyed_notifier(&key).await? {
Ok(()) => Ok(Ok(())),
Err(AlreadyInserted) => Ok(Err(entry)),
}
}
async fn variant_waiter(
&self,
hash: Hash,
variant: String,
) -> Result<NotificationEntry, RepoError> {
let key = Arc::from(format!("{}{variant}", hash.to_base64()));
let entry = self.listen_on_key(key);
self.register_interest().await?;
Ok(entry)
}
#[tracing::instrument(level = "debug", skip(self))]
async fn variant_heartbeat(&self, hash: Hash, variant: String) -> Result<(), RepoError> {
let key = format!("{}{variant}", hash.to_base64());
self.keyed_notifier_heartbeat(&key)
.await
.map_err(Into::into)
}
#[tracing::instrument(level = "trace", skip(self))]
async fn notify_variant(&self, hash: Hash, variant: String) -> Result<(), RepoError> {
let key = format!("{}{variant}", hash.to_base64());
self.clear_keyed_notifier(key).await.map_err(Into::into)
}
#[tracing::instrument(level = "debug", skip(self))]
async fn relate_variant_identifier(
&self,
input_hash: Hash,
input_variant: String,
input_identifier: &Arc<str>,
) -> Result<Result<(), VariantAlreadyExists>, RepoError> {
use schema::variants::dsl::*;
let mut conn = self.get_connection().await?;
let res = diesel::insert_into(variants)
.values((
hash.eq(&input_hash),
variant.eq(&input_variant),
identifier.eq(input_identifier.to_string()),
))
.execute(&mut conn)
.with_metrics(crate::init_metrics::POSTGRES_VARIANTS_RELATE_VARIANT_IDENTIFIER)
.with_timeout(Duration::from_secs(5))
.await
.map_err(|_| PostgresError::DbTimeout)?;
match res {
Ok(_) => Ok(Ok(())),
Err(diesel::result::Error::DatabaseError(
diesel::result::DatabaseErrorKind::UniqueViolation,
_,
)) => Ok(Err(VariantAlreadyExists)),
Err(e) => Err(PostgresError::Diesel(e).into()),
}
}
#[tracing::instrument(level = "debug", skip(self))]
async fn variant_identifier(
&self,
input_hash: Hash,
input_variant: String,
) -> Result<Option<Arc<str>>, RepoError> {
use schema::variants::dsl::*;
let mut conn = self.get_connection().await?;
let opt = variants
.select(identifier)
.filter(hash.eq(&input_hash))
.filter(variant.eq(&input_variant))
.get_result::<String>(&mut conn)
.with_metrics(crate::init_metrics::POSTGRES_VARIANTS_IDENTIFIER)
.with_timeout(Duration::from_secs(5))
.await
.map_err(|_| PostgresError::DbTimeout)?
.optional()
.map_err(PostgresError::Diesel)?
.map(Arc::from);
Ok(opt)
}
#[tracing::instrument(level = "debug", skip(self))]
async fn variants(&self, input_hash: Hash) -> Result<Vec<(String, Arc<str>)>, RepoError> {
use schema::variants::dsl::*;
let mut conn = self.get_connection().await?;
let vec = variants
.select((variant, identifier))
.filter(hash.eq(&input_hash))
.get_results::<(String, String)>(&mut conn)
.with_metrics(crate::init_metrics::POSTGRES_VARIANTS_FOR_HASH)
.with_timeout(Duration::from_secs(5))
.await
.map_err(|_| PostgresError::DbTimeout)?
.map_err(PostgresError::Diesel)?
.into_iter()
.map(|(s, i)| (s, Arc::from(i)))
.collect();
Ok(vec)
}
#[tracing::instrument(level = "debug", skip(self))]
async fn remove_variant(
&self,
input_hash: Hash,
input_variant: String,
) -> Result<(), RepoError> {
use schema::variants::dsl::*;
let mut conn = self.get_connection().await?;
diesel::delete(variants)
.filter(hash.eq(&input_hash))
.filter(variant.eq(&input_variant))
.execute(&mut conn)
.with_metrics(crate::init_metrics::POSTGRES_VARIANTS_REMOVE)
.with_timeout(Duration::from_secs(5))
.await
.map_err(|_| PostgresError::DbTimeout)?
.map_err(PostgresError::Diesel)?;
Ok(())
}
}
#[async_trait::async_trait(?Send)] #[async_trait::async_trait(?Send)]
impl AliasRepo for PostgresRepo { impl AliasRepo for PostgresRepo {
#[tracing::instrument(level = "debug", skip(self))] #[tracing::instrument(level = "debug", skip(self))]
@ -1279,16 +1447,22 @@ impl DetailsRepo for PostgresRepo {
let value = let value =
serde_json::to_value(&input_details.inner).map_err(PostgresError::SerializeDetails)?; serde_json::to_value(&input_details.inner).map_err(PostgresError::SerializeDetails)?;
diesel::insert_into(details) let res = diesel::insert_into(details)
.values((identifier.eq(input_identifier.as_ref()), json.eq(&value))) .values((identifier.eq(input_identifier.as_ref()), json.eq(&value)))
.execute(&mut conn) .execute(&mut conn)
.with_metrics(crate::init_metrics::POSTGRES_DETAILS_RELATE) .with_metrics(crate::init_metrics::POSTGRES_DETAILS_RELATE)
.with_timeout(Duration::from_secs(5)) .with_timeout(Duration::from_secs(5))
.await .await
.map_err(|_| PostgresError::DbTimeout)? .map_err(|_| PostgresError::DbTimeout)?;
.map_err(PostgresError::Diesel)?;
Ok(()) match res {
Ok(_)
| Err(diesel::result::Error::DatabaseError(
diesel::result::DatabaseErrorKind::UniqueViolation,
_,
)) => Ok(()),
Err(e) => Err(PostgresError::Diesel(e).into()),
}
} }
#[tracing::instrument(level = "debug", skip(self))] #[tracing::instrument(level = "debug", skip(self))]

View file

@ -0,0 +1,50 @@
use barrel::backend::Pg;
use barrel::functions::AutogenFunction;
use barrel::{types, Migration};
pub(crate) fn migration() -> String {
let mut m = Migration::new();
m.create_table("keyed_notifications", |t| {
t.add_column(
"key",
types::text().primary(true).unique(true).nullable(false),
);
t.add_column(
"heartbeat",
types::datetime()
.nullable(false)
.default(AutogenFunction::CurrentTimestamp),
);
t.add_index(
"keyed_notifications_heartbeat_index",
types::index(["heartbeat"]),
);
});
m.inject_custom(
r#"
CREATE OR REPLACE FUNCTION keyed_notify()
RETURNS trigger AS
$$
BEGIN
PERFORM pg_notify('keyed_notification_channel', OLD.key);
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
"#
.trim(),
);
m.inject_custom(
r#"
CREATE TRIGGER keyed_notification_removed
AFTER DELETE
ON keyed_notifications
FOR EACH ROW
EXECUTE PROCEDURE keyed_notify();
"#,
);
m.make::<Pg>().to_string()
}

View file

@ -48,6 +48,13 @@ diesel::table! {
} }
} }
diesel::table! {
keyed_notifications (key) {
key -> Text,
heartbeat -> Timestamp,
}
}
diesel::table! { diesel::table! {
proxies (url) { proxies (url) {
url -> Text, url -> Text,
@ -109,6 +116,7 @@ diesel::allow_tables_to_appear_in_same_query!(
details, details,
hashes, hashes,
job_queue, job_queue,
keyed_notifications,
proxies, proxies,
refinery_schema_history, refinery_schema_history,
settings, settings,

View file

@ -5,6 +5,7 @@ use crate::{
serde_str::Serde, serde_str::Serde,
stream::{from_iterator, LocalBoxStream}, stream::{from_iterator, LocalBoxStream},
}; };
use dashmap::DashMap;
use sled::{transaction::TransactionError, Db, IVec, Transactional, Tree}; use sled::{transaction::TransactionError, Db, IVec, Transactional, Tree};
use std::{ use std::{
collections::HashMap, collections::HashMap,
@ -22,10 +23,11 @@ use uuid::Uuid;
use super::{ use super::{
hash::Hash, hash::Hash,
metrics::{PopMetricsGuard, PushMetricsGuard, WaitMetricsGuard}, metrics::{PopMetricsGuard, PushMetricsGuard, WaitMetricsGuard},
notification_map::{NotificationEntry, NotificationMap},
Alias, AliasAccessRepo, AliasAlreadyExists, AliasRepo, BaseRepo, DeleteToken, Details, Alias, AliasAccessRepo, AliasAlreadyExists, AliasRepo, BaseRepo, DeleteToken, Details,
DetailsRepo, FullRepo, HashAlreadyExists, HashPage, HashRepo, JobId, JobResult, OrderedHash, DetailsRepo, FullRepo, HashAlreadyExists, HashPage, HashRepo, JobId, JobResult, OrderedHash,
ProxyRepo, QueueRepo, RepoError, SettingsRepo, StoreMigrationRepo, UploadId, UploadRepo, ProxyRepo, QueueRepo, RepoError, SettingsRepo, StoreMigrationRepo, UploadId, UploadRepo,
UploadResult, VariantAccessRepo, VariantAlreadyExists, UploadResult, VariantAccessRepo, VariantAlreadyExists, VariantRepo,
}; };
macro_rules! b { macro_rules! b {
@ -113,6 +115,8 @@ pub(crate) struct SledRepo {
migration_identifiers: Tree, migration_identifiers: Tree,
cache_capacity: u64, cache_capacity: u64,
export_path: PathBuf, export_path: PathBuf,
variant_process_map: DashMap<(Hash, String), time::OffsetDateTime>,
notifications: NotificationMap,
db: Db, db: Db,
} }
@ -156,6 +160,8 @@ impl SledRepo {
migration_identifiers: db.open_tree("pict-rs-migration-identifiers-tree")?, migration_identifiers: db.open_tree("pict-rs-migration-identifiers-tree")?,
cache_capacity, cache_capacity,
export_path, export_path,
variant_process_map: DashMap::new(),
notifications: NotificationMap::new(),
db, db,
}) })
} }
@ -1331,88 +1337,6 @@ impl HashRepo for SledRepo {
Ok(opt.map(try_into_arc_str).transpose()?) Ok(opt.map(try_into_arc_str).transpose()?)
} }
#[tracing::instrument(level = "trace", skip(self))]
async fn relate_variant_identifier(
&self,
hash: Hash,
variant: String,
identifier: &Arc<str>,
) -> Result<Result<(), VariantAlreadyExists>, RepoError> {
let hash = hash.to_bytes();
let key = variant_key(&hash, &variant);
let value = identifier.clone();
let hash_variant_identifiers = self.hash_variant_identifiers.clone();
crate::sync::spawn_blocking("sled-io", move || {
hash_variant_identifiers
.compare_and_swap(key, Option::<&[u8]>::None, Some(value.as_bytes()))
.map(|res| res.map_err(|_| VariantAlreadyExists))
})
.await
.map_err(|_| RepoError::Canceled)?
.map_err(SledError::from)
.map_err(RepoError::from)
}
#[tracing::instrument(level = "trace", skip(self))]
async fn variant_identifier(
&self,
hash: Hash,
variant: String,
) -> Result<Option<Arc<str>>, RepoError> {
let hash = hash.to_bytes();
let key = variant_key(&hash, &variant);
let opt = b!(
self.hash_variant_identifiers,
hash_variant_identifiers.get(key)
);
Ok(opt.map(try_into_arc_str).transpose()?)
}
#[tracing::instrument(level = "debug", skip(self))]
async fn variants(&self, hash: Hash) -> Result<Vec<(String, Arc<str>)>, RepoError> {
let hash = hash.to_ivec();
let vec = b!(
self.hash_variant_identifiers,
Ok(hash_variant_identifiers
.scan_prefix(hash.clone())
.filter_map(|res| res.ok())
.filter_map(|(key, ivec)| {
let identifier = try_into_arc_str(ivec).ok();
let variant = variant_from_key(&hash, &key);
if variant.is_none() {
tracing::warn!("Skipping a variant: {}", String::from_utf8_lossy(&key));
}
Some((variant?, identifier?))
})
.collect::<Vec<_>>()) as Result<Vec<_>, SledError>
);
Ok(vec)
}
#[tracing::instrument(level = "trace", skip(self))]
async fn remove_variant(&self, hash: Hash, variant: String) -> Result<(), RepoError> {
let hash = hash.to_bytes();
let key = variant_key(&hash, &variant);
b!(
self.hash_variant_identifiers,
hash_variant_identifiers.remove(key)
);
Ok(())
}
#[tracing::instrument(level = "trace", skip(self))] #[tracing::instrument(level = "trace", skip(self))]
async fn relate_blurhash(&self, hash: Hash, blurhash: Arc<str>) -> Result<(), RepoError> { async fn relate_blurhash(&self, hash: Hash, blurhash: Arc<str>) -> Result<(), RepoError> {
b!( b!(
@ -1528,6 +1452,167 @@ impl HashRepo for SledRepo {
} }
} }
#[async_trait::async_trait(?Send)]
impl VariantRepo for SledRepo {
#[tracing::instrument(level = "trace", skip(self))]
async fn claim_variant_processing_rights(
&self,
hash: Hash,
variant: String,
) -> Result<Result<(), NotificationEntry>, RepoError> {
let key = (hash.clone(), variant.clone());
let now = time::OffsetDateTime::now_utc();
let entry = self
.notifications
.register_interest(Arc::from(format!("{}{variant}", hash.to_base64())));
match self.variant_process_map.entry(key.clone()) {
dashmap::mapref::entry::Entry::Occupied(mut occupied_entry) => {
if occupied_entry
.get()
.saturating_add(time::Duration::minutes(2))
> now
{
return Ok(Err(entry));
}
occupied_entry.insert(now);
}
dashmap::mapref::entry::Entry::Vacant(vacant_entry) => {
vacant_entry.insert(now);
}
}
if self.variant_identifier(hash, variant).await?.is_some() {
self.variant_process_map.remove(&key);
return Ok(Err(entry));
}
Ok(Ok(()))
}
async fn variant_waiter(
&self,
hash: Hash,
variant: String,
) -> Result<NotificationEntry, RepoError> {
let entry = self
.notifications
.register_interest(Arc::from(format!("{}{variant}", hash.to_base64())));
Ok(entry)
}
#[tracing::instrument(level = "trace", skip(self))]
async fn variant_heartbeat(&self, hash: Hash, variant: String) -> Result<(), RepoError> {
let key = (hash, variant);
let now = time::OffsetDateTime::now_utc();
if let dashmap::mapref::entry::Entry::Occupied(mut occupied_entry) =
self.variant_process_map.entry(key)
{
occupied_entry.insert(now);
}
Ok(())
}
#[tracing::instrument(level = "trace", skip(self))]
async fn notify_variant(&self, hash: Hash, variant: String) -> Result<(), RepoError> {
let key = (hash.clone(), variant.clone());
self.variant_process_map.remove(&key);
let key = format!("{}{variant}", hash.to_base64());
self.notifications.notify(&key);
Ok(())
}
#[tracing::instrument(level = "trace", skip(self))]
async fn relate_variant_identifier(
&self,
hash: Hash,
variant: String,
identifier: &Arc<str>,
) -> Result<Result<(), VariantAlreadyExists>, RepoError> {
let hash = hash.to_bytes();
let key = variant_key(&hash, &variant);
let value = identifier.clone();
let hash_variant_identifiers = self.hash_variant_identifiers.clone();
let out = crate::sync::spawn_blocking("sled-io", move || {
hash_variant_identifiers
.compare_and_swap(key, Option::<&[u8]>::None, Some(value.as_bytes()))
.map(|res| res.map_err(|_| VariantAlreadyExists))
})
.await
.map_err(|_| RepoError::Canceled)?
.map_err(SledError::from)
.map_err(RepoError::from)?;
Ok(out)
}
#[tracing::instrument(level = "trace", skip(self))]
async fn variant_identifier(
&self,
hash: Hash,
variant: String,
) -> Result<Option<Arc<str>>, RepoError> {
let hash = hash.to_bytes();
let key = variant_key(&hash, &variant);
let opt = b!(
self.hash_variant_identifiers,
hash_variant_identifiers.get(key)
);
Ok(opt.map(try_into_arc_str).transpose()?)
}
#[tracing::instrument(level = "debug", skip(self))]
async fn variants(&self, hash: Hash) -> Result<Vec<(String, Arc<str>)>, RepoError> {
let hash = hash.to_ivec();
let vec = b!(
self.hash_variant_identifiers,
Ok(hash_variant_identifiers
.scan_prefix(hash.clone())
.filter_map(|res| res.ok())
.filter_map(|(key, ivec)| {
let identifier = try_into_arc_str(ivec).ok();
let variant = variant_from_key(&hash, &key);
if variant.is_none() {
tracing::warn!("Skipping a variant: {}", String::from_utf8_lossy(&key));
}
Some((variant?, identifier?))
})
.collect::<Vec<_>>()) as Result<Vec<_>, SledError>
);
Ok(vec)
}
#[tracing::instrument(level = "trace", skip(self))]
async fn remove_variant(&self, hash: Hash, variant: String) -> Result<(), RepoError> {
let hash = hash.to_bytes();
let key = variant_key(&hash, &variant);
b!(
self.hash_variant_identifiers,
hash_variant_identifiers.remove(key)
);
Ok(())
}
}
fn hash_alias_key(hash: &IVec, alias: &IVec) -> Vec<u8> { fn hash_alias_key(hash: &IVec, alias: &IVec) -> Vec<u8> {
let mut v = hash.to_vec(); let mut v = hash.to_vec();
v.extend_from_slice(alias); v.extend_from_slice(alias);