diff --git a/README.md b/README.md index 042e410..cd646a5 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,7 @@ whispering --language en --model tiny - ``--no-progress`` disables the progress message - ``-t`` sets temperatures to decode. You can set several like ``-t 0.0 -t 0.1 -t 0.5``, but too many temperatures exhaust decoding time - ``--debug`` outputs logs for debug +- ``--no-vad`` disables VAD (Voice Activity Detection). This forces whisper to analyze non-voice activity sound period ### Parse interval diff --git a/whispering/cli.py b/whispering/cli.py index fcbf9e4..6e9eca4 100644 --- a/whispering/cli.py +++ b/whispering/cli.py @@ -155,6 +155,10 @@ def get_opts() -> argparse.Namespace: "--no-progress", action="store_true", ) + parser.add_argument( + "--no-vad", + action="store_true", + ) opts = parser.parse_args() if opts.beam_size <= 0: @@ -187,6 +191,7 @@ def get_context(*, opts) -> Context: beam_size=opts.beam_size, temperatures=opts.temperature, allow_padding=opts.allow_padding, + vad=not opts.no_vad, ) logger.debug(f"Context: {ctx}") return ctx diff --git a/whispering/schema.py b/whispering/schema.py index d31d4df..14ef78a 100644 --- a/whispering/schema.py +++ b/whispering/schema.py @@ -27,6 +27,7 @@ class Context(BaseModel, arbitrary_types_allowed=True): timestamp: float = 0.0 buffer_tokens: List[torch.Tensor] = [] buffer_mel: Optional[torch.Tensor] = None + vad: bool = True temperatures: List[float] allow_padding: bool = False diff --git a/whispering/transcriber.py b/whispering/transcriber.py index a390b44..be3039c 100644 --- a/whispering/transcriber.py +++ b/whispering/transcriber.py @@ -233,17 +233,19 @@ class WhisperStreamingTranscriber: ctx: Context, ) -> Iterator[ParsedChunk]: logger.debug(f"{len(audio)}") - x = [ - v - for v in self.vad( - audio=audio, - total_block_number=1, - ) - ] - if len(x) == 0: # No speech - logger.debug("No speech") - ctx.timestamp += len(audio) / N_FRAMES * self.duration_pre_one_mel - return + + if not ctx.vad: + x = [ + v + for v in self.vad( + audio=audio, + total_block_number=1, + ) + ] + if len(x) == 0: # No speech + logger.debug("No speech") + ctx.timestamp += len(audio) / N_FRAMES * self.duration_pre_one_mel + return new_mel = log_mel_spectrogram(audio=audio) logger.debug(f"Incoming new_mel.shape: {new_mel.shape}")