mirror of
https://github.com/shirayu/whispering.git
synced 2024-06-13 02:39:23 +00:00
Refactoring
This commit is contained in:
parent
21d3ab6945
commit
58e374dd07
|
@ -206,7 +206,7 @@ def get_opts() -> argparse.Namespace:
|
|||
def get_wshiper(*, opts) -> WhisperStreamingTranscriber:
|
||||
config = WhisperConfig(
|
||||
model_name=opts.model,
|
||||
language=opts.language,
|
||||
language=opts.language if opts.language == MULTI_LANGUAGE else None,
|
||||
device=opts.device,
|
||||
)
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ MULTI_LANGUAGE: Final[str] = "multi"
|
|||
class WhisperConfig(BaseModel):
|
||||
model_name: str
|
||||
device: str
|
||||
language: str
|
||||
language: Optional[str]
|
||||
fp16: bool = True
|
||||
|
||||
@root_validator
|
||||
|
|
|
@ -18,7 +18,7 @@ from whisper.decoding import DecodingOptions, DecodingResult
|
|||
from whisper.tokenizer import get_tokenizer
|
||||
from whisper.utils import exact_div
|
||||
|
||||
from whispering.schema import MULTI_LANGUAGE, Context, ParsedChunk, WhisperConfig
|
||||
from whispering.schema import Context, ParsedChunk, WhisperConfig
|
||||
from whispering.vad import VAD
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
@ -43,7 +43,7 @@ class WhisperStreamingTranscriber:
|
|||
self.model: Final[Whisper] = load_model(config.model_name, device=config.device)
|
||||
self.tokenizer = get_tokenizer(
|
||||
self.model.is_multilingual,
|
||||
language=config.language if config.language != MULTI_LANGUAGE else None,
|
||||
language=config.language,
|
||||
task="transcribe",
|
||||
)
|
||||
self._set_dtype(config.fp16)
|
||||
|
@ -65,41 +65,23 @@ class WhisperStreamingTranscriber:
|
|||
patience: Optional[float],
|
||||
best_of: Optional[int],
|
||||
) -> DecodingOptions:
|
||||
if self.config.language != MULTI_LANGUAGE:
|
||||
return DecodingOptions(
|
||||
task="transcribe",
|
||||
language=self.config.language,
|
||||
temperature=t,
|
||||
sample_len=None,
|
||||
best_of=best_of,
|
||||
beam_size=beam_size,
|
||||
patience=patience,
|
||||
length_penalty=None,
|
||||
prompt=prompt,
|
||||
prefix=None,
|
||||
suppress_blank=True,
|
||||
suppress_tokens="-1",
|
||||
without_timestamps=False,
|
||||
max_initial_timestamp=1.0,
|
||||
fp16=self.fp16,
|
||||
)
|
||||
else:
|
||||
return DecodingOptions(
|
||||
task="transcribe",
|
||||
temperature=t,
|
||||
sample_len=None,
|
||||
best_of=best_of,
|
||||
beam_size=beam_size,
|
||||
patience=patience,
|
||||
length_penalty=None,
|
||||
prompt=prompt,
|
||||
prefix=None,
|
||||
suppress_blank=True,
|
||||
suppress_tokens="-1",
|
||||
without_timestamps=False,
|
||||
max_initial_timestamp=1.0,
|
||||
fp16=self.fp16,
|
||||
)
|
||||
return DecodingOptions(
|
||||
task="transcribe",
|
||||
language=self.config.language,
|
||||
temperature=t,
|
||||
sample_len=None,
|
||||
best_of=best_of,
|
||||
beam_size=beam_size,
|
||||
patience=patience,
|
||||
length_penalty=None,
|
||||
prompt=prompt,
|
||||
prefix=None,
|
||||
suppress_blank=True,
|
||||
suppress_tokens="-1",
|
||||
without_timestamps=False,
|
||||
max_initial_timestamp=1.0,
|
||||
fp16=self.fp16,
|
||||
)
|
||||
|
||||
def _decode_with_fallback(
|
||||
self,
|
||||
|
|
Loading…
Reference in a new issue