mirror of
https://github.com/shirayu/whispering.git
synced 2024-06-13 10:49:22 +00:00
Multilanguage PR
This commit is contained in:
parent
de0a864f79
commit
40354c4d97
|
@ -13,12 +13,18 @@ import torch
|
|||
from whisper import available_models
|
||||
from whisper.audio import N_FRAMES, SAMPLE_RATE
|
||||
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
|
||||
|
||||
"""
|
||||
from whispering.pbar import ProgressBar
|
||||
from whispering.schema import Context, WhisperConfig
|
||||
from whispering.serve import serve_with_websocket
|
||||
from whispering.transcriber import WhisperStreamingTranscriber
|
||||
from whispering.websocket_client import run_websocket_client
|
||||
"""
|
||||
from pbar import ProgressBar
|
||||
from schema import Context, WhisperConfig
|
||||
from serve import serve_with_websocket
|
||||
from transcriber import WhisperStreamingTranscriber
|
||||
from websocket_client import run_websocket_client
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -105,7 +111,7 @@ def get_opts() -> argparse.Namespace:
|
|||
"--language",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=sorted(LANGUAGES.keys())
|
||||
choices=["multilanguage"] + sorted(LANGUAGES.keys())
|
||||
+ sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]),
|
||||
)
|
||||
group_model.add_argument(
|
||||
|
|
|
@ -9,7 +9,7 @@ import numpy as np
|
|||
import websockets
|
||||
from websockets.exceptions import ConnectionClosedOK
|
||||
|
||||
from whispering.transcriber import Context, WhisperStreamingTranscriber
|
||||
from transcriber import Context, WhisperStreamingTranscriber
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -18,7 +18,6 @@ async def serve_with_websocket_main(websocket):
|
|||
global g_wsp
|
||||
idx: int = 0
|
||||
ctx: Optional[Context] = None
|
||||
|
||||
while True:
|
||||
logger.debug(f"Audio #: {idx}")
|
||||
try:
|
||||
|
@ -78,7 +77,7 @@ async def serve_with_websocket(
|
|||
serve_with_websocket_main,
|
||||
host=host,
|
||||
port=port,
|
||||
max_size=999999999,
|
||||
max_size=999999999
|
||||
):
|
||||
await asyncio.Future()
|
||||
except KeyboardInterrupt:
|
||||
|
|
|
@ -18,8 +18,8 @@ from whisper.decoding import DecodingOptions, DecodingResult
|
|||
from whisper.tokenizer import get_tokenizer
|
||||
from whisper.utils import exact_div
|
||||
|
||||
from whispering.schema import Context, ParsedChunk, WhisperConfig
|
||||
from whispering.vad import VAD
|
||||
from schema import Context, ParsedChunk, WhisperConfig
|
||||
from vad import VAD
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -41,11 +41,19 @@ class WhisperStreamingTranscriber:
|
|||
def __init__(self, *, config: WhisperConfig):
|
||||
self.config: Final[WhisperConfig] = config
|
||||
self.model: Final[Whisper] = load_model(config.model_name, device=config.device)
|
||||
self.tokenizer = get_tokenizer(
|
||||
self.model.is_multilingual,
|
||||
language=config.language,
|
||||
task="transcribe",
|
||||
)
|
||||
# language specified
|
||||
if config.language != "multilanguage":
|
||||
self.tokenizer = get_tokenizer(
|
||||
self.model.is_multilingual,
|
||||
language=config.language,
|
||||
task="transcribe",
|
||||
)
|
||||
# Mulilanguage transcripts
|
||||
else:
|
||||
self.tokenizer = get_tokenizer(
|
||||
self.model.is_multilingual,
|
||||
task="transcribe",
|
||||
)
|
||||
self._set_dtype(config.fp16)
|
||||
self.input_stride: Final[int] = exact_div(
|
||||
N_FRAMES, self.model.dims.n_audio_ctx
|
||||
|
@ -65,23 +73,41 @@ class WhisperStreamingTranscriber:
|
|||
patience: Optional[float],
|
||||
best_of: Optional[int],
|
||||
) -> DecodingOptions:
|
||||
return DecodingOptions(
|
||||
task="transcribe",
|
||||
language=self.config.language,
|
||||
temperature=t,
|
||||
sample_len=None,
|
||||
best_of=best_of,
|
||||
beam_size=beam_size,
|
||||
patience=patience,
|
||||
length_penalty=None,
|
||||
prompt=prompt,
|
||||
prefix=None,
|
||||
suppress_blank=True,
|
||||
suppress_tokens="-1",
|
||||
without_timestamps=False,
|
||||
max_initial_timestamp=1.0,
|
||||
fp16=self.fp16,
|
||||
)
|
||||
if self.config.language != "multilanguage":
|
||||
return DecodingOptions(
|
||||
task="transcribe",
|
||||
language=self.config.language,
|
||||
temperature=t,
|
||||
sample_len=None,
|
||||
best_of=best_of,
|
||||
beam_size=beam_size,
|
||||
patience=patience,
|
||||
length_penalty=None,
|
||||
prompt=prompt,
|
||||
prefix=None,
|
||||
suppress_blank=True,
|
||||
suppress_tokens="-1",
|
||||
without_timestamps=False,
|
||||
max_initial_timestamp=1.0,
|
||||
fp16=self.fp16,
|
||||
)
|
||||
else:
|
||||
return DecodingOptions(
|
||||
task="transcribe",
|
||||
temperature=t,
|
||||
sample_len=None,
|
||||
best_of=best_of,
|
||||
beam_size=beam_size,
|
||||
patience=patience,
|
||||
length_penalty=None,
|
||||
prompt=prompt,
|
||||
prefix=None,
|
||||
suppress_blank=True,
|
||||
suppress_tokens="-1",
|
||||
without_timestamps=False,
|
||||
max_initial_timestamp=1.0,
|
||||
fp16=self.fp16,
|
||||
)
|
||||
|
||||
def _decode_with_fallback(
|
||||
self,
|
||||
|
@ -106,7 +132,6 @@ class WhisperStreamingTranscriber:
|
|||
_decode_options,
|
||||
) # type: ignore
|
||||
assert decode_result is not None
|
||||
|
||||
needs_fallback: bool = False
|
||||
if (
|
||||
ctx.compression_ratio_threshold is not None
|
||||
|
|
|
@ -6,7 +6,7 @@ import numpy as np
|
|||
import torch
|
||||
from whisper.audio import N_FRAMES, SAMPLE_RATE
|
||||
|
||||
from whispering.schema import SpeechSegment
|
||||
from schema import SpeechSegment
|
||||
|
||||
|
||||
class VAD:
|
||||
|
|
|
@ -9,8 +9,8 @@ import sounddevice as sd
|
|||
import websockets
|
||||
from whisper.audio import N_FRAMES, SAMPLE_RATE
|
||||
|
||||
from whispering.schema import ParsedChunk
|
||||
from whispering.transcriber import Context
|
||||
from schema import ParsedChunk
|
||||
from transcriber import Context
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
|
Loading…
Reference in a new issue