mirror of
https://github.com/shirayu/whispering.git
synced 2025-02-16 10:35:16 +00:00
Fix to save the rest of mel
This commit is contained in:
parent
8f1e807ed7
commit
ca9af389bf
1 changed files with 48 additions and 21 deletions
|
@ -4,6 +4,7 @@ from typing import Iterator, List, Optional, Union
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers.utils.dummy_pt_objects import OPT_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
from whisper import Whisper, load_model
|
||||
from whisper.audio import (
|
||||
HOP_LENGTH,
|
||||
|
@ -38,6 +39,7 @@ class WhisperStreamingTranscriber:
|
|||
) # time per output token: 0.02 (seconds)
|
||||
|
||||
self.buffer_tokens = []
|
||||
self.buffer_mel = None
|
||||
|
||||
def _get_decoding_options(
|
||||
self,
|
||||
|
@ -202,27 +204,52 @@ class WhisperStreamingTranscriber:
|
|||
*,
|
||||
segment: np.ndarray,
|
||||
) -> Iterator[ParsedChunk]:
|
||||
log_spec = log_mel_spectrogram(audio=segment).unsqueeze(0)
|
||||
segment = (
|
||||
pad_or_trim(log_spec, N_FRAMES)
|
||||
.to(self.model.device) # type:ignore
|
||||
.to(self.dtype) # type:ignore
|
||||
)
|
||||
results = self._decode_with_fallback(
|
||||
segment=segment,
|
||||
)
|
||||
result = results[0]
|
||||
new_mel = log_mel_spectrogram(audio=segment).unsqueeze(0)
|
||||
if self.buffer_mel is None:
|
||||
mel = new_mel
|
||||
else:
|
||||
mel = torch.cat([self.buffer_mel, new_mel], dim=-1)
|
||||
self.buffer_mel = None
|
||||
|
||||
if self.config.no_speech_threshold is not None:
|
||||
if (result.no_speech_prob > self.config.no_speech_threshold) and not (
|
||||
self.config.logprob_threshold is not None
|
||||
and result.avg_logprob > self.config.logprob_threshold
|
||||
seek: int = 0
|
||||
rest_start: Optional[int] = None
|
||||
while seek < mel.shape[-1]:
|
||||
segment = (
|
||||
pad_or_trim(mel[:, :, seek:], N_FRAMES)
|
||||
.to(self.model.device) # type: ignore
|
||||
.to(self.dtype)
|
||||
)
|
||||
|
||||
results = self._decode_with_fallback(
|
||||
segment=segment,
|
||||
)
|
||||
result = results[0]
|
||||
|
||||
if self.config.no_speech_threshold is not None:
|
||||
if (result.no_speech_prob > self.config.no_speech_threshold) and not (
|
||||
self.config.logprob_threshold is not None
|
||||
and result.avg_logprob > self.config.logprob_threshold
|
||||
):
|
||||
return
|
||||
|
||||
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
|
||||
last_timestamp_position: Optional[int] = None
|
||||
for v in self._deal_timestamp(
|
||||
result=result, segment_duration=segment_duration
|
||||
):
|
||||
return
|
||||
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
|
||||
for v in self._deal_timestamp(result=result, segment_duration=segment_duration):
|
||||
if isinstance(v, int):
|
||||
# FIXME: save log_spec
|
||||
pass
|
||||
if isinstance(v, int):
|
||||
last_timestamp_position = v
|
||||
else:
|
||||
yield v
|
||||
if last_timestamp_position is None:
|
||||
seek += segment.shape[-1]
|
||||
rest_start = None
|
||||
else:
|
||||
yield v
|
||||
seek += last_timestamp_position
|
||||
rest_start = seek
|
||||
|
||||
if rest_start is None:
|
||||
return
|
||||
|
||||
self.buffer_mel = mel[:, :, rest_start:]
|
||||
del mel
|
||||
|
|
Loading…
Reference in a new issue