diff --git a/whisper_streaming/cli.py b/whisper_streaming/cli.py index 6464102..e4ad406 100644 --- a/whisper_streaming/cli.py +++ b/whisper_streaming/cli.py @@ -109,6 +109,10 @@ def get_opts() -> argparse.Namespace: type=int, help="Port number of websocker server", ) + parser.add_argument( + "--allow-padding", + action="store_true", + ) return parser.parse_args() @@ -137,6 +141,7 @@ def main() -> None: device=opts.device, beam_size=opts.beam_size, temperatures=opts.temperature, + allow_padding=opts.allow_padding, ) logger.debug(f"WhisperConfig: {config}") diff --git a/whisper_streaming/schema.py b/whisper_streaming/schema.py index f7fedf8..1bcbc5e 100644 --- a/whisper_streaming/schema.py +++ b/whisper_streaming/schema.py @@ -10,6 +10,7 @@ class WhisperConfig(BaseModel): device: str language: str + allow_padding: bool = False temperatures: List[float] fp16: bool = True compression_ratio_threshold: Optional[float] = 2.4 diff --git a/whisper_streaming/transcriber.py b/whisper_streaming/transcriber.py index c1114a5..7325712 100644 --- a/whisper_streaming/transcriber.py +++ b/whisper_streaming/transcriber.py @@ -248,7 +248,7 @@ class WhisperStreamingTranscriber: .to(self.model.device) # type: ignore .to(self.dtype) ) - if segment.shape[-1] > mel.shape[-1]: + if not self.config.allow_padding and segment.shape[-1] > mel.shape[-1]: logger.warning("Padding is not expected while speaking") logger.debug( @@ -290,7 +290,7 @@ class WhisperStreamingTranscriber: seek += last_timestamp_position * self.input_stride logger.debug(f"new seek={seek}, mel.shape: {mel.shape}") - if mel.shape[-1] - seek < N_FRAMES: + if (not self.config.allow_padding) and (mel.shape[-1] - seek < N_FRAMES): break if mel.shape[-1] - seek <= 0: