Refactoring

This commit is contained in:
Yuta Hayashibe 2022-10-06 20:49:41 +09:00
parent 21d3ab6945
commit 58e374dd07
3 changed files with 21 additions and 39 deletions

View file

@ -206,7 +206,7 @@ def get_opts() -> argparse.Namespace:
def get_wshiper(*, opts) -> WhisperStreamingTranscriber:
config = WhisperConfig(
model_name=opts.model,
language=opts.language,
language=opts.language if opts.language == MULTI_LANGUAGE else None,
device=opts.device,
)

View file

@ -13,7 +13,7 @@ MULTI_LANGUAGE: Final[str] = "multi"
class WhisperConfig(BaseModel):
model_name: str
device: str
language: str
language: Optional[str]
fp16: bool = True
@root_validator

View file

@ -18,7 +18,7 @@ from whisper.decoding import DecodingOptions, DecodingResult
from whisper.tokenizer import get_tokenizer
from whisper.utils import exact_div
from whispering.schema import MULTI_LANGUAGE, Context, ParsedChunk, WhisperConfig
from whispering.schema import Context, ParsedChunk, WhisperConfig
from whispering.vad import VAD
logger = getLogger(__name__)
@ -43,7 +43,7 @@ class WhisperStreamingTranscriber:
self.model: Final[Whisper] = load_model(config.model_name, device=config.device)
self.tokenizer = get_tokenizer(
self.model.is_multilingual,
language=config.language if config.language != MULTI_LANGUAGE else None,
language=config.language,
task="transcribe",
)
self._set_dtype(config.fp16)
@ -65,41 +65,23 @@ class WhisperStreamingTranscriber:
patience: Optional[float],
best_of: Optional[int],
) -> DecodingOptions:
if self.config.language != MULTI_LANGUAGE:
return DecodingOptions(
task="transcribe",
language=self.config.language,
temperature=t,
sample_len=None,
best_of=best_of,
beam_size=beam_size,
patience=patience,
length_penalty=None,
prompt=prompt,
prefix=None,
suppress_blank=True,
suppress_tokens="-1",
without_timestamps=False,
max_initial_timestamp=1.0,
fp16=self.fp16,
)
else:
return DecodingOptions(
task="transcribe",
temperature=t,
sample_len=None,
best_of=best_of,
beam_size=beam_size,
patience=patience,
length_penalty=None,
prompt=prompt,
prefix=None,
suppress_blank=True,
suppress_tokens="-1",
without_timestamps=False,
max_initial_timestamp=1.0,
fp16=self.fp16,
)
return DecodingOptions(
task="transcribe",
language=self.config.language,
temperature=t,
sample_len=None,
best_of=best_of,
beam_size=beam_size,
patience=patience,
length_penalty=None,
prompt=prompt,
prefix=None,
suppress_blank=True,
suppress_tokens="-1",
without_timestamps=False,
max_initial_timestamp=1.0,
fp16=self.fp16,
)
def _decode_with_fallback(
self,