mirror of
https://github.com/shirayu/whispering.git
synced 2025-02-16 10:35:16 +00:00
Fix (openai/whisper@7cb4cc2)
This commit is contained in:
parent
95dafe6740
commit
f46e71f51d
4 changed files with 47 additions and 54 deletions
6
poetry.lock
generated
6
poetry.lock
generated
|
@ -508,13 +508,13 @@ dev = ["pytest"]
|
|||
[package.source]
|
||||
type = "git"
|
||||
url = "https://github.com/openai/whisper.git"
|
||||
reference = '62fe7f1009a534986ac1d32a4aef8c244d029c28'
|
||||
resolved_reference = "62fe7f1009a534986ac1d32a4aef8c244d029c28"
|
||||
reference = '0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f'
|
||||
resolved_reference = "0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f"
|
||||
|
||||
[metadata]
|
||||
lock-version = "1.1"
|
||||
python-versions = ">=3.8,<3.11"
|
||||
content-hash = "d041d21a202339f405cc37076403f92135ee1f113cdfece5a78c9ee12374be7b"
|
||||
content-hash = "f5395ffab6ce7d95246143218e948308d6614929f375489eb2b94a863e15fcc4"
|
||||
|
||||
[metadata.files]
|
||||
black = [
|
||||
|
|
|
@ -8,7 +8,7 @@ packages = [{include = "whispering"}]
|
|||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8,<3.11"
|
||||
whisper = {git = "https://github.com/openai/whisper.git", rev = '62fe7f1009a534986ac1d32a4aef8c244d029c28'}
|
||||
whisper = {git = "https://github.com/openai/whisper.git", rev = '0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f'}
|
||||
sounddevice = "^0.4.5"
|
||||
pydantic = "^1.10.2"
|
||||
websockets = "^10.3"
|
||||
|
|
|
@ -26,7 +26,10 @@ async def serve_with_websocket_main(websocket):
|
|||
|
||||
logger.debug(f"Message size: {len(message)}")
|
||||
segment = np.frombuffer(message, dtype=np.float32)
|
||||
for chunk in g_wsp.transcribe(segment=segment, ctx=g_ctx):
|
||||
for chunk in g_wsp.transcribe(
|
||||
segment=segment, # type: ignore
|
||||
ctx=g_ctx,
|
||||
):
|
||||
await websocket.send(chunk.json())
|
||||
idx += 1
|
||||
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from logging import getLogger
|
||||
from typing import Final, Iterator, List, Optional, Union
|
||||
from typing import Final, Iterator, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from whisper import Whisper, load_model
|
||||
from whisper.audio import (
|
||||
|
@ -75,59 +74,51 @@ class WhisperStreamingTranscriber:
|
|||
suppress_blank=True,
|
||||
suppress_tokens="-1",
|
||||
without_timestamps=False,
|
||||
max_initial_timestamp=0.0,
|
||||
max_initial_timestamp=1.0,
|
||||
fp16=self.fp16,
|
||||
)
|
||||
|
||||
def _decode_with_fallback(
|
||||
self,
|
||||
*,
|
||||
segment: np.ndarray,
|
||||
segment: torch.Tensor,
|
||||
ctx: Context,
|
||||
) -> List[DecodingResult]:
|
||||
) -> DecodingResult:
|
||||
assert len(ctx.temperatures) >= 1
|
||||
t = ctx.temperatures[0]
|
||||
logger.debug(f"temperature: {t}")
|
||||
decode_result: Optional[DecodingResult] = None
|
||||
|
||||
_decode_options1: DecodingOptions = self._get_decoding_options(
|
||||
t=t,
|
||||
prompt=ctx.buffer_tokens,
|
||||
beam_size=ctx.beam_size,
|
||||
patience=ctx.patience,
|
||||
best_of=None,
|
||||
)
|
||||
results: List[DecodingResult] = self.model.decode(segment, _decode_options1) # type: ignore
|
||||
for t in ctx.temperatures:
|
||||
_decode_options: DecodingOptions = self._get_decoding_options(
|
||||
t=t,
|
||||
prompt=ctx.buffer_tokens,
|
||||
beam_size=ctx.beam_size if t <= 0 else None,
|
||||
patience=ctx.patience if t <= 0 else None,
|
||||
best_of=ctx.best_of if t < 0 else None,
|
||||
)
|
||||
logger.debug(f"DecodeOptions: {_decode_options}")
|
||||
decode_result = self.model.decode(
|
||||
segment,
|
||||
_decode_options,
|
||||
) # type: ignore
|
||||
assert decode_result is not None
|
||||
|
||||
for t in ctx.temperatures[1:]:
|
||||
needs_fallback = [
|
||||
needs_fallback: bool = False
|
||||
if (
|
||||
ctx.compression_ratio_threshold is not None
|
||||
and result.compression_ratio > ctx.compression_ratio_threshold
|
||||
or ctx.logprob_threshold is not None
|
||||
and result.avg_logprob < ctx.logprob_threshold
|
||||
for result in results
|
||||
]
|
||||
if any(needs_fallback):
|
||||
logger.debug(
|
||||
f"Fall back with temperature: {t}, needs_fallback: {needs_fallback}"
|
||||
)
|
||||
_decode_options2: DecodingOptions = self._get_decoding_options(
|
||||
t=t,
|
||||
prompt=ctx.buffer_tokens,
|
||||
beam_size=None,
|
||||
patience=None,
|
||||
best_of=ctx.best_of,
|
||||
)
|
||||
retries: List[DecodingResult] = self.model.decode(
|
||||
segment[needs_fallback], _decode_options2 # type: ignore
|
||||
)
|
||||
for retry_index, original_index in enumerate(
|
||||
np.nonzero(needs_fallback)[0]
|
||||
):
|
||||
results[original_index] = retries[retry_index]
|
||||
else:
|
||||
and decode_result.compression_ratio > ctx.compression_ratio_threshold
|
||||
):
|
||||
needs_fallback = True # too repetitive
|
||||
if (
|
||||
ctx.logprob_threshold is not None
|
||||
and decode_result.avg_logprob < ctx.logprob_threshold
|
||||
):
|
||||
needs_fallback = True # average log probability is too low
|
||||
|
||||
if not needs_fallback:
|
||||
break
|
||||
logger.debug(f"# of results: {len(results)}")
|
||||
return results
|
||||
|
||||
assert isinstance(decode_result, DecodingResult)
|
||||
return decode_result
|
||||
|
||||
def _get_chunk(
|
||||
self,
|
||||
|
@ -233,10 +224,10 @@ class WhisperStreamingTranscriber:
|
|||
def transcribe(
|
||||
self,
|
||||
*,
|
||||
segment: np.ndarray,
|
||||
segment: torch.Tensor,
|
||||
ctx: Context,
|
||||
) -> Iterator[ParsedChunk]:
|
||||
new_mel = log_mel_spectrogram(audio=segment).unsqueeze(0)
|
||||
new_mel = log_mel_spectrogram(audio=segment)
|
||||
logger.debug(f"Incoming new_mel.shape: {new_mel.shape}")
|
||||
if ctx.buffer_mel is None:
|
||||
mel = new_mel
|
||||
|
@ -249,7 +240,7 @@ class WhisperStreamingTranscriber:
|
|||
seek: int = 0
|
||||
while seek < mel.shape[-1]:
|
||||
segment = (
|
||||
pad_or_trim(mel[:, :, seek:], N_FRAMES)
|
||||
pad_or_trim(mel[:, seek:], N_FRAMES)
|
||||
.to(self.model.device) # type: ignore
|
||||
.to(self.dtype)
|
||||
)
|
||||
|
@ -260,11 +251,10 @@ class WhisperStreamingTranscriber:
|
|||
f"seek={seek}, timestamp={ctx.timestamp}, "
|
||||
f"mel.shape: {mel.shape}, segment.shape: {segment.shape}"
|
||||
)
|
||||
results = self._decode_with_fallback(
|
||||
result = self._decode_with_fallback(
|
||||
segment=segment,
|
||||
ctx=ctx,
|
||||
)
|
||||
result = results[0]
|
||||
logger.debug(
|
||||
f"Result: temperature={result.temperature:.2f}, no_speech_prob={result.no_speech_prob:.2f}, "
|
||||
f"avg_logprob={result.avg_logprob:.2f}"
|
||||
|
@ -304,7 +294,7 @@ class WhisperStreamingTranscriber:
|
|||
if mel.shape[-1] - seek <= 0:
|
||||
logger.debug(f"ctx.buffer_mel is None ({mel.shape}, {seek})")
|
||||
return
|
||||
ctx.buffer_mel = mel[:, :, seek:]
|
||||
ctx.buffer_mel = mel[:, seek:]
|
||||
assert ctx.buffer_mel is not None
|
||||
logger.debug(f"ctx.buffer_mel.shape: {ctx.buffer_mel.shape}")
|
||||
del mel
|
||||
|
|
Loading…
Reference in a new issue