whispering/12
2022-09-23 19:20:29 +09:00

65 lines
1.9 KiB
Text

diff --git a/whisper_streaming/__init__.py b/whisper_streaming/__init__.py
index eb8bcb0..d246d29 100644
--- a/whisper_streaming/__init__.py
+++ b/whisper_streaming/__init__.py
@@ -25,6 +25,7 @@ class WhisperConfig(BaseModel):
logprob_threshold: Optional[float] = -1.0
no_captions_threshold: Optional[float] = 0.6
best_of: int = 5
+ beam_size: Optional[int] = None
class WhisperStreamingTranscriber:
@@ -42,7 +43,7 @@ class WhisperStreamingTranscriber:
self,
*,
t,
- beam_size,
+ beam_size: Optional[int],
patience,
best_of,
) -> DecodingOptions:
@@ -81,7 +82,7 @@ class WhisperStreamingTranscriber:
_decode_options: DecodingOptions = self._get_decoding_options(
t=t,
- beam_size=None,
+ beam_size=self.config.beam_size,
patience=0.0,
best_of=None,
)
@@ -100,7 +101,7 @@ class WhisperStreamingTranscriber:
t=t,
beam_size=None,
patience=None,
- best_of=None,
+ best_of=self.config.best_of,
)
retries: List[DecodingResult] = self.model.decode(
segment[needs_fallback], _decode_options # type: ignore
@@ -158,16 +159,25 @@ def get_opts() -> argparse.Namespace:
default="cuda" if torch.cuda.is_available() else "cpu",
help="device to use for PyTorch inference",
)
+ parser.add_argument(
+ "--beam_size",
+ "-b",
+ type=int,
+ default=5,
+ )
return parser.parse_args()
def main() -> None:
opts = get_opts()
+ if opts.beam_size <= 0:
+ opts.beam_size = None
config = WhisperConfig(
model_name=opts.model,
language=opts.language,
device=opts.device,
+ beam_size=opts.beam_size,
)
transcribe_from_mic(config=config)