whispering/whispering/vad.py

72 lines
2 KiB
Python
Raw Permalink Normal View History

2022-10-01 14:21:58 +00:00
#!/usr/bin/env python3
from logging import getLogger
2022-10-02 11:30:51 +00:00
from typing import Iterator, Optional
2022-10-01 14:21:58 +00:00
2022-10-02 10:47:17 +00:00
import numpy as np
2022-10-01 14:21:58 +00:00
import torch
from whisper.audio import N_FRAMES, SAMPLE_RATE
from whispering.schema import SpeechSegment
logger = getLogger(__name__)
2022-10-01 14:21:58 +00:00
class VAD:
def __init__(
self,
):
self.vad_model, _ = torch.hub.load(
repo_or_dir="snakers4/silero-vad",
model="silero_vad",
)
def __call__(
self,
*,
2022-10-02 10:47:17 +00:00
audio: np.ndarray,
2022-10-02 11:41:53 +00:00
threshold: float,
2022-10-02 11:30:51 +00:00
total_block_number: Optional[int] = None,
) -> Iterator[SpeechSegment]:
2022-10-02 10:47:17 +00:00
# audio.shape should be multiple of (N_FRAMES,)
2022-10-01 14:21:58 +00:00
def my_ret(
*,
start_block_idx: int,
idx: int,
) -> SpeechSegment:
return SpeechSegment(
start_block_idx=start_block_idx,
end_block_idx=idx,
2022-10-02 10:47:17 +00:00
audio=audio[N_FRAMES * start_block_idx : N_FRAMES * idx],
)
2022-10-02 11:30:51 +00:00
if total_block_number is None:
total_block_number = int(audio.shape[0] / N_FRAMES)
block_unit: int = audio.shape[0] // total_block_number
2022-10-01 14:21:58 +00:00
start_block_idx = None
2022-10-02 11:30:51 +00:00
for idx in range(total_block_number):
start: int = block_unit * idx
end: int = block_unit * (idx + 1)
2022-10-01 14:21:58 +00:00
vad_prob = self.vad_model(
2022-10-02 10:47:17 +00:00
torch.from_numpy(audio[start:end]),
2022-10-01 14:21:58 +00:00
SAMPLE_RATE,
).item()
logger.debug(f"VAD: {vad_prob} (threshold={threshold})")
2022-10-02 11:40:35 +00:00
if vad_prob > threshold:
2022-10-01 14:21:58 +00:00
if start_block_idx is None:
start_block_idx = idx
else:
if start_block_idx is not None:
yield my_ret(
2022-10-01 14:21:58 +00:00
start_block_idx=start_block_idx,
idx=idx,
2022-10-01 14:21:58 +00:00
)
start_block_idx = None
if start_block_idx is not None:
yield my_ret(
start_block_idx=start_block_idx,
2022-10-02 11:30:51 +00:00
idx=total_block_number,
)