mirror of
https://github.com/shirayu/whispering.git
synced 2024-11-14 12:41:25 +00:00
Merge branch 'vad'
This commit is contained in:
commit
1e7166e378
8 changed files with 145 additions and 11 deletions
|
@ -34,6 +34,7 @@ whispering --language en --model tiny
|
||||||
- ``--no-progress`` disables the progress message
|
- ``--no-progress`` disables the progress message
|
||||||
- ``-t`` sets temperatures to decode. You can set several like ``-t 0.0 -t 0.1 -t 0.5``, but too many temperatures exhaust decoding time
|
- ``-t`` sets temperatures to decode. You can set several like ``-t 0.0 -t 0.1 -t 0.5``, but too many temperatures exhaust decoding time
|
||||||
- ``--debug`` outputs logs for debug
|
- ``--debug`` outputs logs for debug
|
||||||
|
- ``--no-vad`` disables VAD (Voice Activity Detection). This forces whisper to analyze non-voice activity sound period
|
||||||
|
|
||||||
### Parse interval
|
### Parse interval
|
||||||
|
|
||||||
|
|
34
poetry.lock
generated
34
poetry.lock
generated
|
@ -378,6 +378,17 @@ python-versions = ">=3.7.0"
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
typing-extensions = "*"
|
typing-extensions = "*"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "torchaudio"
|
||||||
|
version = "0.12.1"
|
||||||
|
description = "An audio package for PyTorch"
|
||||||
|
category = "main"
|
||||||
|
optional = false
|
||||||
|
python-versions = "*"
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
torch = "1.12.1"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tqdm"
|
name = "tqdm"
|
||||||
version = "4.64.1"
|
version = "4.64.1"
|
||||||
|
@ -514,7 +525,7 @@ resolved_reference = "0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f"
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "1.1"
|
lock-version = "1.1"
|
||||||
python-versions = ">=3.8,<3.11"
|
python-versions = ">=3.8,<3.11"
|
||||||
content-hash = "f5395ffab6ce7d95246143218e948308d6614929f375489eb2b94a863e15fcc4"
|
content-hash = "ab527970383bc2245dee005627d0695812601115a36e15a5ef9e66d1185791bf"
|
||||||
|
|
||||||
[metadata.files]
|
[metadata.files]
|
||||||
black = [
|
black = [
|
||||||
|
@ -964,6 +975,27 @@ torch = [
|
||||||
{file = "torch-1.12.1-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:bfec2843daa654f04fda23ba823af03e7b6f7650a873cdb726752d0e3718dada"},
|
{file = "torch-1.12.1-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:bfec2843daa654f04fda23ba823af03e7b6f7650a873cdb726752d0e3718dada"},
|
||||||
{file = "torch-1.12.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:69fe2cae7c39ccadd65a123793d30e0db881f1c1927945519c5c17323131437e"},
|
{file = "torch-1.12.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:69fe2cae7c39ccadd65a123793d30e0db881f1c1927945519c5c17323131437e"},
|
||||||
]
|
]
|
||||||
|
torchaudio = [
|
||||||
|
{file = "torchaudio-0.12.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:dc138bee06b2305442fc132171f2a01d5f42509eaa21bdf87c3d26a6f4a09fdd"},
|
||||||
|
{file = "torchaudio-0.12.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1d81f71837d5d5be651e85ca9fa9377ecb4513b0129ddfb025540e1c2406d3e6"},
|
||||||
|
{file = "torchaudio-0.12.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:c2f46ad1332d4eb4c5bc2259bad22f7693d1e81cdcf2ab04242bf428d78f161f"},
|
||||||
|
{file = "torchaudio-0.12.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:21741a277d31f75215a09c1590170055b65c2eceda6aa5a263676745bd97172e"},
|
||||||
|
{file = "torchaudio-0.12.1-cp310-cp310-win_amd64.whl", hash = "sha256:f0cc2d4ab4288d5115fab554a49bed6251469dc1548c961655556ec48a3c320e"},
|
||||||
|
{file = "torchaudio-0.12.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:23dbcf37af2f41d491c0337ca94501ec7ef588adb1766e1eb28033fac549bbd9"},
|
||||||
|
{file = "torchaudio-0.12.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:e82c48b05d941d64cc67a18d13f8e76ba7e852fe9f187b47d3abfbebd1f05195"},
|
||||||
|
{file = "torchaudio-0.12.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:142da7f0f05517b32cb54ed6f37997f741ad1bd283474898b680b0dfed7ff926"},
|
||||||
|
{file = "torchaudio-0.12.1-cp37-cp37m-win_amd64.whl", hash = "sha256:1c839ceb2035c3ea3458e274e9a1afb65f5fa41678e76c3378b218eb23956579"},
|
||||||
|
{file = "torchaudio-0.12.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:a4c8c15b1e810a93bb77b27fa49159bea2253b593ef94039946ec49aef51764f"},
|
||||||
|
{file = "torchaudio-0.12.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:83c08b71a6dc8e23c1d7b00780abb9e4c29528e47a6e644fe3dee7ac2263821e"},
|
||||||
|
{file = "torchaudio-0.12.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:186dcaa00b60e441f9c489c00966ecdd7412c2a4592058107f8c3a888cbbf337"},
|
||||||
|
{file = "torchaudio-0.12.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:2937756874050cb3249395d7814dacab2c296ce3e5ae3e63397aa4fc902db885"},
|
||||||
|
{file = "torchaudio-0.12.1-cp38-cp38-win_amd64.whl", hash = "sha256:ba00c62bae021b8e5a3d38f04788e489e6f8d9eb16620d8c1e81b1e9d4bf1284"},
|
||||||
|
{file = "torchaudio-0.12.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:08f92bc53682d3bad8606dedb70a49e5a0f7cf9306c9173f074dbba97785442e"},
|
||||||
|
{file = "torchaudio-0.12.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2fc5a2bc8e8aad475bc519f3c82b9649e14b5c657487ffa712cf7c514143e9d7"},
|
||||||
|
{file = "torchaudio-0.12.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:075dba92c8c885ef1bc882e24a0ffdcce29a73f4d2377c75d1fa1c76702b37e3"},
|
||||||
|
{file = "torchaudio-0.12.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:a2bc09eee50fb5adc3e40c66bb63d525344bb8359f65d9c600d53ea6212207e6"},
|
||||||
|
{file = "torchaudio-0.12.1-cp39-cp39-win_amd64.whl", hash = "sha256:5b06c72da8ea8f8cd3075d7f97e2866b473aceaca08ef871895cd5fafde078bf"},
|
||||||
|
]
|
||||||
tqdm = [
|
tqdm = [
|
||||||
{file = "tqdm-4.64.1-py2.py3-none-any.whl", hash = "sha256:6fee160d6ffcd1b1c68c65f14c829c22832bc401726335ce92c52d395944a6a1"},
|
{file = "tqdm-4.64.1-py2.py3-none-any.whl", hash = "sha256:6fee160d6ffcd1b1c68c65f14c829c22832bc401726335ce92c52d395944a6a1"},
|
||||||
{file = "tqdm-4.64.1.tar.gz", hash = "sha256:5f4f682a004951c1b450bc753c710e9280c5746ce6ffedee253ddbcbf54cf1e4"},
|
{file = "tqdm-4.64.1.tar.gz", hash = "sha256:5f4f682a004951c1b450bc753c710e9280c5746ce6ffedee253ddbcbf54cf1e4"},
|
||||||
|
|
|
@ -13,6 +13,7 @@ sounddevice = "^0.4.5"
|
||||||
pydantic = "^1.10.2"
|
pydantic = "^1.10.2"
|
||||||
websockets = "^10.3"
|
websockets = "^10.3"
|
||||||
tqdm = "*"
|
tqdm = "*"
|
||||||
|
torchaudio = "^0.12.1"
|
||||||
|
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
|
|
|
@ -48,16 +48,16 @@ def transcribe_from_mic(
|
||||||
):
|
):
|
||||||
idx: int = 0
|
idx: int = 0
|
||||||
while True:
|
while True:
|
||||||
logger.debug(f"Segment #: {idx}, The rest of queue: {q.qsize()}")
|
logger.debug(f"Audio #: {idx}, The rest of queue: {q.qsize()}")
|
||||||
|
|
||||||
if no_progress:
|
if no_progress:
|
||||||
segment = q.get()
|
audio = q.get()
|
||||||
else:
|
else:
|
||||||
pbar_thread = ProgressBar(
|
pbar_thread = ProgressBar(
|
||||||
num_block=num_block, # TODO: set more accurate value
|
num_block=num_block, # TODO: set more accurate value
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
segment = q.get()
|
audio = q.get()
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pbar_thread.kill()
|
pbar_thread.kill()
|
||||||
return
|
return
|
||||||
|
@ -68,7 +68,7 @@ def transcribe_from_mic(
|
||||||
sys.stderr.write("Analyzing")
|
sys.stderr.write("Analyzing")
|
||||||
sys.stderr.flush()
|
sys.stderr.flush()
|
||||||
|
|
||||||
for chunk in wsp.transcribe(segment=segment, ctx=ctx):
|
for chunk in wsp.transcribe(audio=audio, ctx=ctx):
|
||||||
if not no_progress:
|
if not no_progress:
|
||||||
sys.stderr.write("\r")
|
sys.stderr.write("\r")
|
||||||
sys.stderr.flush()
|
sys.stderr.flush()
|
||||||
|
@ -155,6 +155,10 @@ def get_opts() -> argparse.Namespace:
|
||||||
"--no-progress",
|
"--no-progress",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-vad",
|
||||||
|
action="store_true",
|
||||||
|
)
|
||||||
opts = parser.parse_args()
|
opts = parser.parse_args()
|
||||||
|
|
||||||
if opts.beam_size <= 0:
|
if opts.beam_size <= 0:
|
||||||
|
@ -187,6 +191,7 @@ def get_context(*, opts) -> Context:
|
||||||
beam_size=opts.beam_size,
|
beam_size=opts.beam_size,
|
||||||
temperatures=opts.temperature,
|
temperatures=opts.temperature,
|
||||||
allow_padding=opts.allow_padding,
|
allow_padding=opts.allow_padding,
|
||||||
|
vad=not opts.no_vad,
|
||||||
)
|
)
|
||||||
logger.debug(f"Context: {ctx}")
|
logger.debug(f"Context: {ctx}")
|
||||||
return ctx
|
return ctx
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from pydantic import BaseModel, root_validator
|
from pydantic import BaseModel, root_validator
|
||||||
|
|
||||||
|
@ -26,6 +27,7 @@ class Context(BaseModel, arbitrary_types_allowed=True):
|
||||||
timestamp: float = 0.0
|
timestamp: float = 0.0
|
||||||
buffer_tokens: List[torch.Tensor] = []
|
buffer_tokens: List[torch.Tensor] = []
|
||||||
buffer_mel: Optional[torch.Tensor] = None
|
buffer_mel: Optional[torch.Tensor] = None
|
||||||
|
vad: bool = True
|
||||||
|
|
||||||
temperatures: List[float]
|
temperatures: List[float]
|
||||||
allow_padding: bool = False
|
allow_padding: bool = False
|
||||||
|
@ -50,3 +52,9 @@ class ParsedChunk(BaseModel):
|
||||||
avg_logprob: float
|
avg_logprob: float
|
||||||
compression_ratio: float
|
compression_ratio: float
|
||||||
no_speech_prob: float
|
no_speech_prob: float
|
||||||
|
|
||||||
|
|
||||||
|
class SpeechSegment(BaseModel, arbitrary_types_allowed=True):
|
||||||
|
start_block_idx: int
|
||||||
|
end_block_idx: int
|
||||||
|
audio: np.ndarray
|
||||||
|
|
|
@ -21,7 +21,7 @@ async def serve_with_websocket_main(websocket):
|
||||||
)
|
)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
logger.debug(f"Segment #: {idx}")
|
logger.debug(f"Audio #: {idx}")
|
||||||
try:
|
try:
|
||||||
message = await websocket.recv()
|
message = await websocket.recv()
|
||||||
except ConnectionClosedOK:
|
except ConnectionClosedOK:
|
||||||
|
@ -32,9 +32,9 @@ async def serve_with_websocket_main(websocket):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
logger.debug(f"Message size: {len(message)}")
|
logger.debug(f"Message size: {len(message)}")
|
||||||
segment = np.frombuffer(message, dtype=np.float32)
|
audio = np.frombuffer(message, dtype=np.float32)
|
||||||
for chunk in g_wsp.transcribe(
|
for chunk in g_wsp.transcribe(
|
||||||
segment=segment, # type: ignore
|
audio=audio, # type: ignore
|
||||||
ctx=ctx,
|
ctx=ctx,
|
||||||
):
|
):
|
||||||
await websocket.send(chunk.json())
|
await websocket.send(chunk.json())
|
||||||
|
|
|
@ -3,9 +3,11 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import Final, Iterator, Optional, Union
|
from typing import Final, Iterator, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from whisper import Whisper, load_model
|
from whisper import Whisper, load_model
|
||||||
from whisper.audio import (
|
from whisper.audio import (
|
||||||
|
CHUNK_LENGTH,
|
||||||
HOP_LENGTH,
|
HOP_LENGTH,
|
||||||
N_FRAMES,
|
N_FRAMES,
|
||||||
SAMPLE_RATE,
|
SAMPLE_RATE,
|
||||||
|
@ -17,6 +19,7 @@ from whisper.tokenizer import get_tokenizer
|
||||||
from whisper.utils import exact_div
|
from whisper.utils import exact_div
|
||||||
|
|
||||||
from whispering.schema import Context, ParsedChunk, WhisperConfig
|
from whispering.schema import Context, ParsedChunk, WhisperConfig
|
||||||
|
from whispering.vad import VAD
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -50,6 +53,8 @@ class WhisperStreamingTranscriber:
|
||||||
self.time_precision: Final[float] = (
|
self.time_precision: Final[float] = (
|
||||||
self.input_stride * HOP_LENGTH / SAMPLE_RATE
|
self.input_stride * HOP_LENGTH / SAMPLE_RATE
|
||||||
) # time per output token: 0.02 (seconds)
|
) # time per output token: 0.02 (seconds)
|
||||||
|
self.duration_pre_one_mel: Final[float] = CHUNK_LENGTH / HOP_LENGTH
|
||||||
|
self.vad = VAD()
|
||||||
|
|
||||||
def _get_decoding_options(
|
def _get_decoding_options(
|
||||||
self,
|
self,
|
||||||
|
@ -224,10 +229,25 @@ class WhisperStreamingTranscriber:
|
||||||
def transcribe(
|
def transcribe(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
segment: torch.Tensor,
|
audio: np.ndarray,
|
||||||
ctx: Context,
|
ctx: Context,
|
||||||
) -> Iterator[ParsedChunk]:
|
) -> Iterator[ParsedChunk]:
|
||||||
new_mel = log_mel_spectrogram(audio=segment)
|
logger.debug(f"{len(audio)}")
|
||||||
|
|
||||||
|
if not ctx.vad:
|
||||||
|
x = [
|
||||||
|
v
|
||||||
|
for v in self.vad(
|
||||||
|
audio=audio,
|
||||||
|
total_block_number=1,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
if len(x) == 0: # No speech
|
||||||
|
logger.debug("No speech")
|
||||||
|
ctx.timestamp += len(audio) / N_FRAMES * self.duration_pre_one_mel
|
||||||
|
return
|
||||||
|
|
||||||
|
new_mel = log_mel_spectrogram(audio=audio)
|
||||||
logger.debug(f"Incoming new_mel.shape: {new_mel.shape}")
|
logger.debug(f"Incoming new_mel.shape: {new_mel.shape}")
|
||||||
if ctx.buffer_mel is None:
|
if ctx.buffer_mel is None:
|
||||||
mel = new_mel
|
mel = new_mel
|
||||||
|
@ -239,7 +259,7 @@ class WhisperStreamingTranscriber:
|
||||||
|
|
||||||
seek: int = 0
|
seek: int = 0
|
||||||
while seek < mel.shape[-1]:
|
while seek < mel.shape[-1]:
|
||||||
segment = (
|
segment: torch.Tensor = (
|
||||||
pad_or_trim(mel[:, seek:], N_FRAMES)
|
pad_or_trim(mel[:, seek:], N_FRAMES)
|
||||||
.to(self.model.device) # type: ignore
|
.to(self.model.device) # type: ignore
|
||||||
.to(self.dtype)
|
.to(self.dtype)
|
||||||
|
|
67
whispering/vad.py
Normal file
67
whispering/vad.py
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
from typing import Iterator, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from whisper.audio import N_FRAMES, SAMPLE_RATE
|
||||||
|
|
||||||
|
from whispering.schema import SpeechSegment
|
||||||
|
|
||||||
|
|
||||||
|
class VAD:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
self.vad_model, _ = torch.hub.load(
|
||||||
|
repo_or_dir="snakers4/silero-vad",
|
||||||
|
model="silero_vad",
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
audio: np.ndarray,
|
||||||
|
thredhold: float = 0.5,
|
||||||
|
total_block_number: Optional[int] = None,
|
||||||
|
) -> Iterator[SpeechSegment]:
|
||||||
|
# audio.shape should be multiple of (N_FRAMES,)
|
||||||
|
|
||||||
|
def my_ret(
|
||||||
|
*,
|
||||||
|
start_block_idx: int,
|
||||||
|
idx: int,
|
||||||
|
) -> SpeechSegment:
|
||||||
|
return SpeechSegment(
|
||||||
|
start_block_idx=start_block_idx,
|
||||||
|
end_block_idx=idx,
|
||||||
|
audio=audio[N_FRAMES * start_block_idx : N_FRAMES * idx],
|
||||||
|
)
|
||||||
|
|
||||||
|
if total_block_number is None:
|
||||||
|
total_block_number = int(audio.shape[0] / N_FRAMES)
|
||||||
|
block_unit: int = audio.shape[0] // total_block_number
|
||||||
|
|
||||||
|
start_block_idx = None
|
||||||
|
for idx in range(total_block_number):
|
||||||
|
start: int = block_unit * idx
|
||||||
|
end: int = block_unit * (idx + 1)
|
||||||
|
vad_prob = self.vad_model(
|
||||||
|
torch.from_numpy(audio[start:end]),
|
||||||
|
SAMPLE_RATE,
|
||||||
|
).item()
|
||||||
|
if vad_prob > thredhold:
|
||||||
|
if start_block_idx is None:
|
||||||
|
start_block_idx = idx
|
||||||
|
else:
|
||||||
|
if start_block_idx is not None:
|
||||||
|
yield my_ret(
|
||||||
|
start_block_idx=start_block_idx,
|
||||||
|
idx=idx,
|
||||||
|
)
|
||||||
|
start_block_idx = None
|
||||||
|
if start_block_idx is not None:
|
||||||
|
yield my_ret(
|
||||||
|
start_block_idx=start_block_idx,
|
||||||
|
idx=total_block_number,
|
||||||
|
)
|
Loading…
Reference in a new issue