mirror of
https://github.com/shirayu/whispering.git
synced 2024-11-10 18:51:08 +00:00
112 lines
3.7 KiB
Python
112 lines
3.7 KiB
Python
|
#!/usr/bin/env python3
|
||
|
|
||
|
from typing import List, Optional
|
||
|
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
from whisper import Whisper, load_model
|
||
|
from whisper.audio import N_FRAMES, log_mel_spectrogram, pad_or_trim
|
||
|
from whisper.decoding import DecodingOptions, DecodingResult
|
||
|
from whisper.tokenizer import get_tokenizer
|
||
|
|
||
|
from whisper_streaming.schema import WhisperConfig
|
||
|
|
||
|
|
||
|
class WhisperStreamingTranscriber:
|
||
|
def __init__(self, *, config: WhisperConfig):
|
||
|
self.config: WhisperConfig = config
|
||
|
self.model: Whisper = load_model(config.model_name, device=config.device)
|
||
|
self.tokenizer = get_tokenizer(
|
||
|
self.model.is_multilingual,
|
||
|
language=config.language,
|
||
|
task="transcribe",
|
||
|
)
|
||
|
self.dtype = torch.float16
|
||
|
|
||
|
def _get_decoding_options(
|
||
|
self,
|
||
|
*,
|
||
|
t,
|
||
|
beam_size: Optional[int],
|
||
|
patience: float,
|
||
|
best_of: Optional[int],
|
||
|
) -> DecodingOptions:
|
||
|
return DecodingOptions(
|
||
|
task="transcribe",
|
||
|
language=None,
|
||
|
temperature=t,
|
||
|
sample_len=None,
|
||
|
best_of=best_of,
|
||
|
beam_size=beam_size,
|
||
|
patience=patience,
|
||
|
length_penalty=None,
|
||
|
prompt=None,
|
||
|
prefix=None,
|
||
|
suppress_blank=True,
|
||
|
suppress_tokens="-1",
|
||
|
without_timestamps=False,
|
||
|
max_initial_timestamp=0.0,
|
||
|
fp16=True,
|
||
|
)
|
||
|
|
||
|
def _decode_with_fallback(self, *, segment: np.ndarray) -> List[DecodingResult]:
|
||
|
assert len(self.config.temperatures) >= 1
|
||
|
t = self.config.temperatures[0]
|
||
|
|
||
|
_decode_options1: DecodingOptions = self._get_decoding_options(
|
||
|
t=t,
|
||
|
beam_size=self.config.beam_size,
|
||
|
patience=0.0,
|
||
|
best_of=None,
|
||
|
)
|
||
|
results: List[DecodingResult] = self.model.decode(segment, _decode_options1) # type: ignore
|
||
|
|
||
|
for t in self.config.temperatures[1:]:
|
||
|
needs_fallback = [
|
||
|
self.config.compression_ratio_threshold is not None
|
||
|
and result.compression_ratio > self.config.compression_ratio_threshold
|
||
|
or self.config.logprob_threshold is not None
|
||
|
and result.avg_logprob < self.config.logprob_threshold
|
||
|
for result in results
|
||
|
]
|
||
|
if any(needs_fallback):
|
||
|
_decode_options2: DecodingOptions = self._get_decoding_options(
|
||
|
t=t,
|
||
|
beam_size=None,
|
||
|
patience=0.0,
|
||
|
best_of=self.config.best_of,
|
||
|
)
|
||
|
retries: List[DecodingResult] = self.model.decode(
|
||
|
segment[needs_fallback], _decode_options2 # type: ignore
|
||
|
)
|
||
|
for retry_index, original_index in enumerate(
|
||
|
np.nonzero(needs_fallback)[0]
|
||
|
):
|
||
|
results[original_index] = retries[retry_index]
|
||
|
return results
|
||
|
|
||
|
def transcribe(
|
||
|
self,
|
||
|
*,
|
||
|
segment: np.ndarray,
|
||
|
) -> Optional[DecodingResult]:
|
||
|
log_spec = log_mel_spectrogram(audio=segment).unsqueeze(0)
|
||
|
segment = (
|
||
|
pad_or_trim(log_spec, N_FRAMES)
|
||
|
.to(self.model.device) # type:ignore
|
||
|
.to(self.dtype) # type:ignore
|
||
|
)
|
||
|
results = self._decode_with_fallback(segment=segment)
|
||
|
result = results[0]
|
||
|
|
||
|
if self.config.no_speech_threshold is not None:
|
||
|
if (result.no_speech_prob > self.config.no_speech_threshold) and not (
|
||
|
self.config.logprob_threshold is not None
|
||
|
and result.avg_logprob > self.config.logprob_threshold
|
||
|
):
|
||
|
return
|
||
|
|
||
|
# FIXME: work with timestamp
|
||
|
|
||
|
return result
|