diff --git a/whisper_streaming/transcriber.py b/whisper_streaming/transcriber.py index a11f2f8..aad86f4 100644 --- a/whisper_streaming/transcriber.py +++ b/whisper_streaming/transcriber.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -from typing import Iterator, List, Optional +from typing import Iterator, List, Optional, Union import numpy as np import torch @@ -126,7 +126,9 @@ class WhisperStreamingTranscriber: no_speech_prob=result.no_speech_prob, ) - def _deal_timestamp(self, *, result, segment_duration) -> Iterator[ParsedChunk]: + def _deal_timestamp( + self, *, result, segment_duration + ) -> Iterator[Union[ParsedChunk, int]]: tokens = torch.tensor(result.tokens) timestamp_tokens: torch.Tensor = tokens.ge(self.tokenizer.timestamp_begin) @@ -156,11 +158,13 @@ class WhisperStreamingTranscriber: if chunk is not None: yield chunk last_slice = current_slice - last_timestamp_position = ( - tokens[last_slice - 1].item() - self.tokenizer.timestamp_begin + last_timestamp_position0: int = ( + tokens[last_slice - 1].item() + - self.tokenizer.timestamp_begin # type:ignore ) self.buffer_tokens.extend(tokens[: last_slice + 1].tolist()) - self.timestamp += last_timestamp_position * self.time_precision + self.timestamp += last_timestamp_position0 * self.time_precision + yield last_timestamp_position0 else: duration = segment_duration timestamps = tokens[timestamp_tokens.nonzero().flatten()] @@ -207,6 +211,9 @@ class WhisperStreamingTranscriber: ): return segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE - yield from self._deal_timestamp( - result=result, segment_duration=segment_duration - ) + for v in self._deal_timestamp(result=result, segment_duration=segment_duration): + if isinstance(v, int): + # FIXME: save log_spec + pass + else: + yield v