This commit is contained in:
Yuta Hayashibe 2022-10-02 19:47:17 +09:00
parent 45eb0bc34d
commit 936d5d0c45
4 changed files with 17 additions and 14 deletions

View file

@ -48,16 +48,16 @@ def transcribe_from_mic(
): ):
idx: int = 0 idx: int = 0
while True: while True:
logger.debug(f"Segment #: {idx}, The rest of queue: {q.qsize()}") logger.debug(f"Audio #: {idx}, The rest of queue: {q.qsize()}")
if no_progress: if no_progress:
segment = q.get() audio = q.get()
else: else:
pbar_thread = ProgressBar( pbar_thread = ProgressBar(
num_block=num_block, # TODO: set more accurate value num_block=num_block, # TODO: set more accurate value
) )
try: try:
segment = q.get() audio = q.get()
except KeyboardInterrupt: except KeyboardInterrupt:
pbar_thread.kill() pbar_thread.kill()
return return
@ -68,7 +68,7 @@ def transcribe_from_mic(
sys.stderr.write("Analyzing") sys.stderr.write("Analyzing")
sys.stderr.flush() sys.stderr.flush()
for chunk in wsp.transcribe(segment=segment, ctx=ctx): for chunk in wsp.transcribe(audio=audio, ctx=ctx):
if not no_progress: if not no_progress:
sys.stderr.write("\r") sys.stderr.write("\r")
sys.stderr.flush() sys.stderr.flush()

View file

@ -2,6 +2,7 @@
from typing import List, Optional from typing import List, Optional
import numpy as np
import torch import torch
from pydantic import BaseModel, root_validator from pydantic import BaseModel, root_validator
@ -55,4 +56,4 @@ class ParsedChunk(BaseModel):
class SpeechSegment(BaseModel, arbitrary_types_allowed=True): class SpeechSegment(BaseModel, arbitrary_types_allowed=True):
start_block_idx: int start_block_idx: int
end_block_idx: int end_block_idx: int
segment: torch.Tensor audio: np.ndarray

View file

@ -3,6 +3,7 @@
from logging import getLogger from logging import getLogger
from typing import Final, Iterator, Optional, Union from typing import Final, Iterator, Optional, Union
import numpy as np
import torch import torch
from whisper import Whisper, load_model from whisper import Whisper, load_model
from whisper.audio import ( from whisper.audio import (
@ -226,13 +227,13 @@ class WhisperStreamingTranscriber:
def transcribe( def transcribe(
self, self,
*, *,
segment: torch.Tensor, audio: np.ndarray,
ctx: Context, ctx: Context,
) -> Iterator[ParsedChunk]: ) -> Iterator[ParsedChunk]:
for speech_segment in self.vad(segment=segment): for speech_segment in self.vad(audio=audio):
logger.debug(f"{speech_segment}") logger.debug(f"{speech_segment}")
new_mel = log_mel_spectrogram(audio=segment) new_mel = log_mel_spectrogram(audio=audio)
logger.debug(f"Incoming new_mel.shape: {new_mel.shape}") logger.debug(f"Incoming new_mel.shape: {new_mel.shape}")
if ctx.buffer_mel is None: if ctx.buffer_mel is None:
mel = new_mel mel = new_mel
@ -244,7 +245,7 @@ class WhisperStreamingTranscriber:
seek: int = 0 seek: int = 0
while seek < mel.shape[-1]: while seek < mel.shape[-1]:
segment = ( segment: torch.Tensor = (
pad_or_trim(mel[:, seek:], N_FRAMES) pad_or_trim(mel[:, seek:], N_FRAMES)
.to(self.model.device) # type: ignore .to(self.model.device) # type: ignore
.to(self.dtype) .to(self.dtype)

View file

@ -2,6 +2,7 @@
from typing import Iterator from typing import Iterator
import numpy as np
import torch import torch
from whisper.audio import N_FRAMES, SAMPLE_RATE from whisper.audio import N_FRAMES, SAMPLE_RATE
@ -20,10 +21,10 @@ class VAD:
def __call__( def __call__(
self, self,
*, *,
segment: torch.Tensor, audio: np.ndarray,
thredhold: float = 0.5, thredhold: float = 0.5,
) -> Iterator[SpeechSegment]: ) -> Iterator[SpeechSegment]:
# segment.shape should be multiple of (N_FRAMES,) # audio.shape should be multiple of (N_FRAMES,)
def my_ret( def my_ret(
*, *,
@ -33,17 +34,17 @@ class VAD:
return SpeechSegment( return SpeechSegment(
start_block_idx=start_block_idx, start_block_idx=start_block_idx,
end_block_idx=idx, end_block_idx=idx,
segment=segment[N_FRAMES * start_block_idx : N_FRAMES * idx], audio=audio[N_FRAMES * start_block_idx : N_FRAMES * idx],
) )
block_size: int = int(segment.shape[0] / N_FRAMES) block_size: int = int(audio.shape[0] / N_FRAMES)
start_block_idx = None start_block_idx = None
for idx in range(block_size): for idx in range(block_size):
start: int = N_FRAMES * idx start: int = N_FRAMES * idx
end: int = N_FRAMES * (idx + 1) end: int = N_FRAMES * (idx + 1)
vad_prob = self.vad_model( vad_prob = self.vad_model(
torch.from_numpy(segment[start:end]), torch.from_numpy(audio[start:end]),
SAMPLE_RATE, SAMPLE_RATE,
).item() ).item()
if vad_prob > thredhold: if vad_prob > thredhold: