Merge remote-tracking branch 'origin/master' into vad

This commit is contained in:
Yuta Hayashibe 2022-10-02 19:39:33 +09:00
commit 45eb0bc34d
8 changed files with 91 additions and 73 deletions

View file

@ -32,7 +32,7 @@ whispering --language en --model tiny
- ``--model`` set the [model name](https://github.com/openai/whisper#available-models-and-languages) to use. Larger models will be more accurate, but may not be able to transcribe in real time. - ``--model`` set the [model name](https://github.com/openai/whisper#available-models-and-languages) to use. Larger models will be more accurate, but may not be able to transcribe in real time.
- ``--language`` sets the language to transcribe. The list of languages are shown with ``whispering -h`` - ``--language`` sets the language to transcribe. The list of languages are shown with ``whispering -h``
- ``--no-progress`` disables the progress message - ``--no-progress`` disables the progress message
- ``-t`` sets temperatures to decode. You can set several like (``-t 0.0 -t 0.1 -t 0.5``), but too many temperatures exhaust decoding time - ``-t`` sets temperatures to decode. You can set several like ``-t 0.0 -t 0.1 -t 0.5``, but too many temperatures exhaust decoding time
- ``--debug`` outputs logs for debug - ``--debug`` outputs logs for debug
### Parse interval ### Parse interval

6
poetry.lock generated
View file

@ -519,13 +519,13 @@ dev = ["pytest"]
[package.source] [package.source]
type = "git" type = "git"
url = "https://github.com/openai/whisper.git" url = "https://github.com/openai/whisper.git"
reference = '62fe7f1009a534986ac1d32a4aef8c244d029c28' reference = '0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f'
resolved_reference = "62fe7f1009a534986ac1d32a4aef8c244d029c28" resolved_reference = "0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f"
[metadata] [metadata]
lock-version = "1.1" lock-version = "1.1"
python-versions = ">=3.8,<3.11" python-versions = ">=3.8,<3.11"
content-hash = "75e53434d1d46d54a886ca7a896a2f0ba0072a1848f90d5b6dc46ea2c5b47191" content-hash = "ab527970383bc2245dee005627d0695812601115a36e15a5ef9e66d1185791bf"
[metadata.files] [metadata.files]
black = [ black = [

View file

@ -8,7 +8,7 @@ packages = [{include = "whispering"}]
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.8,<3.11" python = ">=3.8,<3.11"
whisper = {git = "https://github.com/openai/whisper.git", rev = '62fe7f1009a534986ac1d32a4aef8c244d029c28'} whisper = {git = "https://github.com/openai/whisper.git", rev = '0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f'}
sounddevice = "^0.4.5" sounddevice = "^0.4.5"
pydantic = "^1.10.2" pydantic = "^1.10.2"
websockets = "^10.3" websockets = "^10.3"

View file

@ -214,11 +214,14 @@ def main() -> None:
if opts.mode == "client": if opts.mode == "client":
assert opts.language is None assert opts.language is None
assert opts.model is None assert opts.model is None
asyncio.run( try:
run_websocket_client( asyncio.run(
opts=opts, run_websocket_client(
opts=opts,
)
) )
) except KeyboardInterrupt:
pass
else: else:
assert opts.language is not None assert opts.language is not None
assert opts.model is not None assert opts.model is not None

View file

@ -2,7 +2,6 @@
from typing import List, Optional from typing import List, Optional
import numpy as np
import torch import torch
from pydantic import BaseModel, root_validator from pydantic import BaseModel, root_validator
@ -56,4 +55,4 @@ class ParsedChunk(BaseModel):
class SpeechSegment(BaseModel, arbitrary_types_allowed=True): class SpeechSegment(BaseModel, arbitrary_types_allowed=True):
start_block_idx: int start_block_idx: int
end_block_idx: int end_block_idx: int
segment: np.ndarray segment: torch.Tensor

View file

@ -5,6 +5,7 @@ from logging import getLogger
import numpy as np import numpy as np
import websockets import websockets
from websockets.exceptions import ConnectionClosedOK
from whispering.transcriber import Context, WhisperStreamingTranscriber from whispering.transcriber import Context, WhisperStreamingTranscriber
@ -15,10 +16,16 @@ async def serve_with_websocket_main(websocket):
global g_wsp global g_wsp
global g_ctx global g_ctx
idx: int = 0 idx: int = 0
ctx: Context = g_ctx.copy(
deep=True,
)
while True: while True:
logger.debug(f"Segment #: {idx}") logger.debug(f"Segment #: {idx}")
message = await websocket.recv() try:
message = await websocket.recv()
except ConnectionClosedOK:
break
if isinstance(message, str): if isinstance(message, str):
logger.debug(f"Got str: {message}") logger.debug(f"Got str: {message}")
@ -26,7 +33,10 @@ async def serve_with_websocket_main(websocket):
logger.debug(f"Message size: {len(message)}") logger.debug(f"Message size: {len(message)}")
segment = np.frombuffer(message, dtype=np.float32) segment = np.frombuffer(message, dtype=np.float32)
for chunk in g_wsp.transcribe(segment=segment, ctx=g_ctx): for chunk in g_wsp.transcribe(
segment=segment, # type: ignore
ctx=ctx,
):
await websocket.send(chunk.json()) await websocket.send(chunk.json())
idx += 1 idx += 1

View file

@ -1,9 +1,8 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from logging import getLogger from logging import getLogger
from typing import Final, Iterator, List, Optional, Union from typing import Final, Iterator, Optional, Union
import numpy as np
import torch import torch
from whisper import Whisper, load_model from whisper import Whisper, load_model
from whisper.audio import ( from whisper.audio import (
@ -77,59 +76,51 @@ class WhisperStreamingTranscriber:
suppress_blank=True, suppress_blank=True,
suppress_tokens="-1", suppress_tokens="-1",
without_timestamps=False, without_timestamps=False,
max_initial_timestamp=0.0, max_initial_timestamp=1.0,
fp16=self.fp16, fp16=self.fp16,
) )
def _decode_with_fallback( def _decode_with_fallback(
self, self,
*, *,
segment: np.ndarray, segment: torch.Tensor,
ctx: Context, ctx: Context,
) -> List[DecodingResult]: ) -> DecodingResult:
assert len(ctx.temperatures) >= 1 assert len(ctx.temperatures) >= 1
t = ctx.temperatures[0] decode_result: Optional[DecodingResult] = None
logger.debug(f"temperature: {t}")
_decode_options1: DecodingOptions = self._get_decoding_options( for t in ctx.temperatures:
t=t, _decode_options: DecodingOptions = self._get_decoding_options(
prompt=ctx.buffer_tokens, t=t,
beam_size=ctx.beam_size, prompt=ctx.buffer_tokens,
patience=None, beam_size=ctx.beam_size if t <= 0 else None,
best_of=None, patience=ctx.patience if t <= 0 else None,
) best_of=ctx.best_of if t < 0 else None,
results: List[DecodingResult] = self.model.decode(segment, _decode_options1) # type: ignore )
logger.debug(f"DecodeOptions: {_decode_options}")
decode_result = self.model.decode(
segment,
_decode_options,
) # type: ignore
assert decode_result is not None
for t in ctx.temperatures[1:]: needs_fallback: bool = False
needs_fallback = [ if (
ctx.compression_ratio_threshold is not None ctx.compression_ratio_threshold is not None
and result.compression_ratio > ctx.compression_ratio_threshold and decode_result.compression_ratio > ctx.compression_ratio_threshold
or ctx.logprob_threshold is not None ):
and result.avg_logprob < ctx.logprob_threshold needs_fallback = True # too repetitive
for result in results if (
] ctx.logprob_threshold is not None
if any(needs_fallback): and decode_result.avg_logprob < ctx.logprob_threshold
logger.debug( ):
f"Fall back with temperature: {t}, needs_fallback: {needs_fallback}" needs_fallback = True # average log probability is too low
)
_decode_options2: DecodingOptions = self._get_decoding_options( if not needs_fallback:
t=t,
prompt=ctx.buffer_tokens,
beam_size=None,
patience=ctx.patience,
best_of=ctx.best_of,
)
retries: List[DecodingResult] = self.model.decode(
segment[needs_fallback], _decode_options2 # type: ignore
)
for retry_index, original_index in enumerate(
np.nonzero(needs_fallback)[0]
):
results[original_index] = retries[retry_index]
else:
break break
logger.debug(f"# of results: {len(results)}")
return results assert isinstance(decode_result, DecodingResult)
return decode_result
def _get_chunk( def _get_chunk(
self, self,
@ -205,7 +196,10 @@ class WhisperStreamingTranscriber:
duration = segment_duration duration = segment_duration
timestamps = tokens[timestamp_tokens.nonzero().flatten()] timestamps = tokens[timestamp_tokens.nonzero().flatten()]
logger.debug(f"Length of consecutive: 0, timestamps: {timestamps}") logger.debug(f"Length of consecutive: 0, timestamps: {timestamps}")
if len(timestamps) > 0: if (
len(timestamps) > 0
and timestamps[-1].item() != self.tokenizer.timestamp_begin
):
# no consecutive timestamps but it has a timestamp; use the last one. # no consecutive timestamps but it has a timestamp; use the last one.
# single timestamp at the end means no speech after the last timestamp. # single timestamp at the end means no speech after the last timestamp.
last_timestamp_position = ( last_timestamp_position = (
@ -232,13 +226,13 @@ class WhisperStreamingTranscriber:
def transcribe( def transcribe(
self, self,
*, *,
segment: np.ndarray, segment: torch.Tensor,
ctx: Context, ctx: Context,
) -> Iterator[ParsedChunk]: ) -> Iterator[ParsedChunk]:
vad_probs = self.vad(segment) for speech_segment in self.vad(segment=segment):
logger.debug(f"{vad_probs}") logger.debug(f"{speech_segment}")
new_mel = log_mel_spectrogram(audio=segment).unsqueeze(0) new_mel = log_mel_spectrogram(audio=segment)
logger.debug(f"Incoming new_mel.shape: {new_mel.shape}") logger.debug(f"Incoming new_mel.shape: {new_mel.shape}")
if ctx.buffer_mel is None: if ctx.buffer_mel is None:
mel = new_mel mel = new_mel
@ -251,7 +245,7 @@ class WhisperStreamingTranscriber:
seek: int = 0 seek: int = 0
while seek < mel.shape[-1]: while seek < mel.shape[-1]:
segment = ( segment = (
pad_or_trim(mel[:, :, seek:], N_FRAMES) pad_or_trim(mel[:, seek:], N_FRAMES)
.to(self.model.device) # type: ignore .to(self.model.device) # type: ignore
.to(self.dtype) .to(self.dtype)
) )
@ -262,11 +256,10 @@ class WhisperStreamingTranscriber:
f"seek={seek}, timestamp={ctx.timestamp}, " f"seek={seek}, timestamp={ctx.timestamp}, "
f"mel.shape: {mel.shape}, segment.shape: {segment.shape}" f"mel.shape: {mel.shape}, segment.shape: {segment.shape}"
) )
results = self._decode_with_fallback( result = self._decode_with_fallback(
segment=segment, segment=segment,
ctx=ctx, ctx=ctx,
) )
result = results[0]
logger.debug( logger.debug(
f"Result: temperature={result.temperature:.2f}, no_speech_prob={result.no_speech_prob:.2f}, " f"Result: temperature={result.temperature:.2f}, no_speech_prob={result.no_speech_prob:.2f}, "
f"avg_logprob={result.avg_logprob:.2f}" f"avg_logprob={result.avg_logprob:.2f}"
@ -306,7 +299,7 @@ class WhisperStreamingTranscriber:
if mel.shape[-1] - seek <= 0: if mel.shape[-1] - seek <= 0:
logger.debug(f"ctx.buffer_mel is None ({mel.shape}, {seek})") logger.debug(f"ctx.buffer_mel is None ({mel.shape}, {seek})")
return return
ctx.buffer_mel = mel[:, :, seek:] ctx.buffer_mel = mel[:, seek:]
assert ctx.buffer_mel is not None assert ctx.buffer_mel is not None
logger.debug(f"ctx.buffer_mel.shape: {ctx.buffer_mel.shape}") logger.debug(f"ctx.buffer_mel.shape: {ctx.buffer_mel.shape}")
del mel del mel

View file

@ -2,7 +2,6 @@
from typing import Iterator from typing import Iterator
import numpy as np
import torch import torch
from whisper.audio import N_FRAMES, SAMPLE_RATE from whisper.audio import N_FRAMES, SAMPLE_RATE
@ -21,16 +20,26 @@ class VAD:
def __call__( def __call__(
self, self,
*, *,
segment: np.ndarray, segment: torch.Tensor,
thredhold: float = 0.5, thredhold: float = 0.5,
) -> Iterator[SpeechBlock]: ) -> Iterator[SpeechSegment]:
# segment.shape should be multiple of (N_FRAMES,) # segment.shape should be multiple of (N_FRAMES,)
def my_ret(
*,
start_block_idx: int,
idx: int,
) -> SpeechSegment:
return SpeechSegment(
start_block_idx=start_block_idx,
end_block_idx=idx,
segment=segment[N_FRAMES * start_block_idx : N_FRAMES * idx],
)
block_size: int = int(segment.shape[0] / N_FRAMES) block_size: int = int(segment.shape[0] / N_FRAMES)
start_block_idx = None start_block_idx = None
for idx in range(block_size + 1): for idx in range(block_size):
if idx < block_size:
start: int = N_FRAMES * idx start: int = N_FRAMES * idx
end: int = N_FRAMES * (idx + 1) end: int = N_FRAMES * (idx + 1)
vad_prob = self.vad_model( vad_prob = self.vad_model(
@ -42,9 +51,13 @@ class VAD:
start_block_idx = idx start_block_idx = idx
else: else:
if start_block_idx is not None: if start_block_idx is not None:
yield SpeechSegment( yield my_ret(
start_block_idx=start_block_idx, start_block_idx=start_block_idx,
end_block_idx=idx, idx=idx,
segment=segment[N_FRAMES * start_block_idx : N_FRAMES * idx],
) )
start_block_idx = None start_block_idx = None
if start_block_idx is not None:
yield my_ret(
start_block_idx=start_block_idx,
idx=block_size,
)