Multilanguage PR

This commit is contained in:
Alexandra Ramassamy 2022-10-04 12:16:51 +02:00
parent de0a864f79
commit 40354c4d97
5 changed files with 63 additions and 33 deletions

View file

@ -13,12 +13,18 @@ import torch
from whisper import available_models
from whisper.audio import N_FRAMES, SAMPLE_RATE
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
"""
from whispering.pbar import ProgressBar
from whispering.schema import Context, WhisperConfig
from whispering.serve import serve_with_websocket
from whispering.transcriber import WhisperStreamingTranscriber
from whispering.websocket_client import run_websocket_client
"""
from pbar import ProgressBar
from schema import Context, WhisperConfig
from serve import serve_with_websocket
from transcriber import WhisperStreamingTranscriber
from websocket_client import run_websocket_client
logger = getLogger(__name__)
@ -105,7 +111,7 @@ def get_opts() -> argparse.Namespace:
"--language",
type=str,
default=None,
choices=sorted(LANGUAGES.keys())
choices=["multilanguage"] + sorted(LANGUAGES.keys())
+ sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]),
)
group_model.add_argument(

View file

@ -9,7 +9,7 @@ import numpy as np
import websockets
from websockets.exceptions import ConnectionClosedOK
from whispering.transcriber import Context, WhisperStreamingTranscriber
from transcriber import Context, WhisperStreamingTranscriber
logger = getLogger(__name__)
@ -18,7 +18,6 @@ async def serve_with_websocket_main(websocket):
global g_wsp
idx: int = 0
ctx: Optional[Context] = None
while True:
logger.debug(f"Audio #: {idx}")
try:
@ -78,7 +77,7 @@ async def serve_with_websocket(
serve_with_websocket_main,
host=host,
port=port,
max_size=999999999,
max_size=999999999
):
await asyncio.Future()
except KeyboardInterrupt:

View file

@ -18,8 +18,8 @@ from whisper.decoding import DecodingOptions, DecodingResult
from whisper.tokenizer import get_tokenizer
from whisper.utils import exact_div
from whispering.schema import Context, ParsedChunk, WhisperConfig
from whispering.vad import VAD
from schema import Context, ParsedChunk, WhisperConfig
from vad import VAD
logger = getLogger(__name__)
@ -41,11 +41,19 @@ class WhisperStreamingTranscriber:
def __init__(self, *, config: WhisperConfig):
self.config: Final[WhisperConfig] = config
self.model: Final[Whisper] = load_model(config.model_name, device=config.device)
self.tokenizer = get_tokenizer(
self.model.is_multilingual,
language=config.language,
task="transcribe",
)
# language specified
if config.language != "multilanguage":
self.tokenizer = get_tokenizer(
self.model.is_multilingual,
language=config.language,
task="transcribe",
)
# Mulilanguage transcripts
else:
self.tokenizer = get_tokenizer(
self.model.is_multilingual,
task="transcribe",
)
self._set_dtype(config.fp16)
self.input_stride: Final[int] = exact_div(
N_FRAMES, self.model.dims.n_audio_ctx
@ -65,23 +73,41 @@ class WhisperStreamingTranscriber:
patience: Optional[float],
best_of: Optional[int],
) -> DecodingOptions:
return DecodingOptions(
task="transcribe",
language=self.config.language,
temperature=t,
sample_len=None,
best_of=best_of,
beam_size=beam_size,
patience=patience,
length_penalty=None,
prompt=prompt,
prefix=None,
suppress_blank=True,
suppress_tokens="-1",
without_timestamps=False,
max_initial_timestamp=1.0,
fp16=self.fp16,
)
if self.config.language != "multilanguage":
return DecodingOptions(
task="transcribe",
language=self.config.language,
temperature=t,
sample_len=None,
best_of=best_of,
beam_size=beam_size,
patience=patience,
length_penalty=None,
prompt=prompt,
prefix=None,
suppress_blank=True,
suppress_tokens="-1",
without_timestamps=False,
max_initial_timestamp=1.0,
fp16=self.fp16,
)
else:
return DecodingOptions(
task="transcribe",
temperature=t,
sample_len=None,
best_of=best_of,
beam_size=beam_size,
patience=patience,
length_penalty=None,
prompt=prompt,
prefix=None,
suppress_blank=True,
suppress_tokens="-1",
without_timestamps=False,
max_initial_timestamp=1.0,
fp16=self.fp16,
)
def _decode_with_fallback(
self,
@ -106,7 +132,6 @@ class WhisperStreamingTranscriber:
_decode_options,
) # type: ignore
assert decode_result is not None
needs_fallback: bool = False
if (
ctx.compression_ratio_threshold is not None

View file

@ -6,7 +6,7 @@ import numpy as np
import torch
from whisper.audio import N_FRAMES, SAMPLE_RATE
from whispering.schema import SpeechSegment
from schema import SpeechSegment
class VAD:

View file

@ -9,8 +9,8 @@ import sounddevice as sd
import websockets
from whisper.audio import N_FRAMES, SAMPLE_RATE
from whispering.schema import ParsedChunk
from whispering.transcriber import Context
from schema import ParsedChunk
from transcriber import Context
logger = getLogger(__name__)