Fix to save the rest of mel

This commit is contained in:
Yuta Hayashibe 2022-09-23 21:50:26 +09:00
parent 8f1e807ed7
commit ca9af389bf

View file

@ -4,6 +4,7 @@ from typing import Iterator, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
from transformers.utils.dummy_pt_objects import OPT_PRETRAINED_MODEL_ARCHIVE_LIST
from whisper import Whisper, load_model from whisper import Whisper, load_model
from whisper.audio import ( from whisper.audio import (
HOP_LENGTH, HOP_LENGTH,
@ -38,6 +39,7 @@ class WhisperStreamingTranscriber:
) # time per output token: 0.02 (seconds) ) # time per output token: 0.02 (seconds)
self.buffer_tokens = [] self.buffer_tokens = []
self.buffer_mel = None
def _get_decoding_options( def _get_decoding_options(
self, self,
@ -202,27 +204,52 @@ class WhisperStreamingTranscriber:
*, *,
segment: np.ndarray, segment: np.ndarray,
) -> Iterator[ParsedChunk]: ) -> Iterator[ParsedChunk]:
log_spec = log_mel_spectrogram(audio=segment).unsqueeze(0) new_mel = log_mel_spectrogram(audio=segment).unsqueeze(0)
segment = ( if self.buffer_mel is None:
pad_or_trim(log_spec, N_FRAMES) mel = new_mel
.to(self.model.device) # type:ignore else:
.to(self.dtype) # type:ignore mel = torch.cat([self.buffer_mel, new_mel], dim=-1)
) self.buffer_mel = None
results = self._decode_with_fallback(
segment=segment,
)
result = results[0]
if self.config.no_speech_threshold is not None: seek: int = 0
if (result.no_speech_prob > self.config.no_speech_threshold) and not ( rest_start: Optional[int] = None
self.config.logprob_threshold is not None while seek < mel.shape[-1]:
and result.avg_logprob > self.config.logprob_threshold segment = (
pad_or_trim(mel[:, :, seek:], N_FRAMES)
.to(self.model.device) # type: ignore
.to(self.dtype)
)
results = self._decode_with_fallback(
segment=segment,
)
result = results[0]
if self.config.no_speech_threshold is not None:
if (result.no_speech_prob > self.config.no_speech_threshold) and not (
self.config.logprob_threshold is not None
and result.avg_logprob > self.config.logprob_threshold
):
return
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
last_timestamp_position: Optional[int] = None
for v in self._deal_timestamp(
result=result, segment_duration=segment_duration
): ):
return if isinstance(v, int):
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE last_timestamp_position = v
for v in self._deal_timestamp(result=result, segment_duration=segment_duration): else:
if isinstance(v, int): yield v
# FIXME: save log_spec if last_timestamp_position is None:
pass seek += segment.shape[-1]
rest_start = None
else: else:
yield v seek += last_timestamp_position
rest_start = seek
if rest_start is None:
return
self.buffer_mel = mel[:, :, rest_start:]
del mel