This commit is contained in:
Yuta Hayashibe 2022-10-02 20:30:51 +09:00
parent 847eee5819
commit 08798f117a
2 changed files with 23 additions and 8 deletions

View file

@ -7,6 +7,7 @@ import numpy as np
import torch import torch
from whisper import Whisper, load_model from whisper import Whisper, load_model
from whisper.audio import ( from whisper.audio import (
CHUNK_LENGTH,
HOP_LENGTH, HOP_LENGTH,
N_FRAMES, N_FRAMES,
SAMPLE_RATE, SAMPLE_RATE,
@ -52,6 +53,7 @@ class WhisperStreamingTranscriber:
self.time_precision: Final[float] = ( self.time_precision: Final[float] = (
self.input_stride * HOP_LENGTH / SAMPLE_RATE self.input_stride * HOP_LENGTH / SAMPLE_RATE
) # time per output token: 0.02 (seconds) ) # time per output token: 0.02 (seconds)
self.duration_pre_one_mel: Final[float] = CHUNK_LENGTH / HOP_LENGTH
self.vad = VAD() self.vad = VAD()
def _get_decoding_options( def _get_decoding_options(
@ -230,8 +232,18 @@ class WhisperStreamingTranscriber:
audio: np.ndarray, audio: np.ndarray,
ctx: Context, ctx: Context,
) -> Iterator[ParsedChunk]: ) -> Iterator[ParsedChunk]:
for speech_segment in self.vad(audio=audio): logger.debug(f"{len(audio)}")
logger.debug(f"{speech_segment}") x = [
v
for v in self.vad(
audio=audio,
total_block_number=1,
)
]
if len(x) == 0: # No speech
logger.debug("No speech")
ctx.timestamp += len(audio) / N_FRAMES * self.duration_pre_one_mel
return
new_mel = log_mel_spectrogram(audio=audio) new_mel = log_mel_spectrogram(audio=audio)
logger.debug(f"Incoming new_mel.shape: {new_mel.shape}") logger.debug(f"Incoming new_mel.shape: {new_mel.shape}")

View file

@ -1,6 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from typing import Iterator from typing import Iterator, Optional
import numpy as np import numpy as np
import torch import torch
@ -23,6 +23,7 @@ class VAD:
*, *,
audio: np.ndarray, audio: np.ndarray,
thredhold: float = 0.5, thredhold: float = 0.5,
total_block_number: Optional[int] = None,
) -> Iterator[SpeechSegment]: ) -> Iterator[SpeechSegment]:
# audio.shape should be multiple of (N_FRAMES,) # audio.shape should be multiple of (N_FRAMES,)
@ -37,12 +38,14 @@ class VAD:
audio=audio[N_FRAMES * start_block_idx : N_FRAMES * idx], audio=audio[N_FRAMES * start_block_idx : N_FRAMES * idx],
) )
block_size: int = int(audio.shape[0] / N_FRAMES) if total_block_number is None:
total_block_number = int(audio.shape[0] / N_FRAMES)
block_unit: int = audio.shape[0] // total_block_number
start_block_idx = None start_block_idx = None
for idx in range(block_size): for idx in range(total_block_number):
start: int = N_FRAMES * idx start: int = block_unit * idx
end: int = N_FRAMES * (idx + 1) end: int = block_unit * (idx + 1)
vad_prob = self.vad_model( vad_prob = self.vad_model(
torch.from_numpy(audio[start:end]), torch.from_numpy(audio[start:end]),
SAMPLE_RATE, SAMPLE_RATE,
@ -60,5 +63,5 @@ class VAD:
if start_block_idx is not None: if start_block_idx is not None:
yield my_ret( yield my_ret(
start_block_idx=start_block_idx, start_block_idx=start_block_idx,
idx=block_size, idx=total_block_number,
) )