From b0c27c7ca9851a918b4b06e0e025958781d383ea Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Fri, 30 Sep 2022 03:03:46 +0900 Subject: [PATCH 1/6] Fix patience setting --- whispering/transcriber.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/whispering/transcriber.py b/whispering/transcriber.py index 059d3c4..ac2e3f1 100644 --- a/whispering/transcriber.py +++ b/whispering/transcriber.py @@ -93,7 +93,7 @@ class WhisperStreamingTranscriber: t=t, prompt=ctx.buffer_tokens, beam_size=ctx.beam_size, - patience=None, + patience=ctx.patience, best_of=None, ) results: List[DecodingResult] = self.model.decode(segment, _decode_options1) # type: ignore @@ -114,7 +114,7 @@ class WhisperStreamingTranscriber: t=t, prompt=ctx.buffer_tokens, beam_size=None, - patience=ctx.patience, + patience=None, best_of=ctx.best_of, ) retries: List[DecodingResult] = self.model.decode( From 1989e023a5bfff3a20a31d9e36270130fae2411e Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Fri, 30 Sep 2022 03:06:41 +0900 Subject: [PATCH 2/6] Fix README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index fe8a76f..042e410 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ whispering --language en --model tiny - ``--model`` set the [model name](https://github.com/openai/whisper#available-models-and-languages) to use. Larger models will be more accurate, but may not be able to transcribe in real time. - ``--language`` sets the language to transcribe. The list of languages are shown with ``whispering -h`` - ``--no-progress`` disables the progress message -- ``-t`` sets temperatures to decode. You can set several like (``-t 0.0 -t 0.1 -t 0.5``), but too many temperatures exhaust decoding time +- ``-t`` sets temperatures to decode. You can set several like ``-t 0.0 -t 0.1 -t 0.5``, but too many temperatures exhaust decoding time - ``--debug`` outputs logs for debug ### Parse interval From 95dafe67408919ab389be5f1c1153e1b8f5dc972 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Sat, 1 Oct 2022 22:45:45 +0900 Subject: [PATCH 3/6] Fix (openai/whisper@2b0c297) --- whispering/transcriber.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/whispering/transcriber.py b/whispering/transcriber.py index ac2e3f1..4c2a959 100644 --- a/whispering/transcriber.py +++ b/whispering/transcriber.py @@ -203,7 +203,10 @@ class WhisperStreamingTranscriber: duration = segment_duration timestamps = tokens[timestamp_tokens.nonzero().flatten()] logger.debug(f"Length of consecutive: 0, timestamps: {timestamps}") - if len(timestamps) > 0: + if ( + len(timestamps) > 0 + and timestamps[-1].item() != self.tokenizer.timestamp_begin + ): # no consecutive timestamps but it has a timestamp; use the last one. # single timestamp at the end means no speech after the last timestamp. last_timestamp_position = ( From f46e71f51dad8387bf012fe954915c9cf8009765 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Sat, 1 Oct 2022 23:20:46 +0900 Subject: [PATCH 4/6] Fix (openai/whisper@7cb4cc2) --- poetry.lock | 6 +-- pyproject.toml | 2 +- whispering/serve.py | 5 ++- whispering/transcriber.py | 88 +++++++++++++++++---------------------- 4 files changed, 47 insertions(+), 54 deletions(-) diff --git a/poetry.lock b/poetry.lock index 4f43d3a..5d91e76 100644 --- a/poetry.lock +++ b/poetry.lock @@ -508,13 +508,13 @@ dev = ["pytest"] [package.source] type = "git" url = "https://github.com/openai/whisper.git" -reference = '62fe7f1009a534986ac1d32a4aef8c244d029c28' -resolved_reference = "62fe7f1009a534986ac1d32a4aef8c244d029c28" +reference = '0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f' +resolved_reference = "0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f" [metadata] lock-version = "1.1" python-versions = ">=3.8,<3.11" -content-hash = "d041d21a202339f405cc37076403f92135ee1f113cdfece5a78c9ee12374be7b" +content-hash = "f5395ffab6ce7d95246143218e948308d6614929f375489eb2b94a863e15fcc4" [metadata.files] black = [ diff --git a/pyproject.toml b/pyproject.toml index 8dbf250..8a5f484 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ packages = [{include = "whispering"}] [tool.poetry.dependencies] python = ">=3.8,<3.11" -whisper = {git = "https://github.com/openai/whisper.git", rev = '62fe7f1009a534986ac1d32a4aef8c244d029c28'} +whisper = {git = "https://github.com/openai/whisper.git", rev = '0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f'} sounddevice = "^0.4.5" pydantic = "^1.10.2" websockets = "^10.3" diff --git a/whispering/serve.py b/whispering/serve.py index fab8952..198a4c8 100644 --- a/whispering/serve.py +++ b/whispering/serve.py @@ -26,7 +26,10 @@ async def serve_with_websocket_main(websocket): logger.debug(f"Message size: {len(message)}") segment = np.frombuffer(message, dtype=np.float32) - for chunk in g_wsp.transcribe(segment=segment, ctx=g_ctx): + for chunk in g_wsp.transcribe( + segment=segment, # type: ignore + ctx=g_ctx, + ): await websocket.send(chunk.json()) idx += 1 diff --git a/whispering/transcriber.py b/whispering/transcriber.py index 4c2a959..f0c353a 100644 --- a/whispering/transcriber.py +++ b/whispering/transcriber.py @@ -1,9 +1,8 @@ #!/usr/bin/env python3 from logging import getLogger -from typing import Final, Iterator, List, Optional, Union +from typing import Final, Iterator, Optional, Union -import numpy as np import torch from whisper import Whisper, load_model from whisper.audio import ( @@ -75,59 +74,51 @@ class WhisperStreamingTranscriber: suppress_blank=True, suppress_tokens="-1", without_timestamps=False, - max_initial_timestamp=0.0, + max_initial_timestamp=1.0, fp16=self.fp16, ) def _decode_with_fallback( self, *, - segment: np.ndarray, + segment: torch.Tensor, ctx: Context, - ) -> List[DecodingResult]: + ) -> DecodingResult: assert len(ctx.temperatures) >= 1 - t = ctx.temperatures[0] - logger.debug(f"temperature: {t}") + decode_result: Optional[DecodingResult] = None - _decode_options1: DecodingOptions = self._get_decoding_options( - t=t, - prompt=ctx.buffer_tokens, - beam_size=ctx.beam_size, - patience=ctx.patience, - best_of=None, - ) - results: List[DecodingResult] = self.model.decode(segment, _decode_options1) # type: ignore + for t in ctx.temperatures: + _decode_options: DecodingOptions = self._get_decoding_options( + t=t, + prompt=ctx.buffer_tokens, + beam_size=ctx.beam_size if t <= 0 else None, + patience=ctx.patience if t <= 0 else None, + best_of=ctx.best_of if t < 0 else None, + ) + logger.debug(f"DecodeOptions: {_decode_options}") + decode_result = self.model.decode( + segment, + _decode_options, + ) # type: ignore + assert decode_result is not None - for t in ctx.temperatures[1:]: - needs_fallback = [ + needs_fallback: bool = False + if ( ctx.compression_ratio_threshold is not None - and result.compression_ratio > ctx.compression_ratio_threshold - or ctx.logprob_threshold is not None - and result.avg_logprob < ctx.logprob_threshold - for result in results - ] - if any(needs_fallback): - logger.debug( - f"Fall back with temperature: {t}, needs_fallback: {needs_fallback}" - ) - _decode_options2: DecodingOptions = self._get_decoding_options( - t=t, - prompt=ctx.buffer_tokens, - beam_size=None, - patience=None, - best_of=ctx.best_of, - ) - retries: List[DecodingResult] = self.model.decode( - segment[needs_fallback], _decode_options2 # type: ignore - ) - for retry_index, original_index in enumerate( - np.nonzero(needs_fallback)[0] - ): - results[original_index] = retries[retry_index] - else: + and decode_result.compression_ratio > ctx.compression_ratio_threshold + ): + needs_fallback = True # too repetitive + if ( + ctx.logprob_threshold is not None + and decode_result.avg_logprob < ctx.logprob_threshold + ): + needs_fallback = True # average log probability is too low + + if not needs_fallback: break - logger.debug(f"# of results: {len(results)}") - return results + + assert isinstance(decode_result, DecodingResult) + return decode_result def _get_chunk( self, @@ -233,10 +224,10 @@ class WhisperStreamingTranscriber: def transcribe( self, *, - segment: np.ndarray, + segment: torch.Tensor, ctx: Context, ) -> Iterator[ParsedChunk]: - new_mel = log_mel_spectrogram(audio=segment).unsqueeze(0) + new_mel = log_mel_spectrogram(audio=segment) logger.debug(f"Incoming new_mel.shape: {new_mel.shape}") if ctx.buffer_mel is None: mel = new_mel @@ -249,7 +240,7 @@ class WhisperStreamingTranscriber: seek: int = 0 while seek < mel.shape[-1]: segment = ( - pad_or_trim(mel[:, :, seek:], N_FRAMES) + pad_or_trim(mel[:, seek:], N_FRAMES) .to(self.model.device) # type: ignore .to(self.dtype) ) @@ -260,11 +251,10 @@ class WhisperStreamingTranscriber: f"seek={seek}, timestamp={ctx.timestamp}, " f"mel.shape: {mel.shape}, segment.shape: {segment.shape}" ) - results = self._decode_with_fallback( + result = self._decode_with_fallback( segment=segment, ctx=ctx, ) - result = results[0] logger.debug( f"Result: temperature={result.temperature:.2f}, no_speech_prob={result.no_speech_prob:.2f}, " f"avg_logprob={result.avg_logprob:.2f}" @@ -304,7 +294,7 @@ class WhisperStreamingTranscriber: if mel.shape[-1] - seek <= 0: logger.debug(f"ctx.buffer_mel is None ({mel.shape}, {seek})") return - ctx.buffer_mel = mel[:, :, seek:] + ctx.buffer_mel = mel[:, seek:] assert ctx.buffer_mel is not None logger.debug(f"ctx.buffer_mel.shape: {ctx.buffer_mel.shape}") del mel From b1f98f86ff058cbf9f678312dc831feaf671632a Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Sat, 1 Oct 2022 23:43:50 +0900 Subject: [PATCH 5/6] Catch KeyboardInterrupt --- whispering/cli.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/whispering/cli.py b/whispering/cli.py index af93f98..8b60271 100644 --- a/whispering/cli.py +++ b/whispering/cli.py @@ -214,11 +214,14 @@ def main() -> None: if opts.mode == "client": assert opts.language is None assert opts.model is None - asyncio.run( - run_websocket_client( - opts=opts, + try: + asyncio.run( + run_websocket_client( + opts=opts, + ) ) - ) + except KeyboardInterrupt: + pass else: assert opts.language is not None assert opts.model is not None From e44d1ef13d95c58f2f2269cbcdf05809191954aa Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Sat, 1 Oct 2022 23:55:12 +0900 Subject: [PATCH 6/6] Use different context for each connection --- whispering/serve.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/whispering/serve.py b/whispering/serve.py index 198a4c8..20d50fb 100644 --- a/whispering/serve.py +++ b/whispering/serve.py @@ -5,6 +5,7 @@ from logging import getLogger import numpy as np import websockets +from websockets.exceptions import ConnectionClosedOK from whispering.transcriber import Context, WhisperStreamingTranscriber @@ -15,10 +16,16 @@ async def serve_with_websocket_main(websocket): global g_wsp global g_ctx idx: int = 0 + ctx: Context = g_ctx.copy( + deep=True, + ) while True: logger.debug(f"Segment #: {idx}") - message = await websocket.recv() + try: + message = await websocket.recv() + except ConnectionClosedOK: + break if isinstance(message, str): logger.debug(f"Got str: {message}") @@ -28,7 +35,7 @@ async def serve_with_websocket_main(websocket): segment = np.frombuffer(message, dtype=np.float32) for chunk in g_wsp.transcribe( segment=segment, # type: ignore - ctx=g_ctx, + ctx=ctx, ): await websocket.send(chunk.json()) idx += 1