mirror of
https://github.com/shirayu/whispering.git
synced 2024-11-25 18:31:00 +00:00
65 lines
1.9 KiB
Text
65 lines
1.9 KiB
Text
diff --git a/whisper_streaming/__init__.py b/whisper_streaming/__init__.py
|
|
index eb8bcb0..d246d29 100644
|
|
--- a/whisper_streaming/__init__.py
|
|
+++ b/whisper_streaming/__init__.py
|
|
@@ -25,6 +25,7 @@ class WhisperConfig(BaseModel):
|
|
logprob_threshold: Optional[float] = -1.0
|
|
no_captions_threshold: Optional[float] = 0.6
|
|
best_of: int = 5
|
|
+ beam_size: Optional[int] = None
|
|
|
|
|
|
class WhisperStreamingTranscriber:
|
|
@@ -42,7 +43,7 @@ class WhisperStreamingTranscriber:
|
|
self,
|
|
*,
|
|
t,
|
|
- beam_size,
|
|
+ beam_size: Optional[int],
|
|
patience,
|
|
best_of,
|
|
) -> DecodingOptions:
|
|
@@ -81,7 +82,7 @@ class WhisperStreamingTranscriber:
|
|
|
|
_decode_options: DecodingOptions = self._get_decoding_options(
|
|
t=t,
|
|
- beam_size=None,
|
|
+ beam_size=self.config.beam_size,
|
|
patience=0.0,
|
|
best_of=None,
|
|
)
|
|
@@ -100,7 +101,7 @@ class WhisperStreamingTranscriber:
|
|
t=t,
|
|
beam_size=None,
|
|
patience=None,
|
|
- best_of=None,
|
|
+ best_of=self.config.best_of,
|
|
)
|
|
retries: List[DecodingResult] = self.model.decode(
|
|
segment[needs_fallback], _decode_options # type: ignore
|
|
@@ -158,16 +159,25 @@ def get_opts() -> argparse.Namespace:
|
|
default="cuda" if torch.cuda.is_available() else "cpu",
|
|
help="device to use for PyTorch inference",
|
|
)
|
|
+ parser.add_argument(
|
|
+ "--beam_size",
|
|
+ "-b",
|
|
+ type=int,
|
|
+ default=5,
|
|
+ )
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
def main() -> None:
|
|
opts = get_opts()
|
|
+ if opts.beam_size <= 0:
|
|
+ opts.beam_size = None
|
|
config = WhisperConfig(
|
|
model_name=opts.model,
|
|
language=opts.language,
|
|
device=opts.device,
|
|
+ beam_size=opts.beam_size,
|
|
)
|
|
transcribe_from_mic(config=config)
|
|
|