diff --git a/whisper_streaming/__init__.py b/whisper_streaming/__init__.py index eb8bcb0..d246d29 100644 --- a/whisper_streaming/__init__.py +++ b/whisper_streaming/__init__.py @@ -25,6 +25,7 @@ class WhisperConfig(BaseModel): logprob_threshold: Optional[float] = -1.0 no_captions_threshold: Optional[float] = 0.6 best_of: int = 5 + beam_size: Optional[int] = None class WhisperStreamingTranscriber: @@ -42,7 +43,7 @@ class WhisperStreamingTranscriber: self, *, t, - beam_size, + beam_size: Optional[int], patience, best_of, ) -> DecodingOptions: @@ -81,7 +82,7 @@ class WhisperStreamingTranscriber: _decode_options: DecodingOptions = self._get_decoding_options( t=t, - beam_size=None, + beam_size=self.config.beam_size, patience=0.0, best_of=None, ) @@ -100,7 +101,7 @@ class WhisperStreamingTranscriber: t=t, beam_size=None, patience=None, - best_of=None, + best_of=self.config.best_of, ) retries: List[DecodingResult] = self.model.decode( segment[needs_fallback], _decode_options # type: ignore @@ -158,16 +159,25 @@ def get_opts() -> argparse.Namespace: default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference", ) + parser.add_argument( + "--beam_size", + "-b", + type=int, + default=5, + ) return parser.parse_args() def main() -> None: opts = get_opts() + if opts.beam_size <= 0: + opts.beam_size = None config = WhisperConfig( model_name=opts.model, language=opts.language, device=opts.device, + beam_size=opts.beam_size, ) transcribe_from_mic(config=config)