diff --git a/README.md b/README.md index 6b046de8..02762928 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ You will find the following plugins in this repository: * `audio` - `audiofx`: Plugins to apply audio effects to a stream (such as adding - echo/reverb, or normalization). + echo/reverb, [removing noise](https://jmvalin.ca/demo/rnnoise/) or normalization). - `claxon`: A FLAC decoder based on the [Claxon](https://github.com/ruuda/claxon) library. diff --git a/audio/audiofx/Cargo.toml b/audio/audiofx/Cargo.toml index b70d7d29..601cd48f 100644 --- a/audio/audiofx/Cargo.toml +++ b/audio/audiofx/Cargo.toml @@ -16,6 +16,7 @@ byte-slice-cast = "0.3" num-traits = "0.2" lazy_static = "1.0" ebur128 = "0.1" +nnnoiseless = "0.2" [lib] name = "gstrsaudiofx" diff --git a/audio/audiofx/src/audiornnoise.rs b/audio/audiofx/src/audiornnoise.rs new file mode 100644 index 00000000..9da5cb33 --- /dev/null +++ b/audio/audiofx/src/audiornnoise.rs @@ -0,0 +1,382 @@ +// Copyright (C) 2020 Philippe Normand +// Copyright (C) 2020 Natanael Mojica +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use byte_slice_cast::*; +use glib::subclass; +use glib::subclass::prelude::*; +use gst::prelude::*; +use gst::subclass::prelude::*; +use gst_base::subclass::base_transform::BaseTransformImplExt; +use gst_base::subclass::base_transform::GenerateOutputSuccess; +use gst_base::subclass::prelude::*; +use nnnoiseless::DenoiseState; +use std::sync::Mutex; + +lazy_static! { + static ref CAT: gst::DebugCategory = gst::DebugCategory::new( + "audiornnoise", + gst::DebugColorFlags::empty(), + Some("Rust Audio Denoise Filter"), + ); +} + +const FRAME_SIZE: usize = DenoiseState::FRAME_SIZE; + +struct ChannelDenoiser { + denoiser: Box, + frame_chunk: Box<[f32; FRAME_SIZE]>, + out_chunk: Box<[f32; FRAME_SIZE]>, +} + +struct State { + in_info: gst_audio::AudioInfo, + denoisers: Vec, + adapter: gst_base::UniqueAdapter, +} + +struct AudioRNNoise { + state: Mutex>, +} + +impl State { + // The following three functions are copied from the csound filter. + fn buffer_duration(&self, buffer_size: u64) -> gst::ClockTime { + let samples = buffer_size / self.in_info.bpf() as u64; + self.samples_to_time(samples) + } + + fn samples_to_time(&self, samples: u64) -> gst::ClockTime { + gst::ClockTime(samples.mul_div_round(gst::SECOND_VAL, self.in_info.rate() as u64)) + } + + fn get_current_pts(&self) -> gst::ClockTime { + // get the last seen pts and the amount of bytes + // since then + let (prev_pts, distance) = self.adapter.prev_pts(); + + // Use the distance to get the amount of samples + // and with it calculate the time-offset which + // can be added to the prev_pts to get the + // pts at the beginning of the adapter. + let samples = distance / self.in_info.bpf() as u64; + prev_pts + self.samples_to_time(samples) + } + + fn needs_more_data(&self) -> bool { + self.adapter.available() < (FRAME_SIZE * self.in_info.bpf() as usize) + } + + fn process(&mut self, input_plane: &[f32], output_plane: &mut [f32]) { + let channels = self.in_info.channels() as usize; + let size = FRAME_SIZE * channels; + + for (out_frame, in_frame) in output_plane.chunks_mut(size).zip(input_plane.chunks(size)) { + for (index, item) in in_frame.iter().enumerate() { + let channel_index = index % channels; + let channel_denoiser = &mut self.denoisers[channel_index]; + let pos = index / channels; + channel_denoiser.frame_chunk[pos] = *item; + } + + for i in (in_frame.len() / channels)..(size / channels) { + for c in 0..channels { + let channel_denoiser = &mut self.denoisers[c]; + channel_denoiser.frame_chunk[i] = 0.0; + } + } + + // FIXME: The first chunks coming out of the denoisers contains some + // fade-in artifacts. We might want to discard those. + for channel_denoiser in &mut self.denoisers { + channel_denoiser.denoiser.process_frame( + &mut channel_denoiser.out_chunk[..], + &channel_denoiser.frame_chunk[..], + ); + } + + for (index, item) in out_frame.iter_mut().enumerate() { + let channel_index = index % channels; + let channel_denoiser = &self.denoisers[channel_index]; + let pos = index / channels; + *item = channel_denoiser.out_chunk[pos]; + } + } + } +} + +impl AudioRNNoise { + fn drain(&self, element: &gst_base::BaseTransform) -> Result { + let mut state_lock = self.state.lock().unwrap(); + let state = state_lock.as_mut().unwrap(); + + let available = state.adapter.available(); + if available == 0 { + return Ok(gst::FlowSuccess::Ok); + } + + let mut buffer = gst::Buffer::with_size(available).map_err(|e| { + gst_error!( + CAT, + obj: element, + "Failed to allocate buffer at EOS {:?}", + e + ); + gst::FlowError::Flushing + })?; + + let duration = state.buffer_duration(available as _); + let pts = state.get_current_pts(); + + { + let ibuffer = state.adapter.take_buffer(available).unwrap(); + let in_map = ibuffer.map_readable().map_err(|_| gst::FlowError::Error)?; + let in_data = in_map.as_slice_of::().unwrap(); + + let buffer = buffer.get_mut().unwrap(); + buffer.set_duration(duration); + buffer.set_pts(pts); + + let mut out_map = buffer.map_writable().map_err(|_| gst::FlowError::Error)?; + let mut out_data = out_map.as_mut_slice_of::().unwrap(); + + state.process(in_data, &mut out_data); + } + + let srcpad = element.get_static_pad("src").unwrap(); + srcpad.push(buffer) + } + + fn generate_output( + &self, + _element: &gst_base::BaseTransform, + state: &mut State, + ) -> Result { + let available = state.adapter.available(); + let bpf = state.in_info.bpf() as usize; + let output_size = available - (available % (FRAME_SIZE * bpf)); + let duration = state.buffer_duration(output_size as _); + let pts = state.get_current_pts(); + + let mut buffer = gst::Buffer::with_size(output_size).map_err(|_| gst::FlowError::Error)?; + + { + let ibuffer = state + .adapter + .take_buffer(output_size) + .map_err(|_| gst::FlowError::Error)?; + let in_map = ibuffer.map_readable().map_err(|_| gst::FlowError::Error)?; + let in_data = in_map.as_slice_of::().unwrap(); + + let buffer = buffer.get_mut().unwrap(); + buffer.set_duration(duration); + buffer.set_pts(pts); + + let mut out_map = buffer.map_writable().map_err(|_| gst::FlowError::Error)?; + let mut out_data = out_map.as_mut_slice_of::().unwrap(); + + state.process(in_data, &mut out_data); + } + + Ok(GenerateOutputSuccess::Buffer(buffer)) + } +} + +impl ObjectSubclass for AudioRNNoise { + const NAME: &'static str = "AudioRNNoise"; + type ParentType = gst_base::BaseTransform; + type Instance = gst::subclass::ElementInstanceStruct; + type Class = subclass::simple::ClassStruct; + + glib_object_subclass!(); + + fn new() -> Self { + Self { + state: Mutex::new(None), + } + } + + fn class_init(klass: &mut subclass::simple::ClassStruct) { + klass.set_metadata( + "Audio denoise", + "Filter/Effect/Audio", + "Removes noise from an audio stream", + "Philippe Normand ", + ); + + let caps = gst::Caps::new_simple( + "audio/x-raw", + &[ + ("format", &gst_audio::AUDIO_FORMAT_F32.to_str()), + ("rate", &48000), + ("channels", &gst::IntRange::::new(1, std::i32::MAX)), + ("layout", &"interleaved"), + ], + ); + let src_pad_template = gst::PadTemplate::new( + "src", + gst::PadDirection::Src, + gst::PadPresence::Always, + &caps, + ) + .unwrap(); + klass.add_pad_template(src_pad_template); + + let sink_pad_template = gst::PadTemplate::new( + "sink", + gst::PadDirection::Sink, + gst::PadPresence::Always, + &caps, + ) + .unwrap(); + klass.add_pad_template(sink_pad_template); + + klass.configure( + gst_base::subclass::BaseTransformMode::NeverInPlace, + false, + false, + ); + } +} + +impl ObjectImpl for AudioRNNoise {} +impl ElementImpl for AudioRNNoise {} + +impl BaseTransformImpl for AudioRNNoise { + fn set_caps( + &self, + element: &gst_base::BaseTransform, + incaps: &gst::Caps, + outcaps: &gst::Caps, + ) -> Result<(), gst::LoggableError> { + // Flush previous state + if self.state.lock().unwrap().is_some() { + self.drain(element).map_err(|e| { + gst_loggable_error!(CAT, "Error flusing previous state data {:?}", e) + })?; + } + if incaps != outcaps { + return Err(gst_loggable_error!( + CAT, + "Input and output caps are not the same" + )); + } + + gst_debug!(CAT, obj: element, "Set caps to {}", incaps); + + let in_info = gst_audio::AudioInfo::from_caps(incaps) + .map_err(|e| gst_loggable_error!(CAT, "Failed to parse input caps {:?}", e))?; + + let mut denoisers = vec![]; + for _i in 0..in_info.channels() { + denoisers.push(ChannelDenoiser { + denoiser: DenoiseState::new(), + frame_chunk: Box::new([0.0; FRAME_SIZE]), + out_chunk: Box::new([0.0; FRAME_SIZE]), + }) + } + + let mut state_lock = self.state.lock().unwrap(); + *state_lock = Some(State { + in_info, + denoisers, + adapter: gst_base::UniqueAdapter::new(), + }); + + Ok(()) + } + + fn generate_output( + &self, + element: &gst_base::BaseTransform, + ) -> Result { + // Check if there are enough data in the queued buffer and adapter, + // if it is not the case, just notify the parent class to not generate + // an output + if let Some(buffer) = self.take_queued_buffer() { + if buffer.get_flags() == gst::BufferFlags::DISCONT { + self.drain(element)?; + } + + let mut state_guard = self.state.lock().unwrap(); + let state = state_guard.as_mut().ok_or_else(|| { + gst_element_error!( + element, + gst::CoreError::Negotiation, + ["Can not generate an output without State"] + ); + gst::FlowError::NotNegotiated + })?; + + state.adapter.push(buffer); + if !state.needs_more_data() { + return self.generate_output(element, state); + } + } + Ok(GenerateOutputSuccess::NoOutput) + } + + fn sink_event(&self, element: &gst_base::BaseTransform, event: gst::Event) -> bool { + use gst::EventView; + + if let EventView::Eos(_) = event.view() { + gst_debug!(CAT, obj: element, "Handling EOS"); + if self.drain(element).is_err() { + return false; + } + } + self.parent_sink_event(element, event) + } + + fn query( + &self, + element: &gst_base::BaseTransform, + direction: gst::PadDirection, + query: &mut gst::QueryRef, + ) -> bool { + if direction == gst::PadDirection::Src { + if let gst::QueryView::Latency(ref mut q) = query.view_mut() { + let sink_pad = element.get_static_pad("sink").expect("Sink pad not found"); + let mut upstream_query = gst::query::Latency::new(); + if sink_pad.peer_query(&mut upstream_query) { + let (live, mut min, mut max) = upstream_query.get_result(); + gst_debug!( + CAT, + obj: element, + "Peer latency: live {} min {} max {}", + live, + min, + max + ); + + min += gst::ClockTime::from_seconds((FRAME_SIZE / 48000) as u64); + max += gst::ClockTime::from_seconds((FRAME_SIZE / 48000) as u64); + q.set(live, min, max); + return true; + } + } + } + BaseTransformImplExt::parent_query(self, element, direction, query) + } + + fn stop(&self, _element: &gst_base::BaseTransform) -> Result<(), gst::ErrorMessage> { + // Drop state + let _ = self.state.lock().unwrap().take(); + + Ok(()) + } +} + +pub fn register(plugin: &gst::Plugin) -> Result<(), glib::BoolError> { + gst::Element::register( + Some(plugin), + "audiornnoise", + gst::Rank::None, + AudioRNNoise::get_type(), + ) +} diff --git a/audio/audiofx/src/lib.rs b/audio/audiofx/src/lib.rs index ea3e01b8..a47062c4 100644 --- a/audio/audiofx/src/lib.rs +++ b/audio/audiofx/src/lib.rs @@ -20,9 +20,11 @@ extern crate lazy_static; mod audioecho; mod audioloudnorm; +mod audiornnoise; fn plugin_init(plugin: &gst::Plugin) -> Result<(), glib::BoolError> { audioecho::register(plugin)?; + audiornnoise::register(plugin)?; audioloudnorm::register(plugin) } diff --git a/audio/audiofx/tests/audiornnoise.rs b/audio/audiofx/tests/audiornnoise.rs new file mode 100644 index 00000000..5ff28789 --- /dev/null +++ b/audio/audiofx/tests/audiornnoise.rs @@ -0,0 +1,85 @@ +// Copyright (C) 2020 Philippe Normand +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +extern crate gstreamer as gst; +extern crate gstreamer_app as gst_app; +extern crate gstreamer_audio as gst_audio; +extern crate gstreamer_check as gst_check; + +use byte_slice_cast::*; + +fn init() { + use std::sync::Once; + static INIT: Once = Once::new(); + + INIT.call_once(|| { + gst::init().unwrap(); + gstrsaudiofx::plugin_register_static().expect("Failed to register rsaudiofx plugin"); + }); +} + +#[test] +fn test_rnnoise_silence_big_buffers() { + init(); + let audio_info = gst_audio::AudioInfo::builder(gst_audio::AUDIO_FORMAT_F32, 48000, 2) + .build() + .unwrap(); + test_rnnoise(&audio_info, 4096); +} + +#[test] +fn test_rnnoise_silence_small_buffers() { + init(); + let audio_info = gst_audio::AudioInfo::builder(gst_audio::AUDIO_FORMAT_F32, 48000, 2) + .build() + .unwrap(); + test_rnnoise(&audio_info, 1024); +} + +fn test_rnnoise(audio_info: &gst_audio::AudioInfo, buffer_size: usize) { + let filter = gst::ElementFactory::make("audiornnoise", None).unwrap(); + let mut h = gst_check::Harness::with_element(&filter, Some("sink"), Some("src")); + let sink_caps = audio_info.to_caps().unwrap(); + let src_caps = sink_caps.clone(); + h.set_caps(src_caps, sink_caps); + h.play(); + + let buffer = { + let mut buffer = gst::Buffer::with_size(buffer_size * audio_info.bpf() as usize).unwrap(); + { + let format_info = audio_info.format_info(); + let buffer_mut = buffer.get_mut().unwrap(); + let mut omap = buffer_mut.map_writable().unwrap(); + let odata = omap.as_mut_slice_of::().unwrap(); + format_info.fill_silence(odata); + } + + buffer + }; + + let num_buffers = 5; + let mut total_processed = 0; + for _ in 0..num_buffers { + let result = h.push_and_pull(buffer.clone()).unwrap(); + let map = result.into_mapped_buffer_readable().unwrap(); + let output = map.as_slice().as_slice_of::().unwrap(); + + // all samples in the output buffers must value 0 + assert_eq!(output.iter().any(|sample| *sample as u16 != 0u16), false); + total_processed += output.len(); + } + h.push_event(gst::event::Eos::new()); + + let last_buffer = h.pull().unwrap(); + let map = last_buffer.into_mapped_buffer_readable().unwrap(); + let output = map.as_slice().as_slice_of::().unwrap(); + total_processed += output.len(); + // The total amount of samples pushed into the element shall be equal to the + // total amount of samples pulled from it. + assert_eq!(total_processed, num_buffers * buffer_size); +}