mirror of
https://github.com/shirayu/whispering.git
synced 2024-06-13 02:39:23 +00:00
Fix to save the rest of mel
This commit is contained in:
parent
8f1e807ed7
commit
ca9af389bf
|
@ -4,6 +4,7 @@ from typing import Iterator, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from transformers.utils.dummy_pt_objects import OPT_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||||
from whisper import Whisper, load_model
|
from whisper import Whisper, load_model
|
||||||
from whisper.audio import (
|
from whisper.audio import (
|
||||||
HOP_LENGTH,
|
HOP_LENGTH,
|
||||||
|
@ -38,6 +39,7 @@ class WhisperStreamingTranscriber:
|
||||||
) # time per output token: 0.02 (seconds)
|
) # time per output token: 0.02 (seconds)
|
||||||
|
|
||||||
self.buffer_tokens = []
|
self.buffer_tokens = []
|
||||||
|
self.buffer_mel = None
|
||||||
|
|
||||||
def _get_decoding_options(
|
def _get_decoding_options(
|
||||||
self,
|
self,
|
||||||
|
@ -202,27 +204,52 @@ class WhisperStreamingTranscriber:
|
||||||
*,
|
*,
|
||||||
segment: np.ndarray,
|
segment: np.ndarray,
|
||||||
) -> Iterator[ParsedChunk]:
|
) -> Iterator[ParsedChunk]:
|
||||||
log_spec = log_mel_spectrogram(audio=segment).unsqueeze(0)
|
new_mel = log_mel_spectrogram(audio=segment).unsqueeze(0)
|
||||||
segment = (
|
if self.buffer_mel is None:
|
||||||
pad_or_trim(log_spec, N_FRAMES)
|
mel = new_mel
|
||||||
.to(self.model.device) # type:ignore
|
else:
|
||||||
.to(self.dtype) # type:ignore
|
mel = torch.cat([self.buffer_mel, new_mel], dim=-1)
|
||||||
)
|
self.buffer_mel = None
|
||||||
results = self._decode_with_fallback(
|
|
||||||
segment=segment,
|
|
||||||
)
|
|
||||||
result = results[0]
|
|
||||||
|
|
||||||
if self.config.no_speech_threshold is not None:
|
seek: int = 0
|
||||||
if (result.no_speech_prob > self.config.no_speech_threshold) and not (
|
rest_start: Optional[int] = None
|
||||||
self.config.logprob_threshold is not None
|
while seek < mel.shape[-1]:
|
||||||
and result.avg_logprob > self.config.logprob_threshold
|
segment = (
|
||||||
|
pad_or_trim(mel[:, :, seek:], N_FRAMES)
|
||||||
|
.to(self.model.device) # type: ignore
|
||||||
|
.to(self.dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
results = self._decode_with_fallback(
|
||||||
|
segment=segment,
|
||||||
|
)
|
||||||
|
result = results[0]
|
||||||
|
|
||||||
|
if self.config.no_speech_threshold is not None:
|
||||||
|
if (result.no_speech_prob > self.config.no_speech_threshold) and not (
|
||||||
|
self.config.logprob_threshold is not None
|
||||||
|
and result.avg_logprob > self.config.logprob_threshold
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
|
||||||
|
last_timestamp_position: Optional[int] = None
|
||||||
|
for v in self._deal_timestamp(
|
||||||
|
result=result, segment_duration=segment_duration
|
||||||
):
|
):
|
||||||
return
|
if isinstance(v, int):
|
||||||
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
|
last_timestamp_position = v
|
||||||
for v in self._deal_timestamp(result=result, segment_duration=segment_duration):
|
else:
|
||||||
if isinstance(v, int):
|
yield v
|
||||||
# FIXME: save log_spec
|
if last_timestamp_position is None:
|
||||||
pass
|
seek += segment.shape[-1]
|
||||||
|
rest_start = None
|
||||||
else:
|
else:
|
||||||
yield v
|
seek += last_timestamp_position
|
||||||
|
rest_start = seek
|
||||||
|
|
||||||
|
if rest_start is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.buffer_mel = mel[:, :, rest_start:]
|
||||||
|
del mel
|
||||||
|
|
Loading…
Reference in a new issue