From 936d5d0c45d27bce3cf7e3e3c0fbbeb781cd9d13 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Sun, 2 Oct 2022 19:47:17 +0900 Subject: [PATCH] Fix --- whispering/cli.py | 8 ++++---- whispering/schema.py | 3 ++- whispering/transcriber.py | 9 +++++---- whispering/vad.py | 11 ++++++----- 4 files changed, 17 insertions(+), 14 deletions(-) diff --git a/whispering/cli.py b/whispering/cli.py index 8b60271..fcbf9e4 100644 --- a/whispering/cli.py +++ b/whispering/cli.py @@ -48,16 +48,16 @@ def transcribe_from_mic( ): idx: int = 0 while True: - logger.debug(f"Segment #: {idx}, The rest of queue: {q.qsize()}") + logger.debug(f"Audio #: {idx}, The rest of queue: {q.qsize()}") if no_progress: - segment = q.get() + audio = q.get() else: pbar_thread = ProgressBar( num_block=num_block, # TODO: set more accurate value ) try: - segment = q.get() + audio = q.get() except KeyboardInterrupt: pbar_thread.kill() return @@ -68,7 +68,7 @@ def transcribe_from_mic( sys.stderr.write("Analyzing") sys.stderr.flush() - for chunk in wsp.transcribe(segment=segment, ctx=ctx): + for chunk in wsp.transcribe(audio=audio, ctx=ctx): if not no_progress: sys.stderr.write("\r") sys.stderr.flush() diff --git a/whispering/schema.py b/whispering/schema.py index 4d7f9af..d31d4df 100644 --- a/whispering/schema.py +++ b/whispering/schema.py @@ -2,6 +2,7 @@ from typing import List, Optional +import numpy as np import torch from pydantic import BaseModel, root_validator @@ -55,4 +56,4 @@ class ParsedChunk(BaseModel): class SpeechSegment(BaseModel, arbitrary_types_allowed=True): start_block_idx: int end_block_idx: int - segment: torch.Tensor + audio: np.ndarray diff --git a/whispering/transcriber.py b/whispering/transcriber.py index ac5d3b9..107d30a 100644 --- a/whispering/transcriber.py +++ b/whispering/transcriber.py @@ -3,6 +3,7 @@ from logging import getLogger from typing import Final, Iterator, Optional, Union +import numpy as np import torch from whisper import Whisper, load_model from whisper.audio import ( @@ -226,13 +227,13 @@ class WhisperStreamingTranscriber: def transcribe( self, *, - segment: torch.Tensor, + audio: np.ndarray, ctx: Context, ) -> Iterator[ParsedChunk]: - for speech_segment in self.vad(segment=segment): + for speech_segment in self.vad(audio=audio): logger.debug(f"{speech_segment}") - new_mel = log_mel_spectrogram(audio=segment) + new_mel = log_mel_spectrogram(audio=audio) logger.debug(f"Incoming new_mel.shape: {new_mel.shape}") if ctx.buffer_mel is None: mel = new_mel @@ -244,7 +245,7 @@ class WhisperStreamingTranscriber: seek: int = 0 while seek < mel.shape[-1]: - segment = ( + segment: torch.Tensor = ( pad_or_trim(mel[:, seek:], N_FRAMES) .to(self.model.device) # type: ignore .to(self.dtype) diff --git a/whispering/vad.py b/whispering/vad.py index 8d992de..f740b66 100644 --- a/whispering/vad.py +++ b/whispering/vad.py @@ -2,6 +2,7 @@ from typing import Iterator +import numpy as np import torch from whisper.audio import N_FRAMES, SAMPLE_RATE @@ -20,10 +21,10 @@ class VAD: def __call__( self, *, - segment: torch.Tensor, + audio: np.ndarray, thredhold: float = 0.5, ) -> Iterator[SpeechSegment]: - # segment.shape should be multiple of (N_FRAMES,) + # audio.shape should be multiple of (N_FRAMES,) def my_ret( *, @@ -33,17 +34,17 @@ class VAD: return SpeechSegment( start_block_idx=start_block_idx, end_block_idx=idx, - segment=segment[N_FRAMES * start_block_idx : N_FRAMES * idx], + audio=audio[N_FRAMES * start_block_idx : N_FRAMES * idx], ) - block_size: int = int(segment.shape[0] / N_FRAMES) + block_size: int = int(audio.shape[0] / N_FRAMES) start_block_idx = None for idx in range(block_size): start: int = N_FRAMES * idx end: int = N_FRAMES * (idx + 1) vad_prob = self.vad_model( - torch.from_numpy(segment[start:end]), + torch.from_numpy(audio[start:end]), SAMPLE_RATE, ).item() if vad_prob > thredhold: