This commit is contained in:
Yuta Hayashibe 2022-10-02 19:47:17 +09:00
parent 45eb0bc34d
commit 936d5d0c45
4 changed files with 17 additions and 14 deletions

View file

@ -48,16 +48,16 @@ def transcribe_from_mic(
):
idx: int = 0
while True:
logger.debug(f"Segment #: {idx}, The rest of queue: {q.qsize()}")
logger.debug(f"Audio #: {idx}, The rest of queue: {q.qsize()}")
if no_progress:
segment = q.get()
audio = q.get()
else:
pbar_thread = ProgressBar(
num_block=num_block, # TODO: set more accurate value
)
try:
segment = q.get()
audio = q.get()
except KeyboardInterrupt:
pbar_thread.kill()
return
@ -68,7 +68,7 @@ def transcribe_from_mic(
sys.stderr.write("Analyzing")
sys.stderr.flush()
for chunk in wsp.transcribe(segment=segment, ctx=ctx):
for chunk in wsp.transcribe(audio=audio, ctx=ctx):
if not no_progress:
sys.stderr.write("\r")
sys.stderr.flush()

View file

@ -2,6 +2,7 @@
from typing import List, Optional
import numpy as np
import torch
from pydantic import BaseModel, root_validator
@ -55,4 +56,4 @@ class ParsedChunk(BaseModel):
class SpeechSegment(BaseModel, arbitrary_types_allowed=True):
start_block_idx: int
end_block_idx: int
segment: torch.Tensor
audio: np.ndarray

View file

@ -3,6 +3,7 @@
from logging import getLogger
from typing import Final, Iterator, Optional, Union
import numpy as np
import torch
from whisper import Whisper, load_model
from whisper.audio import (
@ -226,13 +227,13 @@ class WhisperStreamingTranscriber:
def transcribe(
self,
*,
segment: torch.Tensor,
audio: np.ndarray,
ctx: Context,
) -> Iterator[ParsedChunk]:
for speech_segment in self.vad(segment=segment):
for speech_segment in self.vad(audio=audio):
logger.debug(f"{speech_segment}")
new_mel = log_mel_spectrogram(audio=segment)
new_mel = log_mel_spectrogram(audio=audio)
logger.debug(f"Incoming new_mel.shape: {new_mel.shape}")
if ctx.buffer_mel is None:
mel = new_mel
@ -244,7 +245,7 @@ class WhisperStreamingTranscriber:
seek: int = 0
while seek < mel.shape[-1]:
segment = (
segment: torch.Tensor = (
pad_or_trim(mel[:, seek:], N_FRAMES)
.to(self.model.device) # type: ignore
.to(self.dtype)

View file

@ -2,6 +2,7 @@
from typing import Iterator
import numpy as np
import torch
from whisper.audio import N_FRAMES, SAMPLE_RATE
@ -20,10 +21,10 @@ class VAD:
def __call__(
self,
*,
segment: torch.Tensor,
audio: np.ndarray,
thredhold: float = 0.5,
) -> Iterator[SpeechSegment]:
# segment.shape should be multiple of (N_FRAMES,)
# audio.shape should be multiple of (N_FRAMES,)
def my_ret(
*,
@ -33,17 +34,17 @@ class VAD:
return SpeechSegment(
start_block_idx=start_block_idx,
end_block_idx=idx,
segment=segment[N_FRAMES * start_block_idx : N_FRAMES * idx],
audio=audio[N_FRAMES * start_block_idx : N_FRAMES * idx],
)
block_size: int = int(segment.shape[0] / N_FRAMES)
block_size: int = int(audio.shape[0] / N_FRAMES)
start_block_idx = None
for idx in range(block_size):
start: int = N_FRAMES * idx
end: int = N_FRAMES * (idx + 1)
vad_prob = self.vad_model(
torch.from_numpy(segment[start:end]),
torch.from_numpy(audio[start:end]),
SAMPLE_RATE,
).item()
if vad_prob > thredhold: