This commit is contained in:
Yuta Hayashibe 2022-10-01 23:20:46 +09:00
parent 95dafe6740
commit f46e71f51d
4 changed files with 47 additions and 54 deletions

6
poetry.lock generated
View file

@ -508,13 +508,13 @@ dev = ["pytest"]
[package.source]
type = "git"
url = "https://github.com/openai/whisper.git"
reference = '62fe7f1009a534986ac1d32a4aef8c244d029c28'
resolved_reference = "62fe7f1009a534986ac1d32a4aef8c244d029c28"
reference = '0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f'
resolved_reference = "0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f"
[metadata]
lock-version = "1.1"
python-versions = ">=3.8,<3.11"
content-hash = "d041d21a202339f405cc37076403f92135ee1f113cdfece5a78c9ee12374be7b"
content-hash = "f5395ffab6ce7d95246143218e948308d6614929f375489eb2b94a863e15fcc4"
[metadata.files]
black = [

View file

@ -8,7 +8,7 @@ packages = [{include = "whispering"}]
[tool.poetry.dependencies]
python = ">=3.8,<3.11"
whisper = {git = "https://github.com/openai/whisper.git", rev = '62fe7f1009a534986ac1d32a4aef8c244d029c28'}
whisper = {git = "https://github.com/openai/whisper.git", rev = '0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f'}
sounddevice = "^0.4.5"
pydantic = "^1.10.2"
websockets = "^10.3"

View file

@ -26,7 +26,10 @@ async def serve_with_websocket_main(websocket):
logger.debug(f"Message size: {len(message)}")
segment = np.frombuffer(message, dtype=np.float32)
for chunk in g_wsp.transcribe(segment=segment, ctx=g_ctx):
for chunk in g_wsp.transcribe(
segment=segment, # type: ignore
ctx=g_ctx,
):
await websocket.send(chunk.json())
idx += 1

View file

@ -1,9 +1,8 @@
#!/usr/bin/env python3
from logging import getLogger
from typing import Final, Iterator, List, Optional, Union
from typing import Final, Iterator, Optional, Union
import numpy as np
import torch
from whisper import Whisper, load_model
from whisper.audio import (
@ -75,59 +74,51 @@ class WhisperStreamingTranscriber:
suppress_blank=True,
suppress_tokens="-1",
without_timestamps=False,
max_initial_timestamp=0.0,
max_initial_timestamp=1.0,
fp16=self.fp16,
)
def _decode_with_fallback(
self,
*,
segment: np.ndarray,
segment: torch.Tensor,
ctx: Context,
) -> List[DecodingResult]:
) -> DecodingResult:
assert len(ctx.temperatures) >= 1
t = ctx.temperatures[0]
logger.debug(f"temperature: {t}")
decode_result: Optional[DecodingResult] = None
_decode_options1: DecodingOptions = self._get_decoding_options(
t=t,
prompt=ctx.buffer_tokens,
beam_size=ctx.beam_size,
patience=ctx.patience,
best_of=None,
)
results: List[DecodingResult] = self.model.decode(segment, _decode_options1) # type: ignore
for t in ctx.temperatures:
_decode_options: DecodingOptions = self._get_decoding_options(
t=t,
prompt=ctx.buffer_tokens,
beam_size=ctx.beam_size if t <= 0 else None,
patience=ctx.patience if t <= 0 else None,
best_of=ctx.best_of if t < 0 else None,
)
logger.debug(f"DecodeOptions: {_decode_options}")
decode_result = self.model.decode(
segment,
_decode_options,
) # type: ignore
assert decode_result is not None
for t in ctx.temperatures[1:]:
needs_fallback = [
needs_fallback: bool = False
if (
ctx.compression_ratio_threshold is not None
and result.compression_ratio > ctx.compression_ratio_threshold
or ctx.logprob_threshold is not None
and result.avg_logprob < ctx.logprob_threshold
for result in results
]
if any(needs_fallback):
logger.debug(
f"Fall back with temperature: {t}, needs_fallback: {needs_fallback}"
)
_decode_options2: DecodingOptions = self._get_decoding_options(
t=t,
prompt=ctx.buffer_tokens,
beam_size=None,
patience=None,
best_of=ctx.best_of,
)
retries: List[DecodingResult] = self.model.decode(
segment[needs_fallback], _decode_options2 # type: ignore
)
for retry_index, original_index in enumerate(
np.nonzero(needs_fallback)[0]
):
results[original_index] = retries[retry_index]
else:
and decode_result.compression_ratio > ctx.compression_ratio_threshold
):
needs_fallback = True # too repetitive
if (
ctx.logprob_threshold is not None
and decode_result.avg_logprob < ctx.logprob_threshold
):
needs_fallback = True # average log probability is too low
if not needs_fallback:
break
logger.debug(f"# of results: {len(results)}")
return results
assert isinstance(decode_result, DecodingResult)
return decode_result
def _get_chunk(
self,
@ -233,10 +224,10 @@ class WhisperStreamingTranscriber:
def transcribe(
self,
*,
segment: np.ndarray,
segment: torch.Tensor,
ctx: Context,
) -> Iterator[ParsedChunk]:
new_mel = log_mel_spectrogram(audio=segment).unsqueeze(0)
new_mel = log_mel_spectrogram(audio=segment)
logger.debug(f"Incoming new_mel.shape: {new_mel.shape}")
if ctx.buffer_mel is None:
mel = new_mel
@ -249,7 +240,7 @@ class WhisperStreamingTranscriber:
seek: int = 0
while seek < mel.shape[-1]:
segment = (
pad_or_trim(mel[:, :, seek:], N_FRAMES)
pad_or_trim(mel[:, seek:], N_FRAMES)
.to(self.model.device) # type: ignore
.to(self.dtype)
)
@ -260,11 +251,10 @@ class WhisperStreamingTranscriber:
f"seek={seek}, timestamp={ctx.timestamp}, "
f"mel.shape: {mel.shape}, segment.shape: {segment.shape}"
)
results = self._decode_with_fallback(
result = self._decode_with_fallback(
segment=segment,
ctx=ctx,
)
result = results[0]
logger.debug(
f"Result: temperature={result.temperature:.2f}, no_speech_prob={result.no_speech_prob:.2f}, "
f"avg_logprob={result.avg_logprob:.2f}"
@ -304,7 +294,7 @@ class WhisperStreamingTranscriber:
if mel.shape[-1] - seek <= 0:
logger.debug(f"ctx.buffer_mel is None ({mel.shape}, {seek})")
return
ctx.buffer_mel = mel[:, :, seek:]
ctx.buffer_mel = mel[:, seek:]
assert ctx.buffer_mel is not None
logger.debug(f"ctx.buffer_mel.shape: {ctx.buffer_mel.shape}")
del mel