From a9452cde76d2b484e74b72c5ecd5630afebd057d Mon Sep 17 00:00:00 2001 From: AlexandraRamassamy <79447930+AlexandraRamassamy@users.noreply.github.com> Date: Thu, 6 Oct 2022 13:53:23 +0200 Subject: [PATCH] Multi language feature (#20, Resolve #19) Add "multi" to languages for multi language transcribing --- whispering/cli.py | 7 ++++--- whispering/schema.py | 6 ++++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/whispering/cli.py b/whispering/cli.py index f804f1e..548412b 100644 --- a/whispering/cli.py +++ b/whispering/cli.py @@ -16,7 +16,7 @@ from whisper.audio import N_FRAMES, SAMPLE_RATE from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE from whispering.pbar import ProgressBar -from whispering.schema import Context, StdoutWriter, WhisperConfig +from whispering.schema import MULTI_LANGUAGE, Context, StdoutWriter, WhisperConfig from whispering.serve import serve_with_websocket from whispering.transcriber import WhisperStreamingTranscriber from whispering.websocket_client import run_websocket_client @@ -106,7 +106,8 @@ def get_opts() -> argparse.Namespace: "--language", type=str, default=None, - choices=sorted(LANGUAGES.keys()) + choices=[MULTI_LANGUAGE] + + sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), ) group_model.add_argument( @@ -205,7 +206,7 @@ def get_opts() -> argparse.Namespace: def get_wshiper(*, opts) -> WhisperStreamingTranscriber: config = WhisperConfig( model_name=opts.model, - language=opts.language, + language=opts.language if opts.language == MULTI_LANGUAGE else None, device=opts.device, ) diff --git a/whispering/schema.py b/whispering/schema.py index 46acbfb..39b2318 100644 --- a/whispering/schema.py +++ b/whispering/schema.py @@ -1,17 +1,19 @@ #!/usr/bin/env python3 import sys -from typing import List, Optional +from typing import Final, List, Optional import numpy as np import torch from pydantic import BaseModel, root_validator +MULTI_LANGUAGE: Final[str] = "multi" + class WhisperConfig(BaseModel): model_name: str device: str - language: str + language: Optional[str] fp16: bool = True @root_validator