This commit is contained in:
Yuta Hayashibe 2022-10-02 20:30:51 +09:00
parent 847eee5819
commit 08798f117a
2 changed files with 23 additions and 8 deletions

View file

@ -7,6 +7,7 @@ import numpy as np
import torch
from whisper import Whisper, load_model
from whisper.audio import (
CHUNK_LENGTH,
HOP_LENGTH,
N_FRAMES,
SAMPLE_RATE,
@ -52,6 +53,7 @@ class WhisperStreamingTranscriber:
self.time_precision: Final[float] = (
self.input_stride * HOP_LENGTH / SAMPLE_RATE
) # time per output token: 0.02 (seconds)
self.duration_pre_one_mel: Final[float] = CHUNK_LENGTH / HOP_LENGTH
self.vad = VAD()
def _get_decoding_options(
@ -230,8 +232,18 @@ class WhisperStreamingTranscriber:
audio: np.ndarray,
ctx: Context,
) -> Iterator[ParsedChunk]:
for speech_segment in self.vad(audio=audio):
logger.debug(f"{speech_segment}")
logger.debug(f"{len(audio)}")
x = [
v
for v in self.vad(
audio=audio,
total_block_number=1,
)
]
if len(x) == 0: # No speech
logger.debug("No speech")
ctx.timestamp += len(audio) / N_FRAMES * self.duration_pre_one_mel
return
new_mel = log_mel_spectrogram(audio=audio)
logger.debug(f"Incoming new_mel.shape: {new_mel.shape}")

View file

@ -1,6 +1,6 @@
#!/usr/bin/env python3
from typing import Iterator
from typing import Iterator, Optional
import numpy as np
import torch
@ -23,6 +23,7 @@ class VAD:
*,
audio: np.ndarray,
thredhold: float = 0.5,
total_block_number: Optional[int] = None,
) -> Iterator[SpeechSegment]:
# audio.shape should be multiple of (N_FRAMES,)
@ -37,12 +38,14 @@ class VAD:
audio=audio[N_FRAMES * start_block_idx : N_FRAMES * idx],
)
block_size: int = int(audio.shape[0] / N_FRAMES)
if total_block_number is None:
total_block_number = int(audio.shape[0] / N_FRAMES)
block_unit: int = audio.shape[0] // total_block_number
start_block_idx = None
for idx in range(block_size):
start: int = N_FRAMES * idx
end: int = N_FRAMES * (idx + 1)
for idx in range(total_block_number):
start: int = block_unit * idx
end: int = block_unit * (idx + 1)
vad_prob = self.vad_model(
torch.from_numpy(audio[start:end]),
SAMPLE_RATE,
@ -60,5 +63,5 @@ class VAD:
if start_block_idx is not None:
yield my_ret(
start_block_idx=start_block_idx,
idx=block_size,
idx=total_block_number,
)