mirror of
https://github.com/shirayu/whispering.git
synced 2025-01-22 06:38:13 +00:00
Fix
This commit is contained in:
parent
45eb0bc34d
commit
936d5d0c45
4 changed files with 17 additions and 14 deletions
|
@ -48,16 +48,16 @@ def transcribe_from_mic(
|
|||
):
|
||||
idx: int = 0
|
||||
while True:
|
||||
logger.debug(f"Segment #: {idx}, The rest of queue: {q.qsize()}")
|
||||
logger.debug(f"Audio #: {idx}, The rest of queue: {q.qsize()}")
|
||||
|
||||
if no_progress:
|
||||
segment = q.get()
|
||||
audio = q.get()
|
||||
else:
|
||||
pbar_thread = ProgressBar(
|
||||
num_block=num_block, # TODO: set more accurate value
|
||||
)
|
||||
try:
|
||||
segment = q.get()
|
||||
audio = q.get()
|
||||
except KeyboardInterrupt:
|
||||
pbar_thread.kill()
|
||||
return
|
||||
|
@ -68,7 +68,7 @@ def transcribe_from_mic(
|
|||
sys.stderr.write("Analyzing")
|
||||
sys.stderr.flush()
|
||||
|
||||
for chunk in wsp.transcribe(segment=segment, ctx=ctx):
|
||||
for chunk in wsp.transcribe(audio=audio, ctx=ctx):
|
||||
if not no_progress:
|
||||
sys.stderr.write("\r")
|
||||
sys.stderr.flush()
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
|
@ -55,4 +56,4 @@ class ParsedChunk(BaseModel):
|
|||
class SpeechSegment(BaseModel, arbitrary_types_allowed=True):
|
||||
start_block_idx: int
|
||||
end_block_idx: int
|
||||
segment: torch.Tensor
|
||||
audio: np.ndarray
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
from logging import getLogger
|
||||
from typing import Final, Iterator, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from whisper import Whisper, load_model
|
||||
from whisper.audio import (
|
||||
|
@ -226,13 +227,13 @@ class WhisperStreamingTranscriber:
|
|||
def transcribe(
|
||||
self,
|
||||
*,
|
||||
segment: torch.Tensor,
|
||||
audio: np.ndarray,
|
||||
ctx: Context,
|
||||
) -> Iterator[ParsedChunk]:
|
||||
for speech_segment in self.vad(segment=segment):
|
||||
for speech_segment in self.vad(audio=audio):
|
||||
logger.debug(f"{speech_segment}")
|
||||
|
||||
new_mel = log_mel_spectrogram(audio=segment)
|
||||
new_mel = log_mel_spectrogram(audio=audio)
|
||||
logger.debug(f"Incoming new_mel.shape: {new_mel.shape}")
|
||||
if ctx.buffer_mel is None:
|
||||
mel = new_mel
|
||||
|
@ -244,7 +245,7 @@ class WhisperStreamingTranscriber:
|
|||
|
||||
seek: int = 0
|
||||
while seek < mel.shape[-1]:
|
||||
segment = (
|
||||
segment: torch.Tensor = (
|
||||
pad_or_trim(mel[:, seek:], N_FRAMES)
|
||||
.to(self.model.device) # type: ignore
|
||||
.to(self.dtype)
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
from typing import Iterator
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from whisper.audio import N_FRAMES, SAMPLE_RATE
|
||||
|
||||
|
@ -20,10 +21,10 @@ class VAD:
|
|||
def __call__(
|
||||
self,
|
||||
*,
|
||||
segment: torch.Tensor,
|
||||
audio: np.ndarray,
|
||||
thredhold: float = 0.5,
|
||||
) -> Iterator[SpeechSegment]:
|
||||
# segment.shape should be multiple of (N_FRAMES,)
|
||||
# audio.shape should be multiple of (N_FRAMES,)
|
||||
|
||||
def my_ret(
|
||||
*,
|
||||
|
@ -33,17 +34,17 @@ class VAD:
|
|||
return SpeechSegment(
|
||||
start_block_idx=start_block_idx,
|
||||
end_block_idx=idx,
|
||||
segment=segment[N_FRAMES * start_block_idx : N_FRAMES * idx],
|
||||
audio=audio[N_FRAMES * start_block_idx : N_FRAMES * idx],
|
||||
)
|
||||
|
||||
block_size: int = int(segment.shape[0] / N_FRAMES)
|
||||
block_size: int = int(audio.shape[0] / N_FRAMES)
|
||||
|
||||
start_block_idx = None
|
||||
for idx in range(block_size):
|
||||
start: int = N_FRAMES * idx
|
||||
end: int = N_FRAMES * (idx + 1)
|
||||
vad_prob = self.vad_model(
|
||||
torch.from_numpy(segment[start:end]),
|
||||
torch.from_numpy(audio[start:end]),
|
||||
SAMPLE_RATE,
|
||||
).item()
|
||||
if vad_prob > thredhold:
|
||||
|
|
Loading…
Reference in a new issue