Merge branch 'vad'

This commit is contained in:
Yuta Hayashibe 2022-10-02 20:38:51 +09:00
commit 1e7166e378
8 changed files with 145 additions and 11 deletions

View file

@ -34,6 +34,7 @@ whispering --language en --model tiny
- ``--no-progress`` disables the progress message - ``--no-progress`` disables the progress message
- ``-t`` sets temperatures to decode. You can set several like ``-t 0.0 -t 0.1 -t 0.5``, but too many temperatures exhaust decoding time - ``-t`` sets temperatures to decode. You can set several like ``-t 0.0 -t 0.1 -t 0.5``, but too many temperatures exhaust decoding time
- ``--debug`` outputs logs for debug - ``--debug`` outputs logs for debug
- ``--no-vad`` disables VAD (Voice Activity Detection). This forces whisper to analyze non-voice activity sound period
### Parse interval ### Parse interval

34
poetry.lock generated
View file

@ -378,6 +378,17 @@ python-versions = ">=3.7.0"
[package.dependencies] [package.dependencies]
typing-extensions = "*" typing-extensions = "*"
[[package]]
name = "torchaudio"
version = "0.12.1"
description = "An audio package for PyTorch"
category = "main"
optional = false
python-versions = "*"
[package.dependencies]
torch = "1.12.1"
[[package]] [[package]]
name = "tqdm" name = "tqdm"
version = "4.64.1" version = "4.64.1"
@ -514,7 +525,7 @@ resolved_reference = "0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f"
[metadata] [metadata]
lock-version = "1.1" lock-version = "1.1"
python-versions = ">=3.8,<3.11" python-versions = ">=3.8,<3.11"
content-hash = "f5395ffab6ce7d95246143218e948308d6614929f375489eb2b94a863e15fcc4" content-hash = "ab527970383bc2245dee005627d0695812601115a36e15a5ef9e66d1185791bf"
[metadata.files] [metadata.files]
black = [ black = [
@ -964,6 +975,27 @@ torch = [
{file = "torch-1.12.1-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:bfec2843daa654f04fda23ba823af03e7b6f7650a873cdb726752d0e3718dada"}, {file = "torch-1.12.1-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:bfec2843daa654f04fda23ba823af03e7b6f7650a873cdb726752d0e3718dada"},
{file = "torch-1.12.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:69fe2cae7c39ccadd65a123793d30e0db881f1c1927945519c5c17323131437e"}, {file = "torch-1.12.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:69fe2cae7c39ccadd65a123793d30e0db881f1c1927945519c5c17323131437e"},
] ]
torchaudio = [
{file = "torchaudio-0.12.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:dc138bee06b2305442fc132171f2a01d5f42509eaa21bdf87c3d26a6f4a09fdd"},
{file = "torchaudio-0.12.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1d81f71837d5d5be651e85ca9fa9377ecb4513b0129ddfb025540e1c2406d3e6"},
{file = "torchaudio-0.12.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:c2f46ad1332d4eb4c5bc2259bad22f7693d1e81cdcf2ab04242bf428d78f161f"},
{file = "torchaudio-0.12.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:21741a277d31f75215a09c1590170055b65c2eceda6aa5a263676745bd97172e"},
{file = "torchaudio-0.12.1-cp310-cp310-win_amd64.whl", hash = "sha256:f0cc2d4ab4288d5115fab554a49bed6251469dc1548c961655556ec48a3c320e"},
{file = "torchaudio-0.12.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:23dbcf37af2f41d491c0337ca94501ec7ef588adb1766e1eb28033fac549bbd9"},
{file = "torchaudio-0.12.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:e82c48b05d941d64cc67a18d13f8e76ba7e852fe9f187b47d3abfbebd1f05195"},
{file = "torchaudio-0.12.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:142da7f0f05517b32cb54ed6f37997f741ad1bd283474898b680b0dfed7ff926"},
{file = "torchaudio-0.12.1-cp37-cp37m-win_amd64.whl", hash = "sha256:1c839ceb2035c3ea3458e274e9a1afb65f5fa41678e76c3378b218eb23956579"},
{file = "torchaudio-0.12.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:a4c8c15b1e810a93bb77b27fa49159bea2253b593ef94039946ec49aef51764f"},
{file = "torchaudio-0.12.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:83c08b71a6dc8e23c1d7b00780abb9e4c29528e47a6e644fe3dee7ac2263821e"},
{file = "torchaudio-0.12.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:186dcaa00b60e441f9c489c00966ecdd7412c2a4592058107f8c3a888cbbf337"},
{file = "torchaudio-0.12.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:2937756874050cb3249395d7814dacab2c296ce3e5ae3e63397aa4fc902db885"},
{file = "torchaudio-0.12.1-cp38-cp38-win_amd64.whl", hash = "sha256:ba00c62bae021b8e5a3d38f04788e489e6f8d9eb16620d8c1e81b1e9d4bf1284"},
{file = "torchaudio-0.12.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:08f92bc53682d3bad8606dedb70a49e5a0f7cf9306c9173f074dbba97785442e"},
{file = "torchaudio-0.12.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2fc5a2bc8e8aad475bc519f3c82b9649e14b5c657487ffa712cf7c514143e9d7"},
{file = "torchaudio-0.12.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:075dba92c8c885ef1bc882e24a0ffdcce29a73f4d2377c75d1fa1c76702b37e3"},
{file = "torchaudio-0.12.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:a2bc09eee50fb5adc3e40c66bb63d525344bb8359f65d9c600d53ea6212207e6"},
{file = "torchaudio-0.12.1-cp39-cp39-win_amd64.whl", hash = "sha256:5b06c72da8ea8f8cd3075d7f97e2866b473aceaca08ef871895cd5fafde078bf"},
]
tqdm = [ tqdm = [
{file = "tqdm-4.64.1-py2.py3-none-any.whl", hash = "sha256:6fee160d6ffcd1b1c68c65f14c829c22832bc401726335ce92c52d395944a6a1"}, {file = "tqdm-4.64.1-py2.py3-none-any.whl", hash = "sha256:6fee160d6ffcd1b1c68c65f14c829c22832bc401726335ce92c52d395944a6a1"},
{file = "tqdm-4.64.1.tar.gz", hash = "sha256:5f4f682a004951c1b450bc753c710e9280c5746ce6ffedee253ddbcbf54cf1e4"}, {file = "tqdm-4.64.1.tar.gz", hash = "sha256:5f4f682a004951c1b450bc753c710e9280c5746ce6ffedee253ddbcbf54cf1e4"},

View file

@ -13,6 +13,7 @@ sounddevice = "^0.4.5"
pydantic = "^1.10.2" pydantic = "^1.10.2"
websockets = "^10.3" websockets = "^10.3"
tqdm = "*" tqdm = "*"
torchaudio = "^0.12.1"
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]

View file

@ -48,16 +48,16 @@ def transcribe_from_mic(
): ):
idx: int = 0 idx: int = 0
while True: while True:
logger.debug(f"Segment #: {idx}, The rest of queue: {q.qsize()}") logger.debug(f"Audio #: {idx}, The rest of queue: {q.qsize()}")
if no_progress: if no_progress:
segment = q.get() audio = q.get()
else: else:
pbar_thread = ProgressBar( pbar_thread = ProgressBar(
num_block=num_block, # TODO: set more accurate value num_block=num_block, # TODO: set more accurate value
) )
try: try:
segment = q.get() audio = q.get()
except KeyboardInterrupt: except KeyboardInterrupt:
pbar_thread.kill() pbar_thread.kill()
return return
@ -68,7 +68,7 @@ def transcribe_from_mic(
sys.stderr.write("Analyzing") sys.stderr.write("Analyzing")
sys.stderr.flush() sys.stderr.flush()
for chunk in wsp.transcribe(segment=segment, ctx=ctx): for chunk in wsp.transcribe(audio=audio, ctx=ctx):
if not no_progress: if not no_progress:
sys.stderr.write("\r") sys.stderr.write("\r")
sys.stderr.flush() sys.stderr.flush()
@ -155,6 +155,10 @@ def get_opts() -> argparse.Namespace:
"--no-progress", "--no-progress",
action="store_true", action="store_true",
) )
parser.add_argument(
"--no-vad",
action="store_true",
)
opts = parser.parse_args() opts = parser.parse_args()
if opts.beam_size <= 0: if opts.beam_size <= 0:
@ -187,6 +191,7 @@ def get_context(*, opts) -> Context:
beam_size=opts.beam_size, beam_size=opts.beam_size,
temperatures=opts.temperature, temperatures=opts.temperature,
allow_padding=opts.allow_padding, allow_padding=opts.allow_padding,
vad=not opts.no_vad,
) )
logger.debug(f"Context: {ctx}") logger.debug(f"Context: {ctx}")
return ctx return ctx

View file

@ -2,6 +2,7 @@
from typing import List, Optional from typing import List, Optional
import numpy as np
import torch import torch
from pydantic import BaseModel, root_validator from pydantic import BaseModel, root_validator
@ -26,6 +27,7 @@ class Context(BaseModel, arbitrary_types_allowed=True):
timestamp: float = 0.0 timestamp: float = 0.0
buffer_tokens: List[torch.Tensor] = [] buffer_tokens: List[torch.Tensor] = []
buffer_mel: Optional[torch.Tensor] = None buffer_mel: Optional[torch.Tensor] = None
vad: bool = True
temperatures: List[float] temperatures: List[float]
allow_padding: bool = False allow_padding: bool = False
@ -50,3 +52,9 @@ class ParsedChunk(BaseModel):
avg_logprob: float avg_logprob: float
compression_ratio: float compression_ratio: float
no_speech_prob: float no_speech_prob: float
class SpeechSegment(BaseModel, arbitrary_types_allowed=True):
start_block_idx: int
end_block_idx: int
audio: np.ndarray

View file

@ -21,7 +21,7 @@ async def serve_with_websocket_main(websocket):
) )
while True: while True:
logger.debug(f"Segment #: {idx}") logger.debug(f"Audio #: {idx}")
try: try:
message = await websocket.recv() message = await websocket.recv()
except ConnectionClosedOK: except ConnectionClosedOK:
@ -32,9 +32,9 @@ async def serve_with_websocket_main(websocket):
continue continue
logger.debug(f"Message size: {len(message)}") logger.debug(f"Message size: {len(message)}")
segment = np.frombuffer(message, dtype=np.float32) audio = np.frombuffer(message, dtype=np.float32)
for chunk in g_wsp.transcribe( for chunk in g_wsp.transcribe(
segment=segment, # type: ignore audio=audio, # type: ignore
ctx=ctx, ctx=ctx,
): ):
await websocket.send(chunk.json()) await websocket.send(chunk.json())

View file

@ -3,9 +3,11 @@
from logging import getLogger from logging import getLogger
from typing import Final, Iterator, Optional, Union from typing import Final, Iterator, Optional, Union
import numpy as np
import torch import torch
from whisper import Whisper, load_model from whisper import Whisper, load_model
from whisper.audio import ( from whisper.audio import (
CHUNK_LENGTH,
HOP_LENGTH, HOP_LENGTH,
N_FRAMES, N_FRAMES,
SAMPLE_RATE, SAMPLE_RATE,
@ -17,6 +19,7 @@ from whisper.tokenizer import get_tokenizer
from whisper.utils import exact_div from whisper.utils import exact_div
from whispering.schema import Context, ParsedChunk, WhisperConfig from whispering.schema import Context, ParsedChunk, WhisperConfig
from whispering.vad import VAD
logger = getLogger(__name__) logger = getLogger(__name__)
@ -50,6 +53,8 @@ class WhisperStreamingTranscriber:
self.time_precision: Final[float] = ( self.time_precision: Final[float] = (
self.input_stride * HOP_LENGTH / SAMPLE_RATE self.input_stride * HOP_LENGTH / SAMPLE_RATE
) # time per output token: 0.02 (seconds) ) # time per output token: 0.02 (seconds)
self.duration_pre_one_mel: Final[float] = CHUNK_LENGTH / HOP_LENGTH
self.vad = VAD()
def _get_decoding_options( def _get_decoding_options(
self, self,
@ -224,10 +229,25 @@ class WhisperStreamingTranscriber:
def transcribe( def transcribe(
self, self,
*, *,
segment: torch.Tensor, audio: np.ndarray,
ctx: Context, ctx: Context,
) -> Iterator[ParsedChunk]: ) -> Iterator[ParsedChunk]:
new_mel = log_mel_spectrogram(audio=segment) logger.debug(f"{len(audio)}")
if not ctx.vad:
x = [
v
for v in self.vad(
audio=audio,
total_block_number=1,
)
]
if len(x) == 0: # No speech
logger.debug("No speech")
ctx.timestamp += len(audio) / N_FRAMES * self.duration_pre_one_mel
return
new_mel = log_mel_spectrogram(audio=audio)
logger.debug(f"Incoming new_mel.shape: {new_mel.shape}") logger.debug(f"Incoming new_mel.shape: {new_mel.shape}")
if ctx.buffer_mel is None: if ctx.buffer_mel is None:
mel = new_mel mel = new_mel
@ -239,7 +259,7 @@ class WhisperStreamingTranscriber:
seek: int = 0 seek: int = 0
while seek < mel.shape[-1]: while seek < mel.shape[-1]:
segment = ( segment: torch.Tensor = (
pad_or_trim(mel[:, seek:], N_FRAMES) pad_or_trim(mel[:, seek:], N_FRAMES)
.to(self.model.device) # type: ignore .to(self.model.device) # type: ignore
.to(self.dtype) .to(self.dtype)

67
whispering/vad.py Normal file
View file

@ -0,0 +1,67 @@
#!/usr/bin/env python3
from typing import Iterator, Optional
import numpy as np
import torch
from whisper.audio import N_FRAMES, SAMPLE_RATE
from whispering.schema import SpeechSegment
class VAD:
def __init__(
self,
):
self.vad_model, _ = torch.hub.load(
repo_or_dir="snakers4/silero-vad",
model="silero_vad",
)
def __call__(
self,
*,
audio: np.ndarray,
thredhold: float = 0.5,
total_block_number: Optional[int] = None,
) -> Iterator[SpeechSegment]:
# audio.shape should be multiple of (N_FRAMES,)
def my_ret(
*,
start_block_idx: int,
idx: int,
) -> SpeechSegment:
return SpeechSegment(
start_block_idx=start_block_idx,
end_block_idx=idx,
audio=audio[N_FRAMES * start_block_idx : N_FRAMES * idx],
)
if total_block_number is None:
total_block_number = int(audio.shape[0] / N_FRAMES)
block_unit: int = audio.shape[0] // total_block_number
start_block_idx = None
for idx in range(total_block_number):
start: int = block_unit * idx
end: int = block_unit * (idx + 1)
vad_prob = self.vad_model(
torch.from_numpy(audio[start:end]),
SAMPLE_RATE,
).item()
if vad_prob > thredhold:
if start_block_idx is None:
start_block_idx = idx
else:
if start_block_idx is not None:
yield my_ret(
start_block_idx=start_block_idx,
idx=idx,
)
start_block_idx = None
if start_block_idx is not None:
yield my_ret(
start_block_idx=start_block_idx,
idx=total_block_number,
)