Fix to save the rest of mel

This commit is contained in:
Yuta Hayashibe 2022-09-23 21:50:26 +09:00
parent 8f1e807ed7
commit ca9af389bf

View file

@ -4,6 +4,7 @@ from typing import Iterator, List, Optional, Union
import numpy as np
import torch
from transformers.utils.dummy_pt_objects import OPT_PRETRAINED_MODEL_ARCHIVE_LIST
from whisper import Whisper, load_model
from whisper.audio import (
HOP_LENGTH,
@ -38,6 +39,7 @@ class WhisperStreamingTranscriber:
) # time per output token: 0.02 (seconds)
self.buffer_tokens = []
self.buffer_mel = None
def _get_decoding_options(
self,
@ -202,27 +204,52 @@ class WhisperStreamingTranscriber:
*,
segment: np.ndarray,
) -> Iterator[ParsedChunk]:
log_spec = log_mel_spectrogram(audio=segment).unsqueeze(0)
segment = (
pad_or_trim(log_spec, N_FRAMES)
.to(self.model.device) # type:ignore
.to(self.dtype) # type:ignore
)
results = self._decode_with_fallback(
segment=segment,
)
result = results[0]
new_mel = log_mel_spectrogram(audio=segment).unsqueeze(0)
if self.buffer_mel is None:
mel = new_mel
else:
mel = torch.cat([self.buffer_mel, new_mel], dim=-1)
self.buffer_mel = None
if self.config.no_speech_threshold is not None:
if (result.no_speech_prob > self.config.no_speech_threshold) and not (
self.config.logprob_threshold is not None
and result.avg_logprob > self.config.logprob_threshold
seek: int = 0
rest_start: Optional[int] = None
while seek < mel.shape[-1]:
segment = (
pad_or_trim(mel[:, :, seek:], N_FRAMES)
.to(self.model.device) # type: ignore
.to(self.dtype)
)
results = self._decode_with_fallback(
segment=segment,
)
result = results[0]
if self.config.no_speech_threshold is not None:
if (result.no_speech_prob > self.config.no_speech_threshold) and not (
self.config.logprob_threshold is not None
and result.avg_logprob > self.config.logprob_threshold
):
return
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
last_timestamp_position: Optional[int] = None
for v in self._deal_timestamp(
result=result, segment_duration=segment_duration
):
return
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
for v in self._deal_timestamp(result=result, segment_duration=segment_duration):
if isinstance(v, int):
# FIXME: save log_spec
pass
if isinstance(v, int):
last_timestamp_position = v
else:
yield v
if last_timestamp_position is None:
seek += segment.shape[-1]
rest_start = None
else:
yield v
seek += last_timestamp_position
rest_start = seek
if rest_start is None:
return
self.buffer_mel = mel[:, :, rest_start:]
del mel