mirror of
https://github.com/shirayu/whispering.git
synced 2025-01-22 06:38:13 +00:00
Fix
This commit is contained in:
parent
45eb0bc34d
commit
936d5d0c45
4 changed files with 17 additions and 14 deletions
|
@ -48,16 +48,16 @@ def transcribe_from_mic(
|
||||||
):
|
):
|
||||||
idx: int = 0
|
idx: int = 0
|
||||||
while True:
|
while True:
|
||||||
logger.debug(f"Segment #: {idx}, The rest of queue: {q.qsize()}")
|
logger.debug(f"Audio #: {idx}, The rest of queue: {q.qsize()}")
|
||||||
|
|
||||||
if no_progress:
|
if no_progress:
|
||||||
segment = q.get()
|
audio = q.get()
|
||||||
else:
|
else:
|
||||||
pbar_thread = ProgressBar(
|
pbar_thread = ProgressBar(
|
||||||
num_block=num_block, # TODO: set more accurate value
|
num_block=num_block, # TODO: set more accurate value
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
segment = q.get()
|
audio = q.get()
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pbar_thread.kill()
|
pbar_thread.kill()
|
||||||
return
|
return
|
||||||
|
@ -68,7 +68,7 @@ def transcribe_from_mic(
|
||||||
sys.stderr.write("Analyzing")
|
sys.stderr.write("Analyzing")
|
||||||
sys.stderr.flush()
|
sys.stderr.flush()
|
||||||
|
|
||||||
for chunk in wsp.transcribe(segment=segment, ctx=ctx):
|
for chunk in wsp.transcribe(audio=audio, ctx=ctx):
|
||||||
if not no_progress:
|
if not no_progress:
|
||||||
sys.stderr.write("\r")
|
sys.stderr.write("\r")
|
||||||
sys.stderr.flush()
|
sys.stderr.flush()
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from pydantic import BaseModel, root_validator
|
from pydantic import BaseModel, root_validator
|
||||||
|
|
||||||
|
@ -55,4 +56,4 @@ class ParsedChunk(BaseModel):
|
||||||
class SpeechSegment(BaseModel, arbitrary_types_allowed=True):
|
class SpeechSegment(BaseModel, arbitrary_types_allowed=True):
|
||||||
start_block_idx: int
|
start_block_idx: int
|
||||||
end_block_idx: int
|
end_block_idx: int
|
||||||
segment: torch.Tensor
|
audio: np.ndarray
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import Final, Iterator, Optional, Union
|
from typing import Final, Iterator, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from whisper import Whisper, load_model
|
from whisper import Whisper, load_model
|
||||||
from whisper.audio import (
|
from whisper.audio import (
|
||||||
|
@ -226,13 +227,13 @@ class WhisperStreamingTranscriber:
|
||||||
def transcribe(
|
def transcribe(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
segment: torch.Tensor,
|
audio: np.ndarray,
|
||||||
ctx: Context,
|
ctx: Context,
|
||||||
) -> Iterator[ParsedChunk]:
|
) -> Iterator[ParsedChunk]:
|
||||||
for speech_segment in self.vad(segment=segment):
|
for speech_segment in self.vad(audio=audio):
|
||||||
logger.debug(f"{speech_segment}")
|
logger.debug(f"{speech_segment}")
|
||||||
|
|
||||||
new_mel = log_mel_spectrogram(audio=segment)
|
new_mel = log_mel_spectrogram(audio=audio)
|
||||||
logger.debug(f"Incoming new_mel.shape: {new_mel.shape}")
|
logger.debug(f"Incoming new_mel.shape: {new_mel.shape}")
|
||||||
if ctx.buffer_mel is None:
|
if ctx.buffer_mel is None:
|
||||||
mel = new_mel
|
mel = new_mel
|
||||||
|
@ -244,7 +245,7 @@ class WhisperStreamingTranscriber:
|
||||||
|
|
||||||
seek: int = 0
|
seek: int = 0
|
||||||
while seek < mel.shape[-1]:
|
while seek < mel.shape[-1]:
|
||||||
segment = (
|
segment: torch.Tensor = (
|
||||||
pad_or_trim(mel[:, seek:], N_FRAMES)
|
pad_or_trim(mel[:, seek:], N_FRAMES)
|
||||||
.to(self.model.device) # type: ignore
|
.to(self.model.device) # type: ignore
|
||||||
.to(self.dtype)
|
.to(self.dtype)
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
from typing import Iterator
|
from typing import Iterator
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from whisper.audio import N_FRAMES, SAMPLE_RATE
|
from whisper.audio import N_FRAMES, SAMPLE_RATE
|
||||||
|
|
||||||
|
@ -20,10 +21,10 @@ class VAD:
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
segment: torch.Tensor,
|
audio: np.ndarray,
|
||||||
thredhold: float = 0.5,
|
thredhold: float = 0.5,
|
||||||
) -> Iterator[SpeechSegment]:
|
) -> Iterator[SpeechSegment]:
|
||||||
# segment.shape should be multiple of (N_FRAMES,)
|
# audio.shape should be multiple of (N_FRAMES,)
|
||||||
|
|
||||||
def my_ret(
|
def my_ret(
|
||||||
*,
|
*,
|
||||||
|
@ -33,17 +34,17 @@ class VAD:
|
||||||
return SpeechSegment(
|
return SpeechSegment(
|
||||||
start_block_idx=start_block_idx,
|
start_block_idx=start_block_idx,
|
||||||
end_block_idx=idx,
|
end_block_idx=idx,
|
||||||
segment=segment[N_FRAMES * start_block_idx : N_FRAMES * idx],
|
audio=audio[N_FRAMES * start_block_idx : N_FRAMES * idx],
|
||||||
)
|
)
|
||||||
|
|
||||||
block_size: int = int(segment.shape[0] / N_FRAMES)
|
block_size: int = int(audio.shape[0] / N_FRAMES)
|
||||||
|
|
||||||
start_block_idx = None
|
start_block_idx = None
|
||||||
for idx in range(block_size):
|
for idx in range(block_size):
|
||||||
start: int = N_FRAMES * idx
|
start: int = N_FRAMES * idx
|
||||||
end: int = N_FRAMES * (idx + 1)
|
end: int = N_FRAMES * (idx + 1)
|
||||||
vad_prob = self.vad_model(
|
vad_prob = self.vad_model(
|
||||||
torch.from_numpy(segment[start:end]),
|
torch.from_numpy(audio[start:end]),
|
||||||
SAMPLE_RATE,
|
SAMPLE_RATE,
|
||||||
).item()
|
).item()
|
||||||
if vad_prob > thredhold:
|
if vad_prob > thredhold:
|
||||||
|
|
Loading…
Reference in a new issue