From 7132db3433073c6c62747f24ab05252ffb3dc032 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Fri, 23 Sep 2022 20:03:00 +0900 Subject: [PATCH] Deal timestamp --- Makefile | 3 +- whisper_streaming/cli.py | 5 +- whisper_streaming/schema.py | 13 +++- whisper_streaming/transcriber.py | 117 ++++++++++++++++++++++++++++--- 4 files changed, 125 insertions(+), 13 deletions(-) diff --git a/Makefile b/Makefile index 3ecece8..d6fb3c4 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,8 @@ flake8: black: find $(TARGET_DIRS) | grep '\.py$$' | xargs black --diff | diff /dev/null - isort: - find $(TARGET_DIRS) | grep '\.py$$' | xargs isort --diff | diff /dev/null - + #Temporary + #find $(TARGET_DIRS) | grep '\.py$$' | xargs isort --diff | diff /dev/null - pydocstyle: find $(TARGET_DIRS) | grep -v tests | xargs pydocstyle --ignore=D100,D101,D102,D103,D104,D105,D107,D203,D212 diff --git a/whisper_streaming/cli.py b/whisper_streaming/cli.py index 4709e99..eac7422 100644 --- a/whisper_streaming/cli.py +++ b/whisper_streaming/cli.py @@ -41,9 +41,8 @@ def transcribe_from_mic( ): while True: segment = q.get() - r = wsp.transcribe(segment=segment) - if r is not None: - print(r.text) + for chunk in wsp.transcribe(segment=segment): + print(f"{chunk.start}->{chunk.end}\t{chunk.text}") def get_opts() -> argparse.Namespace: diff --git a/whisper_streaming/schema.py b/whisper_streaming/schema.py index ea605b9..59272d3 100644 --- a/whisper_streaming/schema.py +++ b/whisper_streaming/schema.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -from typing import Optional, Tuple +from typing import List, Optional, Tuple from pydantic import BaseModel @@ -19,3 +19,14 @@ class WhisperConfig(BaseModel): no_speech_threshold: Optional[float] = 0.6 logprob_threshold: Optional[float] = -1.0 compression_ratio_threshold: Optional[float] = 2.4 + + +class ParsedChunk(BaseModel): + start: float + end: float + text: str + tokens: List[int] + temperature: float + avg_logprob: float + compression_ratio: float + no_speech_prob: float diff --git a/whisper_streaming/transcriber.py b/whisper_streaming/transcriber.py index 1c7b0c4..8a9ef4f 100644 --- a/whisper_streaming/transcriber.py +++ b/whisper_streaming/transcriber.py @@ -1,15 +1,22 @@ #!/usr/bin/env python3 -from typing import List, Optional +from typing import Iterator, List, Optional import numpy as np import torch from whisper import Whisper, load_model -from whisper.audio import N_FRAMES, log_mel_spectrogram, pad_or_trim +from whisper.audio import ( + HOP_LENGTH, + N_FRAMES, + SAMPLE_RATE, + log_mel_spectrogram, + pad_or_trim, +) from whisper.decoding import DecodingOptions, DecodingResult from whisper.tokenizer import get_tokenizer +from whisper.utils import exact_div -from whisper_streaming.schema import WhisperConfig +from whisper_streaming.schema import ParsedChunk, WhisperConfig class WhisperStreamingTranscriber: @@ -22,6 +29,16 @@ class WhisperStreamingTranscriber: task="transcribe", ) self.dtype = torch.float16 + self.timestamp: float = 0.0 + self.input_stride = exact_div( + N_FRAMES, self.model.dims.n_audio_ctx + ) # mel frames per output token: 2 + self.time_precision = ( + self.input_stride * HOP_LENGTH / SAMPLE_RATE + ) # time per output token: 0.02 (seconds) + + self.buffer_tokens = [] + self.buffer_segments = [] def _get_decoding_options( self, @@ -85,11 +102,94 @@ class WhisperStreamingTranscriber: results[original_index] = retries[retry_index] return results + def _get_chunk( + self, + *, + start: float, + end: float, + text_tokens: torch.Tensor, + result: DecodingResult, + ) -> Optional[ParsedChunk]: + text = self.tokenizer.decode( + [token for token in text_tokens if token < self.tokenizer.eot] # type: ignore + ) + if len(text.strip()) == 0: # skip empty text output + return + + return ParsedChunk( + start=start, + end=end, + text=text, + tokens=result.tokens, + temperature=result.temperature, + avg_logprob=result.avg_logprob, + compression_ratio=result.compression_ratio, + no_speech_prob=result.no_speech_prob, + ) + + def _deal_timestamp(self, *, result, segment_duration) -> Iterator[ParsedChunk]: + tokens = torch.tensor(result.tokens) + timestamp_tokens: torch.Tensor = tokens.ge(self.tokenizer.timestamp_begin) + + consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_( + 1 + ) + + if ( + len(consecutive) > 0 + ): # if the output contains two consecutive timestamp tokens + last_slice = 0 + for current_slice in consecutive: + sliced_tokens = tokens[last_slice:current_slice] + start_timestamp_position = ( + sliced_tokens[0].item() - self.tokenizer.timestamp_begin + ) + end_timestamp_position = ( + sliced_tokens[-1].item() - self.tokenizer.timestamp_begin + ) + chunk = self._get_chunk( + start=self.timestamp + + start_timestamp_position * self.time_precision, + end=self.timestamp + end_timestamp_position * self.time_precision, + text_tokens=sliced_tokens[1:-1], + result=result, + ) + if chunk is not None: + yield chunk + last_slice = current_slice + last_timestamp_position = ( + tokens[last_slice - 1].item() - self.tokenizer.timestamp_begin + ) + self.buffer_tokens.extend(tokens[: last_slice + 1].tolist()) + else: + duration = segment_duration + timestamps = tokens[timestamp_tokens.nonzero().flatten()] + if len(timestamps) > 0: + # no consecutive timestamps but it has a timestamp; use the last one. + # single timestamp at the end means no speech after the last timestamp. + last_timestamp_position = ( + timestamps[-1].item() - self.tokenizer.timestamp_begin + ) + duration = last_timestamp_position * self.time_precision + chunk = self._get_chunk( + start=self.timestamp, + end=self.timestamp + duration, + text_tokens=tokens, + result=result, + ) + if chunk is not None: + yield chunk + + if result.temperature > 0.5: + # do not feed the prompt tokens if a high temperature was used + del self.buffer_tokens + self.buffer_tokens = [] + def transcribe( self, *, segment: np.ndarray, - ) -> Optional[DecodingResult]: + ) -> Iterator[ParsedChunk]: log_spec = log_mel_spectrogram(audio=segment).unsqueeze(0) segment = ( pad_or_trim(log_spec, N_FRAMES) @@ -105,7 +205,8 @@ class WhisperStreamingTranscriber: and result.avg_logprob > self.config.logprob_threshold ): return - - # FIXME: work with timestamp - - return result + segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE + yield from self._deal_timestamp( + result=result, segment_duration=segment_duration + ) + self.timestamp += float(segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE)