Need to save log_spec

This commit is contained in:
Yuta Hayashibe 2022-09-23 20:41:44 +09:00
parent 8b5615cefa
commit 957a3ffe18

View file

@ -1,6 +1,6 @@
#!/usr/bin/env python3
from typing import Iterator, List, Optional
from typing import Iterator, List, Optional, Union
import numpy as np
import torch
@ -126,7 +126,9 @@ class WhisperStreamingTranscriber:
no_speech_prob=result.no_speech_prob,
)
def _deal_timestamp(self, *, result, segment_duration) -> Iterator[ParsedChunk]:
def _deal_timestamp(
self, *, result, segment_duration
) -> Iterator[Union[ParsedChunk, int]]:
tokens = torch.tensor(result.tokens)
timestamp_tokens: torch.Tensor = tokens.ge(self.tokenizer.timestamp_begin)
@ -156,11 +158,13 @@ class WhisperStreamingTranscriber:
if chunk is not None:
yield chunk
last_slice = current_slice
last_timestamp_position = (
tokens[last_slice - 1].item() - self.tokenizer.timestamp_begin
last_timestamp_position0: int = (
tokens[last_slice - 1].item()
- self.tokenizer.timestamp_begin # type:ignore
)
self.buffer_tokens.extend(tokens[: last_slice + 1].tolist())
self.timestamp += last_timestamp_position * self.time_precision
self.timestamp += last_timestamp_position0 * self.time_precision
yield last_timestamp_position0
else:
duration = segment_duration
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
@ -207,6 +211,9 @@ class WhisperStreamingTranscriber:
):
return
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
yield from self._deal_timestamp(
result=result, segment_duration=segment_duration
)
for v in self._deal_timestamp(result=result, segment_duration=segment_duration):
if isinstance(v, int):
# FIXME: save log_spec
pass
else:
yield v