diff --git a/whisper_streaming/cli.py b/whisper_streaming/cli.py index d0fca4c..632caef 100644 --- a/whisper_streaming/cli.py +++ b/whisper_streaming/cli.py @@ -23,6 +23,7 @@ def transcribe_from_mic( sd_device: Optional[Union[int, str]], num_block: int, ) -> None: + logger.debug(f"WhisperConfig: {config}") wsp = WhisperStreamingTranscriber(config=config) q = queue.Queue() @@ -83,6 +84,14 @@ def get_opts() -> argparse.Namespace: default=20, help="Number of operation unit. Larger values can improve accuracy but consume more memory.", ) + parser.add_argument( + "--temperature", + "-t", + type=float, + action="append", + default=[], + ) + parser.add_argument( "--mic", ) @@ -103,6 +112,9 @@ def main() -> None: if opts.beam_size <= 0: opts.beam_size = None + if len(opts.temperature) == 0: + opts.temperature = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0] + try: opts.mic = int(opts.mic) except Exception: @@ -113,6 +125,7 @@ def main() -> None: language=opts.language, device=opts.device, beam_size=opts.beam_size, + temperatures=opts.temperature, ) transcribe_from_mic( config=config, diff --git a/whisper_streaming/schema.py b/whisper_streaming/schema.py index 3dc1b1c..f7fedf8 100644 --- a/whisper_streaming/schema.py +++ b/whisper_streaming/schema.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -from typing import List, Optional, Tuple +from typing import List, Optional from pydantic import BaseModel @@ -10,8 +10,8 @@ class WhisperConfig(BaseModel): device: str language: str + temperatures: List[float] fp16: bool = True - temperatures: Tuple[float, ...] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0) compression_ratio_threshold: Optional[float] = 2.4 logprob_threshold: Optional[float] = -1.0 no_captions_threshold: Optional[float] = 0.6