This commit is contained in:
Yuta Hayashibe 2022-10-01 23:20:46 +09:00
parent 95dafe6740
commit f46e71f51d
4 changed files with 47 additions and 54 deletions

6
poetry.lock generated
View file

@ -508,13 +508,13 @@ dev = ["pytest"]
[package.source] [package.source]
type = "git" type = "git"
url = "https://github.com/openai/whisper.git" url = "https://github.com/openai/whisper.git"
reference = '62fe7f1009a534986ac1d32a4aef8c244d029c28' reference = '0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f'
resolved_reference = "62fe7f1009a534986ac1d32a4aef8c244d029c28" resolved_reference = "0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f"
[metadata] [metadata]
lock-version = "1.1" lock-version = "1.1"
python-versions = ">=3.8,<3.11" python-versions = ">=3.8,<3.11"
content-hash = "d041d21a202339f405cc37076403f92135ee1f113cdfece5a78c9ee12374be7b" content-hash = "f5395ffab6ce7d95246143218e948308d6614929f375489eb2b94a863e15fcc4"
[metadata.files] [metadata.files]
black = [ black = [

View file

@ -8,7 +8,7 @@ packages = [{include = "whispering"}]
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.8,<3.11" python = ">=3.8,<3.11"
whisper = {git = "https://github.com/openai/whisper.git", rev = '62fe7f1009a534986ac1d32a4aef8c244d029c28'} whisper = {git = "https://github.com/openai/whisper.git", rev = '0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f'}
sounddevice = "^0.4.5" sounddevice = "^0.4.5"
pydantic = "^1.10.2" pydantic = "^1.10.2"
websockets = "^10.3" websockets = "^10.3"

View file

@ -26,7 +26,10 @@ async def serve_with_websocket_main(websocket):
logger.debug(f"Message size: {len(message)}") logger.debug(f"Message size: {len(message)}")
segment = np.frombuffer(message, dtype=np.float32) segment = np.frombuffer(message, dtype=np.float32)
for chunk in g_wsp.transcribe(segment=segment, ctx=g_ctx): for chunk in g_wsp.transcribe(
segment=segment, # type: ignore
ctx=g_ctx,
):
await websocket.send(chunk.json()) await websocket.send(chunk.json())
idx += 1 idx += 1

View file

@ -1,9 +1,8 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from logging import getLogger from logging import getLogger
from typing import Final, Iterator, List, Optional, Union from typing import Final, Iterator, Optional, Union
import numpy as np
import torch import torch
from whisper import Whisper, load_model from whisper import Whisper, load_model
from whisper.audio import ( from whisper.audio import (
@ -75,59 +74,51 @@ class WhisperStreamingTranscriber:
suppress_blank=True, suppress_blank=True,
suppress_tokens="-1", suppress_tokens="-1",
without_timestamps=False, without_timestamps=False,
max_initial_timestamp=0.0, max_initial_timestamp=1.0,
fp16=self.fp16, fp16=self.fp16,
) )
def _decode_with_fallback( def _decode_with_fallback(
self, self,
*, *,
segment: np.ndarray, segment: torch.Tensor,
ctx: Context, ctx: Context,
) -> List[DecodingResult]: ) -> DecodingResult:
assert len(ctx.temperatures) >= 1 assert len(ctx.temperatures) >= 1
t = ctx.temperatures[0] decode_result: Optional[DecodingResult] = None
logger.debug(f"temperature: {t}")
_decode_options1: DecodingOptions = self._get_decoding_options( for t in ctx.temperatures:
_decode_options: DecodingOptions = self._get_decoding_options(
t=t, t=t,
prompt=ctx.buffer_tokens, prompt=ctx.buffer_tokens,
beam_size=ctx.beam_size, beam_size=ctx.beam_size if t <= 0 else None,
patience=ctx.patience, patience=ctx.patience if t <= 0 else None,
best_of=None, best_of=ctx.best_of if t < 0 else None,
) )
results: List[DecodingResult] = self.model.decode(segment, _decode_options1) # type: ignore logger.debug(f"DecodeOptions: {_decode_options}")
decode_result = self.model.decode(
segment,
_decode_options,
) # type: ignore
assert decode_result is not None
for t in ctx.temperatures[1:]: needs_fallback: bool = False
needs_fallback = [ if (
ctx.compression_ratio_threshold is not None ctx.compression_ratio_threshold is not None
and result.compression_ratio > ctx.compression_ratio_threshold and decode_result.compression_ratio > ctx.compression_ratio_threshold
or ctx.logprob_threshold is not None
and result.avg_logprob < ctx.logprob_threshold
for result in results
]
if any(needs_fallback):
logger.debug(
f"Fall back with temperature: {t}, needs_fallback: {needs_fallback}"
)
_decode_options2: DecodingOptions = self._get_decoding_options(
t=t,
prompt=ctx.buffer_tokens,
beam_size=None,
patience=None,
best_of=ctx.best_of,
)
retries: List[DecodingResult] = self.model.decode(
segment[needs_fallback], _decode_options2 # type: ignore
)
for retry_index, original_index in enumerate(
np.nonzero(needs_fallback)[0]
): ):
results[original_index] = retries[retry_index] needs_fallback = True # too repetitive
else: if (
ctx.logprob_threshold is not None
and decode_result.avg_logprob < ctx.logprob_threshold
):
needs_fallback = True # average log probability is too low
if not needs_fallback:
break break
logger.debug(f"# of results: {len(results)}")
return results assert isinstance(decode_result, DecodingResult)
return decode_result
def _get_chunk( def _get_chunk(
self, self,
@ -233,10 +224,10 @@ class WhisperStreamingTranscriber:
def transcribe( def transcribe(
self, self,
*, *,
segment: np.ndarray, segment: torch.Tensor,
ctx: Context, ctx: Context,
) -> Iterator[ParsedChunk]: ) -> Iterator[ParsedChunk]:
new_mel = log_mel_spectrogram(audio=segment).unsqueeze(0) new_mel = log_mel_spectrogram(audio=segment)
logger.debug(f"Incoming new_mel.shape: {new_mel.shape}") logger.debug(f"Incoming new_mel.shape: {new_mel.shape}")
if ctx.buffer_mel is None: if ctx.buffer_mel is None:
mel = new_mel mel = new_mel
@ -249,7 +240,7 @@ class WhisperStreamingTranscriber:
seek: int = 0 seek: int = 0
while seek < mel.shape[-1]: while seek < mel.shape[-1]:
segment = ( segment = (
pad_or_trim(mel[:, :, seek:], N_FRAMES) pad_or_trim(mel[:, seek:], N_FRAMES)
.to(self.model.device) # type: ignore .to(self.model.device) # type: ignore
.to(self.dtype) .to(self.dtype)
) )
@ -260,11 +251,10 @@ class WhisperStreamingTranscriber:
f"seek={seek}, timestamp={ctx.timestamp}, " f"seek={seek}, timestamp={ctx.timestamp}, "
f"mel.shape: {mel.shape}, segment.shape: {segment.shape}" f"mel.shape: {mel.shape}, segment.shape: {segment.shape}"
) )
results = self._decode_with_fallback( result = self._decode_with_fallback(
segment=segment, segment=segment,
ctx=ctx, ctx=ctx,
) )
result = results[0]
logger.debug( logger.debug(
f"Result: temperature={result.temperature:.2f}, no_speech_prob={result.no_speech_prob:.2f}, " f"Result: temperature={result.temperature:.2f}, no_speech_prob={result.no_speech_prob:.2f}, "
f"avg_logprob={result.avg_logprob:.2f}" f"avg_logprob={result.avg_logprob:.2f}"
@ -304,7 +294,7 @@ class WhisperStreamingTranscriber:
if mel.shape[-1] - seek <= 0: if mel.shape[-1] - seek <= 0:
logger.debug(f"ctx.buffer_mel is None ({mel.shape}, {seek})") logger.debug(f"ctx.buffer_mel is None ({mel.shape}, {seek})")
return return
ctx.buffer_mel = mel[:, :, seek:] ctx.buffer_mel = mel[:, seek:]
assert ctx.buffer_mel is not None assert ctx.buffer_mel is not None
logger.debug(f"ctx.buffer_mel.shape: {ctx.buffer_mel.shape}") logger.debug(f"ctx.buffer_mel.shape: {ctx.buffer_mel.shape}")
del mel del mel