Merge remote-tracking branch 'origin/master' into vad

This commit is contained in:
Yuta Hayashibe 2022-10-02 19:39:33 +09:00
commit 45eb0bc34d
8 changed files with 91 additions and 73 deletions

View file

@ -32,7 +32,7 @@ whispering --language en --model tiny
- ``--model`` set the [model name](https://github.com/openai/whisper#available-models-and-languages) to use. Larger models will be more accurate, but may not be able to transcribe in real time.
- ``--language`` sets the language to transcribe. The list of languages are shown with ``whispering -h``
- ``--no-progress`` disables the progress message
- ``-t`` sets temperatures to decode. You can set several like (``-t 0.0 -t 0.1 -t 0.5``), but too many temperatures exhaust decoding time
- ``-t`` sets temperatures to decode. You can set several like ``-t 0.0 -t 0.1 -t 0.5``, but too many temperatures exhaust decoding time
- ``--debug`` outputs logs for debug
### Parse interval

6
poetry.lock generated
View file

@ -519,13 +519,13 @@ dev = ["pytest"]
[package.source]
type = "git"
url = "https://github.com/openai/whisper.git"
reference = '62fe7f1009a534986ac1d32a4aef8c244d029c28'
resolved_reference = "62fe7f1009a534986ac1d32a4aef8c244d029c28"
reference = '0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f'
resolved_reference = "0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f"
[metadata]
lock-version = "1.1"
python-versions = ">=3.8,<3.11"
content-hash = "75e53434d1d46d54a886ca7a896a2f0ba0072a1848f90d5b6dc46ea2c5b47191"
content-hash = "ab527970383bc2245dee005627d0695812601115a36e15a5ef9e66d1185791bf"
[metadata.files]
black = [

View file

@ -8,7 +8,7 @@ packages = [{include = "whispering"}]
[tool.poetry.dependencies]
python = ">=3.8,<3.11"
whisper = {git = "https://github.com/openai/whisper.git", rev = '62fe7f1009a534986ac1d32a4aef8c244d029c28'}
whisper = {git = "https://github.com/openai/whisper.git", rev = '0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f'}
sounddevice = "^0.4.5"
pydantic = "^1.10.2"
websockets = "^10.3"

View file

@ -214,11 +214,14 @@ def main() -> None:
if opts.mode == "client":
assert opts.language is None
assert opts.model is None
asyncio.run(
run_websocket_client(
opts=opts,
try:
asyncio.run(
run_websocket_client(
opts=opts,
)
)
)
except KeyboardInterrupt:
pass
else:
assert opts.language is not None
assert opts.model is not None

View file

@ -2,7 +2,6 @@
from typing import List, Optional
import numpy as np
import torch
from pydantic import BaseModel, root_validator
@ -56,4 +55,4 @@ class ParsedChunk(BaseModel):
class SpeechSegment(BaseModel, arbitrary_types_allowed=True):
start_block_idx: int
end_block_idx: int
segment: np.ndarray
segment: torch.Tensor

View file

@ -5,6 +5,7 @@ from logging import getLogger
import numpy as np
import websockets
from websockets.exceptions import ConnectionClosedOK
from whispering.transcriber import Context, WhisperStreamingTranscriber
@ -15,10 +16,16 @@ async def serve_with_websocket_main(websocket):
global g_wsp
global g_ctx
idx: int = 0
ctx: Context = g_ctx.copy(
deep=True,
)
while True:
logger.debug(f"Segment #: {idx}")
message = await websocket.recv()
try:
message = await websocket.recv()
except ConnectionClosedOK:
break
if isinstance(message, str):
logger.debug(f"Got str: {message}")
@ -26,7 +33,10 @@ async def serve_with_websocket_main(websocket):
logger.debug(f"Message size: {len(message)}")
segment = np.frombuffer(message, dtype=np.float32)
for chunk in g_wsp.transcribe(segment=segment, ctx=g_ctx):
for chunk in g_wsp.transcribe(
segment=segment, # type: ignore
ctx=ctx,
):
await websocket.send(chunk.json())
idx += 1

View file

@ -1,9 +1,8 @@
#!/usr/bin/env python3
from logging import getLogger
from typing import Final, Iterator, List, Optional, Union
from typing import Final, Iterator, Optional, Union
import numpy as np
import torch
from whisper import Whisper, load_model
from whisper.audio import (
@ -77,59 +76,51 @@ class WhisperStreamingTranscriber:
suppress_blank=True,
suppress_tokens="-1",
without_timestamps=False,
max_initial_timestamp=0.0,
max_initial_timestamp=1.0,
fp16=self.fp16,
)
def _decode_with_fallback(
self,
*,
segment: np.ndarray,
segment: torch.Tensor,
ctx: Context,
) -> List[DecodingResult]:
) -> DecodingResult:
assert len(ctx.temperatures) >= 1
t = ctx.temperatures[0]
logger.debug(f"temperature: {t}")
decode_result: Optional[DecodingResult] = None
_decode_options1: DecodingOptions = self._get_decoding_options(
t=t,
prompt=ctx.buffer_tokens,
beam_size=ctx.beam_size,
patience=None,
best_of=None,
)
results: List[DecodingResult] = self.model.decode(segment, _decode_options1) # type: ignore
for t in ctx.temperatures:
_decode_options: DecodingOptions = self._get_decoding_options(
t=t,
prompt=ctx.buffer_tokens,
beam_size=ctx.beam_size if t <= 0 else None,
patience=ctx.patience if t <= 0 else None,
best_of=ctx.best_of if t < 0 else None,
)
logger.debug(f"DecodeOptions: {_decode_options}")
decode_result = self.model.decode(
segment,
_decode_options,
) # type: ignore
assert decode_result is not None
for t in ctx.temperatures[1:]:
needs_fallback = [
needs_fallback: bool = False
if (
ctx.compression_ratio_threshold is not None
and result.compression_ratio > ctx.compression_ratio_threshold
or ctx.logprob_threshold is not None
and result.avg_logprob < ctx.logprob_threshold
for result in results
]
if any(needs_fallback):
logger.debug(
f"Fall back with temperature: {t}, needs_fallback: {needs_fallback}"
)
_decode_options2: DecodingOptions = self._get_decoding_options(
t=t,
prompt=ctx.buffer_tokens,
beam_size=None,
patience=ctx.patience,
best_of=ctx.best_of,
)
retries: List[DecodingResult] = self.model.decode(
segment[needs_fallback], _decode_options2 # type: ignore
)
for retry_index, original_index in enumerate(
np.nonzero(needs_fallback)[0]
):
results[original_index] = retries[retry_index]
else:
and decode_result.compression_ratio > ctx.compression_ratio_threshold
):
needs_fallback = True # too repetitive
if (
ctx.logprob_threshold is not None
and decode_result.avg_logprob < ctx.logprob_threshold
):
needs_fallback = True # average log probability is too low
if not needs_fallback:
break
logger.debug(f"# of results: {len(results)}")
return results
assert isinstance(decode_result, DecodingResult)
return decode_result
def _get_chunk(
self,
@ -205,7 +196,10 @@ class WhisperStreamingTranscriber:
duration = segment_duration
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
logger.debug(f"Length of consecutive: 0, timestamps: {timestamps}")
if len(timestamps) > 0:
if (
len(timestamps) > 0
and timestamps[-1].item() != self.tokenizer.timestamp_begin
):
# no consecutive timestamps but it has a timestamp; use the last one.
# single timestamp at the end means no speech after the last timestamp.
last_timestamp_position = (
@ -232,13 +226,13 @@ class WhisperStreamingTranscriber:
def transcribe(
self,
*,
segment: np.ndarray,
segment: torch.Tensor,
ctx: Context,
) -> Iterator[ParsedChunk]:
vad_probs = self.vad(segment)
logger.debug(f"{vad_probs}")
for speech_segment in self.vad(segment=segment):
logger.debug(f"{speech_segment}")
new_mel = log_mel_spectrogram(audio=segment).unsqueeze(0)
new_mel = log_mel_spectrogram(audio=segment)
logger.debug(f"Incoming new_mel.shape: {new_mel.shape}")
if ctx.buffer_mel is None:
mel = new_mel
@ -251,7 +245,7 @@ class WhisperStreamingTranscriber:
seek: int = 0
while seek < mel.shape[-1]:
segment = (
pad_or_trim(mel[:, :, seek:], N_FRAMES)
pad_or_trim(mel[:, seek:], N_FRAMES)
.to(self.model.device) # type: ignore
.to(self.dtype)
)
@ -262,11 +256,10 @@ class WhisperStreamingTranscriber:
f"seek={seek}, timestamp={ctx.timestamp}, "
f"mel.shape: {mel.shape}, segment.shape: {segment.shape}"
)
results = self._decode_with_fallback(
result = self._decode_with_fallback(
segment=segment,
ctx=ctx,
)
result = results[0]
logger.debug(
f"Result: temperature={result.temperature:.2f}, no_speech_prob={result.no_speech_prob:.2f}, "
f"avg_logprob={result.avg_logprob:.2f}"
@ -306,7 +299,7 @@ class WhisperStreamingTranscriber:
if mel.shape[-1] - seek <= 0:
logger.debug(f"ctx.buffer_mel is None ({mel.shape}, {seek})")
return
ctx.buffer_mel = mel[:, :, seek:]
ctx.buffer_mel = mel[:, seek:]
assert ctx.buffer_mel is not None
logger.debug(f"ctx.buffer_mel.shape: {ctx.buffer_mel.shape}")
del mel

View file

@ -2,7 +2,6 @@
from typing import Iterator
import numpy as np
import torch
from whisper.audio import N_FRAMES, SAMPLE_RATE
@ -21,16 +20,26 @@ class VAD:
def __call__(
self,
*,
segment: np.ndarray,
segment: torch.Tensor,
thredhold: float = 0.5,
) -> Iterator[SpeechBlock]:
) -> Iterator[SpeechSegment]:
# segment.shape should be multiple of (N_FRAMES,)
def my_ret(
*,
start_block_idx: int,
idx: int,
) -> SpeechSegment:
return SpeechSegment(
start_block_idx=start_block_idx,
end_block_idx=idx,
segment=segment[N_FRAMES * start_block_idx : N_FRAMES * idx],
)
block_size: int = int(segment.shape[0] / N_FRAMES)
start_block_idx = None
for idx in range(block_size + 1):
if idx < block_size:
for idx in range(block_size):
start: int = N_FRAMES * idx
end: int = N_FRAMES * (idx + 1)
vad_prob = self.vad_model(
@ -42,9 +51,13 @@ class VAD:
start_block_idx = idx
else:
if start_block_idx is not None:
yield SpeechSegment(
yield my_ret(
start_block_idx=start_block_idx,
end_block_idx=idx,
segment=segment[N_FRAMES * start_block_idx : N_FRAMES * idx],
idx=idx,
)
start_block_idx = None
if start_block_idx is not None:
yield my_ret(
start_block_idx=start_block_idx,
idx=block_size,
)