whispering/whispering/vad.py
Yuta Hayashibe a62cb52f5f Add
2022-10-01 23:22:16 +09:00

51 lines
1.4 KiB
Python

#!/usr/bin/env python3
from typing import Iterator
import numpy as np
import torch
from whisper.audio import N_FRAMES, SAMPLE_RATE
from whispering.schema import SpeechSegment
class VAD:
def __init__(
self,
):
self.vad_model, _ = torch.hub.load(
repo_or_dir="snakers4/silero-vad",
model="silero_vad",
)
def __call__(
self,
*,
segment: np.ndarray,
thredhold: float = 0.5,
) -> Iterator[SpeechBlock]:
# segment.shape should be multiple of (N_FRAMES,)
block_size: int = int(segment.shape[0] / N_FRAMES)
start_block_idx = None
for idx in range(block_size + 1):
if idx < block_size:
start: int = N_FRAMES * idx
end: int = N_FRAMES * (idx + 1)
vad_prob = self.vad_model(
torch.from_numpy(segment[start:end]),
SAMPLE_RATE,
).item()
if vad_prob > thredhold:
if start_block_idx is None:
start_block_idx = idx
else:
if start_block_idx is not None:
yield SpeechSegment(
start_block_idx=start_block_idx,
end_block_idx=idx,
segment=segment[N_FRAMES * start_block_idx : N_FRAMES * idx],
)
start_block_idx = None