whispering/whispering/transcriber.py

349 lines
12 KiB
Python
Raw Permalink Normal View History

2022-09-23 10:20:11 +00:00
#!/usr/bin/env python3
2022-09-23 13:13:25 +00:00
from logging import getLogger
2022-10-01 14:20:46 +00:00
from typing import Final, Iterator, Optional, Union
2022-09-23 10:20:11 +00:00
2022-10-02 10:47:17 +00:00
import numpy as np
2022-09-23 10:20:11 +00:00
import torch
from whisper import Whisper, load_model
2022-09-23 11:03:00 +00:00
from whisper.audio import (
2022-10-02 11:30:51 +00:00
CHUNK_LENGTH,
2022-09-23 11:03:00 +00:00
HOP_LENGTH,
N_FRAMES,
SAMPLE_RATE,
log_mel_spectrogram,
pad_or_trim,
)
2022-09-23 10:20:11 +00:00
from whisper.decoding import DecodingOptions, DecodingResult
from whisper.tokenizer import get_tokenizer
2022-09-23 11:03:00 +00:00
from whisper.utils import exact_div
2022-09-23 10:20:11 +00:00
2022-09-29 11:14:56 +00:00
from whispering.schema import Context, ParsedChunk, WhisperConfig
2022-10-01 14:21:58 +00:00
from whispering.vad import VAD
2022-09-23 10:20:11 +00:00
2022-09-23 13:01:40 +00:00
logger = getLogger(__name__)
2022-09-23 10:20:11 +00:00
class WhisperStreamingTranscriber:
2022-09-23 13:39:27 +00:00
def _set_dtype(self, fp16: bool):
self.fp16 = fp16
self.dtype = torch.float16 if fp16 else torch.float32
if self.model.device == torch.device("cpu"):
if torch.cuda.is_available():
2023-01-06 16:18:41 +00:00
logger.info("Performing inference on CPU though CUDA is available")
2022-09-23 13:39:27 +00:00
if self.dtype == torch.float16:
2023-01-06 16:18:41 +00:00
logger.info("Using FP32 because FP16 is not supported on CPU")
2022-09-23 13:39:27 +00:00
self.dtype = torch.float32
if self.dtype == torch.float32:
self.fp16 = False
2022-09-23 10:20:11 +00:00
def __init__(self, *, config: WhisperConfig):
2022-09-29 11:14:56 +00:00
self.config: Final[WhisperConfig] = config
self.model: Final[Whisper] = load_model(config.model_name, device=config.device)
2022-09-23 10:20:11 +00:00
self.tokenizer = get_tokenizer(
self.model.is_multilingual,
language=config.language,
task="transcribe",
)
2022-09-23 13:39:27 +00:00
self._set_dtype(config.fp16)
2022-09-29 11:14:56 +00:00
self.input_stride: Final[int] = exact_div(
2022-09-23 11:03:00 +00:00
N_FRAMES, self.model.dims.n_audio_ctx
) # mel frames per output token: 2
2022-09-29 11:14:56 +00:00
self.time_precision: Final[float] = (
2022-09-23 11:03:00 +00:00
self.input_stride * HOP_LENGTH / SAMPLE_RATE
) # time per output token: 0.02 (seconds)
2022-10-02 11:30:51 +00:00
self.duration_pre_one_mel: Final[float] = CHUNK_LENGTH / HOP_LENGTH
2022-10-01 14:21:58 +00:00
self.vad = VAD()
2022-09-23 11:03:00 +00:00
2022-09-23 10:20:11 +00:00
def _get_decoding_options(
self,
*,
t,
2022-09-23 12:25:05 +00:00
prompt,
2022-09-23 10:20:11 +00:00
beam_size: Optional[int],
2022-09-29 11:29:45 +00:00
patience: Optional[float],
2022-09-23 10:20:11 +00:00
best_of: Optional[int],
) -> DecodingOptions:
return DecodingOptions(
task="transcribe",
2022-09-24 00:42:10 +00:00
language=self.config.language,
2022-09-23 10:20:11 +00:00
temperature=t,
sample_len=None,
best_of=best_of,
beam_size=beam_size,
patience=patience,
length_penalty=None,
2022-09-23 12:25:05 +00:00
prompt=prompt,
2022-09-23 10:20:11 +00:00
prefix=None,
suppress_blank=True,
suppress_tokens="-1",
without_timestamps=False,
2022-10-01 14:20:46 +00:00
max_initial_timestamp=1.0,
2022-09-23 13:39:27 +00:00
fp16=self.fp16,
2022-09-23 10:20:11 +00:00
)
2022-09-23 12:25:05 +00:00
def _decode_with_fallback(
self,
*,
2022-10-01 14:20:46 +00:00
segment: torch.Tensor,
2022-09-29 11:14:56 +00:00
ctx: Context,
2022-10-01 14:20:46 +00:00
) -> DecodingResult:
2022-09-29 11:43:49 +00:00
assert len(ctx.temperatures) >= 1
2022-10-01 14:20:46 +00:00
decode_result: Optional[DecodingResult] = None
2022-09-23 10:20:11 +00:00
2022-10-01 14:20:46 +00:00
for t in ctx.temperatures:
_decode_options: DecodingOptions = self._get_decoding_options(
t=t,
prompt=ctx.buffer_tokens,
beam_size=ctx.beam_size if t <= 0 else None,
patience=ctx.patience if t <= 0 else None,
best_of=ctx.best_of if t < 0 else None,
)
logger.debug(f"DecodeOptions: {_decode_options}")
decode_result = self.model.decode(
segment,
_decode_options,
) # type: ignore
assert decode_result is not None
2022-09-23 10:20:11 +00:00
2022-10-01 14:20:46 +00:00
needs_fallback: bool = False
if (
2022-09-29 11:43:49 +00:00
ctx.compression_ratio_threshold is not None
2022-10-01 14:20:46 +00:00
and decode_result.compression_ratio > ctx.compression_ratio_threshold
):
needs_fallback = True # too repetitive
if (
ctx.logprob_threshold is not None
and decode_result.avg_logprob < ctx.logprob_threshold
):
needs_fallback = True # average log probability is too low
if not needs_fallback:
2022-09-24 04:06:22 +00:00
break
2022-10-01 14:20:46 +00:00
assert isinstance(decode_result, DecodingResult)
return decode_result
2022-09-23 10:20:11 +00:00
2022-09-23 11:03:00 +00:00
def _get_chunk(
self,
*,
start: float,
end: float,
text_tokens: torch.Tensor,
result: DecodingResult,
) -> Optional[ParsedChunk]:
text = self.tokenizer.decode(
[token for token in text_tokens if token < self.tokenizer.eot] # type: ignore
)
if len(text.strip()) == 0: # skip empty text output
return
return ParsedChunk(
start=start,
end=end,
text=text,
tokens=result.tokens,
temperature=result.temperature,
avg_logprob=result.avg_logprob,
compression_ratio=result.compression_ratio,
no_speech_prob=result.no_speech_prob,
)
2022-09-23 11:41:44 +00:00
def _deal_timestamp(
2022-09-29 11:14:56 +00:00
self,
*,
result,
segment_duration,
ctx: Context,
2022-09-23 11:41:44 +00:00
) -> Iterator[Union[ParsedChunk, int]]:
2022-09-23 11:03:00 +00:00
tokens = torch.tensor(result.tokens)
timestamp_tokens: torch.Tensor = tokens.ge(self.tokenizer.timestamp_begin)
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(
1
)
if (
len(consecutive) > 0
): # if the output contains two consecutive timestamp tokens
2022-09-23 13:31:26 +00:00
logger.debug(f"Length of consecutive: {len(consecutive)}")
2022-09-23 11:03:00 +00:00
last_slice = 0
for current_slice in consecutive:
sliced_tokens = tokens[last_slice:current_slice]
2022-09-24 15:28:05 +00:00
logger.debug(f" last_slice={last_slice}, current_slice={current_slice}")
2022-09-23 11:03:00 +00:00
start_timestamp_position = (
sliced_tokens[0].item() - self.tokenizer.timestamp_begin
)
end_timestamp_position = (
sliced_tokens[-1].item() - self.tokenizer.timestamp_begin
)
chunk = self._get_chunk(
2022-09-29 11:14:56 +00:00
start=ctx.timestamp
2022-09-23 11:03:00 +00:00
+ start_timestamp_position * self.time_precision,
2022-09-29 11:14:56 +00:00
end=ctx.timestamp + end_timestamp_position * self.time_precision,
2022-09-23 11:03:00 +00:00
text_tokens=sliced_tokens[1:-1],
result=result,
)
if chunk is not None:
yield chunk
last_slice = current_slice
2022-09-23 11:41:44 +00:00
last_timestamp_position0: int = (
tokens[last_slice - 1].item()
- self.tokenizer.timestamp_begin # type:ignore
2022-09-23 11:03:00 +00:00
)
2022-09-29 11:14:56 +00:00
ctx.buffer_tokens.extend(tokens[: last_slice + 1].tolist())
ctx.timestamp += last_timestamp_position0 * self.time_precision
2022-09-23 11:41:44 +00:00
yield last_timestamp_position0
2022-09-23 11:03:00 +00:00
else:
duration = segment_duration
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
2022-09-24 05:22:33 +00:00
logger.debug(f"Length of consecutive: 0, timestamps: {timestamps}")
2022-10-01 13:45:45 +00:00
if (
len(timestamps) > 0
and timestamps[-1].item() != self.tokenizer.timestamp_begin
):
2022-09-23 11:03:00 +00:00
# no consecutive timestamps but it has a timestamp; use the last one.
# single timestamp at the end means no speech after the last timestamp.
last_timestamp_position = (
timestamps[-1].item() - self.tokenizer.timestamp_begin
)
duration = last_timestamp_position * self.time_precision
2022-09-24 05:22:33 +00:00
logger.debug(f"segment_duration: {segment_duration}, Duration: {duration}")
2022-09-23 11:03:00 +00:00
chunk = self._get_chunk(
2022-09-29 11:14:56 +00:00
start=ctx.timestamp,
end=ctx.timestamp + duration,
2022-09-23 11:03:00 +00:00
text_tokens=tokens,
result=result,
)
if chunk is not None:
yield chunk
2022-09-29 11:14:56 +00:00
ctx.timestamp += duration
2022-09-23 11:03:00 +00:00
2022-09-29 11:43:49 +00:00
if result.temperature > ctx.buffer_threshold:
2022-09-23 11:03:00 +00:00
# do not feed the prompt tokens if a high temperature was used
2022-09-29 11:14:56 +00:00
del ctx.buffer_tokens
ctx.buffer_tokens = []
logger.debug(f"Length of buffer: {len(ctx.buffer_tokens)}")
2022-09-23 11:03:00 +00:00
2022-09-23 10:20:11 +00:00
def transcribe(
self,
*,
2022-10-02 10:47:17 +00:00
audio: np.ndarray,
2022-09-29 11:14:56 +00:00
ctx: Context,
2022-09-23 11:03:00 +00:00
) -> Iterator[ParsedChunk]:
2022-10-02 11:30:51 +00:00
logger.debug(f"{len(audio)}")
force_padding: bool = False
2022-10-02 11:38:21 +00:00
if ctx.vad_threshold > 0.0:
2022-10-02 11:38:21 +00:00
x = [
v
for v in self.vad(
audio=audio,
total_block_number=1,
2022-10-02 11:41:53 +00:00
threshold=ctx.vad_threshold,
2022-10-02 11:38:21 +00:00
)
]
if len(x) == 0: # No speech
logger.debug("No speech")
ctx.timestamp += len(audio) / N_FRAMES * self.duration_pre_one_mel
if ctx.nosoeech_skip_count is not None:
ctx.nosoeech_skip_count += 1
if (
ctx.nosoeech_skip_count is None
or ctx.nosoeech_skip_count <= ctx.max_nospeech_skip
):
logger.debug(
f"nosoeech_skip_count: {ctx.nosoeech_skip_count} (<= {ctx.max_nospeech_skip})"
)
return
ctx.nosoeech_skip_count = None
force_padding = True
2022-10-01 14:21:58 +00:00
2022-10-02 10:47:17 +00:00
new_mel = log_mel_spectrogram(audio=audio)
2022-09-24 15:28:05 +00:00
logger.debug(f"Incoming new_mel.shape: {new_mel.shape}")
2022-09-29 11:14:56 +00:00
if ctx.buffer_mel is None:
2022-09-23 12:50:26 +00:00
mel = new_mel
else:
2022-09-29 11:14:56 +00:00
logger.debug(f"buffer_mel.shape: {ctx.buffer_mel.shape}")
mel = torch.cat([ctx.buffer_mel, new_mel], dim=-1)
ctx.buffer_mel = None
2022-09-24 15:28:05 +00:00
logger.debug(f"mel.shape: {mel.shape}")
2022-09-23 12:50:26 +00:00
seek: int = 0
while seek < mel.shape[-1]:
logger.debug(f"seek: {seek}")
if mel.shape[-1] - seek <= 0:
logger.debug(f"No more seek: mel.shape={mel.shape}, seek={seek}")
break
2022-11-08 14:42:11 +00:00
if mel.shape[-1] - seek < ctx.mel_frame_min_num:
logger.debug(
2022-11-08 14:42:11 +00:00
f"mel.shape ({mel.shape[-1]}) - seek ({seek}) < ctx.mel_frame_min_num ({ctx.mel_frame_min_num})"
)
if force_padding:
logger.debug("Padding")
else:
logger.debug("No padding")
break
2022-10-02 10:47:17 +00:00
segment: torch.Tensor = (
2022-10-01 14:20:46 +00:00
pad_or_trim(mel[:, seek:], N_FRAMES)
2022-09-23 12:50:26 +00:00
.to(self.model.device) # type: ignore
.to(self.dtype)
)
2022-09-24 06:27:02 +00:00
logger.debug(
2022-09-29 11:14:56 +00:00
f"seek={seek}, timestamp={ctx.timestamp}, "
2022-09-24 06:27:02 +00:00
f"mel.shape: {mel.shape}, segment.shape: {segment.shape}"
)
2022-10-01 14:20:46 +00:00
result = self._decode_with_fallback(
2022-09-23 12:50:26 +00:00
segment=segment,
2022-09-29 11:14:56 +00:00
ctx=ctx,
2022-09-23 12:50:26 +00:00
)
2022-09-23 14:07:01 +00:00
logger.debug(
2022-09-23 14:28:25 +00:00
f"Result: temperature={result.temperature:.2f}, no_speech_prob={result.no_speech_prob:.2f}, "
f"avg_logprob={result.avg_logprob:.2f}"
2022-09-23 14:07:01 +00:00
)
2022-09-23 10:20:11 +00:00
2022-09-29 11:43:49 +00:00
if ctx.no_speech_threshold is not None:
if (result.no_speech_prob > ctx.no_speech_threshold) and not (
ctx.logprob_threshold is not None
and result.avg_logprob > ctx.logprob_threshold
2022-09-23 12:50:26 +00:00
):
2022-09-23 13:26:22 +00:00
seek += segment.shape[-1]
2022-09-24 06:27:02 +00:00
logger.debug(
f"Skip: {segment.shape[-1]}, new seek={seek}, mel.shape: {mel.shape}"
)
2022-09-23 13:26:22 +00:00
continue
2022-09-23 12:50:26 +00:00
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
last_timestamp_position: Optional[int] = None
for v in self._deal_timestamp(
2022-09-29 11:14:56 +00:00
result=result,
segment_duration=segment_duration,
ctx=ctx,
2022-09-23 10:20:11 +00:00
):
2022-09-23 12:50:26 +00:00
if isinstance(v, int):
last_timestamp_position = v
else:
yield v
if last_timestamp_position is None:
seek += segment.shape[-1]
2022-09-23 11:41:44 +00:00
else:
2022-09-24 05:22:33 +00:00
seek += last_timestamp_position * self.input_stride
2022-09-24 06:27:02 +00:00
logger.debug(f"new seek={seek}, mel.shape: {mel.shape}")
2022-09-23 12:50:26 +00:00
2022-09-24 15:28:05 +00:00
if mel.shape[-1] - seek <= 0:
ctx.buffer_mel = None
ctx.nosoeech_skip_count = None
2022-09-29 11:14:56 +00:00
logger.debug(f"ctx.buffer_mel is None ({mel.shape}, {seek})")
2022-09-24 06:39:41 +00:00
return
2022-10-01 14:20:46 +00:00
ctx.buffer_mel = mel[:, seek:]
2022-09-29 11:14:56 +00:00
assert ctx.buffer_mel is not None
logger.debug(f"ctx.buffer_mel.shape: {ctx.buffer_mel.shape}")
2022-09-23 12:50:26 +00:00
del mel
if ctx.nosoeech_skip_count is None:
ctx.nosoeech_skip_count = 0 # start count