diff --git a/whispering/cli.py b/whispering/cli.py index 863c5c7..040e1d9 100644 --- a/whispering/cli.py +++ b/whispering/cli.py @@ -12,7 +12,7 @@ from whisper import available_models from whisper.audio import N_FRAMES, SAMPLE_RATE from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE -from whispering.schema import WhisperConfig +from whispering.schema import Context, WhisperConfig from whispering.serve import serve_with_websocket from whispering.transcriber import WhisperStreamingTranscriber from whispering.websocket_client import run_websocket_client @@ -33,6 +33,7 @@ def transcribe_from_mic( logger.warning(status) q.put(indata.ravel()) + ctx: Context = Context() logger.info("Ready to transcribe") with sd.InputStream( samplerate=SAMPLE_RATE, @@ -46,7 +47,7 @@ def transcribe_from_mic( while True: logger.debug(f"Segment #: {idx}, The rest of queue: {q.qsize()}") segment = q.get() - for chunk in wsp.transcribe(segment=segment): + for chunk in wsp.transcribe(segment=segment, ctx=ctx): print(f"{chunk.start:.2f}->{chunk.end:.2f}\t{chunk.text}") idx += 1 diff --git a/whispering/schema.py b/whispering/schema.py index 1bcbc5e..018bc22 100644 --- a/whispering/schema.py +++ b/whispering/schema.py @@ -2,6 +2,7 @@ from typing import List, Optional +import torch from pydantic import BaseModel @@ -23,6 +24,12 @@ class WhisperConfig(BaseModel): compression_ratio_threshold: Optional[float] = 2.4 +class Context(BaseModel, arbitrary_types_allowed=True): + timestamp: float = 0.0 + buffer_tokens: List[torch.Tensor] = [] + buffer_mel: Optional[torch.Tensor] = None + + class ParsedChunk(BaseModel): start: float end: float diff --git a/whispering/serve.py b/whispering/serve.py index eee2c94..b38a26d 100644 --- a/whispering/serve.py +++ b/whispering/serve.py @@ -6,7 +6,7 @@ from logging import getLogger import numpy as np import websockets -from whispering.transcriber import WhisperStreamingTranscriber +from whispering.transcriber import Context, WhisperStreamingTranscriber logger = getLogger(__name__) @@ -14,6 +14,7 @@ logger = getLogger(__name__) async def serve_with_websocket_main(websocket): global g_wsp idx: int = 0 + ctx: Context = Context() while True: logger.debug(f"Segment #: {idx}") @@ -25,7 +26,7 @@ async def serve_with_websocket_main(websocket): logger.debug(f"Message size: {len(message)}") segment = np.frombuffer(message, dtype=np.float32) - for chunk in g_wsp.transcribe(segment=segment): + for chunk in g_wsp.transcribe(segment=segment, ctx=ctx): await websocket.send(chunk.json()) idx += 1 diff --git a/whispering/transcriber.py b/whispering/transcriber.py index 3e2140a..5af84a6 100644 --- a/whispering/transcriber.py +++ b/whispering/transcriber.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 from logging import getLogger -from typing import Iterator, List, Optional, Union +from typing import Final, Iterator, List, Optional, Union import numpy as np import torch @@ -17,7 +17,7 @@ from whisper.decoding import DecodingOptions, DecodingResult from whisper.tokenizer import get_tokenizer from whisper.utils import exact_div -from whispering.schema import ParsedChunk, WhisperConfig +from whispering.schema import Context, ParsedChunk, WhisperConfig logger = getLogger(__name__) @@ -37,25 +37,21 @@ class WhisperStreamingTranscriber: self.fp16 = False def __init__(self, *, config: WhisperConfig): - self.config: WhisperConfig = config - self.model: Whisper = load_model(config.model_name, device=config.device) + self.config: Final[WhisperConfig] = config + self.model: Final[Whisper] = load_model(config.model_name, device=config.device) self.tokenizer = get_tokenizer( self.model.is_multilingual, language=config.language, task="transcribe", ) self._set_dtype(config.fp16) - self.timestamp: float = 0.0 - self.input_stride = exact_div( + self.input_stride: Final[int] = exact_div( N_FRAMES, self.model.dims.n_audio_ctx ) # mel frames per output token: 2 - self.time_precision = ( + self.time_precision: Final[float] = ( self.input_stride * HOP_LENGTH / SAMPLE_RATE ) # time per output token: 0.02 (seconds) - self.buffer_tokens = [] - self.buffer_mel = None - def _get_decoding_options( self, *, @@ -87,6 +83,7 @@ class WhisperStreamingTranscriber: self, *, segment: np.ndarray, + ctx: Context, ) -> List[DecodingResult]: assert len(self.config.temperatures) >= 1 t = self.config.temperatures[0] @@ -94,7 +91,7 @@ class WhisperStreamingTranscriber: _decode_options1: DecodingOptions = self._get_decoding_options( t=t, - prompt=self.buffer_tokens, + prompt=ctx.buffer_tokens, beam_size=self.config.beam_size, patience=0.0, best_of=None, @@ -115,7 +112,7 @@ class WhisperStreamingTranscriber: ) _decode_options2: DecodingOptions = self._get_decoding_options( t=t, - prompt=self.buffer_tokens, + prompt=ctx.buffer_tokens, beam_size=None, patience=0.0, best_of=self.config.best_of, @@ -158,7 +155,11 @@ class WhisperStreamingTranscriber: ) def _deal_timestamp( - self, *, result, segment_duration + self, + *, + result, + segment_duration, + ctx: Context, ) -> Iterator[Union[ParsedChunk, int]]: tokens = torch.tensor(result.tokens) timestamp_tokens: torch.Tensor = tokens.ge(self.tokenizer.timestamp_begin) @@ -182,9 +183,9 @@ class WhisperStreamingTranscriber: sliced_tokens[-1].item() - self.tokenizer.timestamp_begin ) chunk = self._get_chunk( - start=self.timestamp + start=ctx.timestamp + start_timestamp_position * self.time_precision, - end=self.timestamp + end_timestamp_position * self.time_precision, + end=ctx.timestamp + end_timestamp_position * self.time_precision, text_tokens=sliced_tokens[1:-1], result=result, ) @@ -195,8 +196,8 @@ class WhisperStreamingTranscriber: tokens[last_slice - 1].item() - self.tokenizer.timestamp_begin # type:ignore ) - self.buffer_tokens.extend(tokens[: last_slice + 1].tolist()) - self.timestamp += last_timestamp_position0 * self.time_precision + ctx.buffer_tokens.extend(tokens[: last_slice + 1].tolist()) + ctx.timestamp += last_timestamp_position0 * self.time_precision yield last_timestamp_position0 else: duration = segment_duration @@ -211,34 +212,35 @@ class WhisperStreamingTranscriber: duration = last_timestamp_position * self.time_precision logger.debug(f"segment_duration: {segment_duration}, Duration: {duration}") chunk = self._get_chunk( - start=self.timestamp, - end=self.timestamp + duration, + start=ctx.timestamp, + end=ctx.timestamp + duration, text_tokens=tokens, result=result, ) if chunk is not None: yield chunk - self.timestamp += duration + ctx.timestamp += duration if result.temperature > 0.5: # do not feed the prompt tokens if a high temperature was used - del self.buffer_tokens - self.buffer_tokens = [] - logger.debug(f"Length of buffer: {len(self.buffer_tokens)}") + del ctx.buffer_tokens + ctx.buffer_tokens = [] + logger.debug(f"Length of buffer: {len(ctx.buffer_tokens)}") def transcribe( self, *, segment: np.ndarray, + ctx: Context, ) -> Iterator[ParsedChunk]: new_mel = log_mel_spectrogram(audio=segment).unsqueeze(0) logger.debug(f"Incoming new_mel.shape: {new_mel.shape}") - if self.buffer_mel is None: + if ctx.buffer_mel is None: mel = new_mel else: - logger.debug(f"buffer_mel.shape: {self.buffer_mel.shape}") - mel = torch.cat([self.buffer_mel, new_mel], dim=-1) - self.buffer_mel = None + logger.debug(f"buffer_mel.shape: {ctx.buffer_mel.shape}") + mel = torch.cat([ctx.buffer_mel, new_mel], dim=-1) + ctx.buffer_mel = None logger.debug(f"mel.shape: {mel.shape}") seek: int = 0 @@ -252,11 +254,12 @@ class WhisperStreamingTranscriber: logger.warning("Padding is not expected while speaking") logger.debug( - f"seek={seek}, timestamp={self.timestamp}, " + f"seek={seek}, timestamp={ctx.timestamp}, " f"mel.shape: {mel.shape}, segment.shape: {segment.shape}" ) results = self._decode_with_fallback( segment=segment, + ctx=ctx, ) result = results[0] logger.debug( @@ -278,7 +281,9 @@ class WhisperStreamingTranscriber: segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE last_timestamp_position: Optional[int] = None for v in self._deal_timestamp( - result=result, segment_duration=segment_duration + result=result, + segment_duration=segment_duration, + ctx=ctx, ): if isinstance(v, int): last_timestamp_position = v @@ -294,8 +299,9 @@ class WhisperStreamingTranscriber: break if mel.shape[-1] - seek <= 0: - logger.debug(f"self.buffer_mel is None ({mel.shape}, {seek})") + logger.debug(f"ctx.buffer_mel is None ({mel.shape}, {seek})") return - self.buffer_mel = mel[:, :, seek:] - logger.debug(f"self.buffer_mel.shape: {self.buffer_mel.shape}") + ctx.buffer_mel = mel[:, :, seek:] + assert ctx.buffer_mel is not None + logger.debug(f"ctx.buffer_mel.shape: {ctx.buffer_mel.shape}") del mel