Need to save log_spec

This commit is contained in:
Yuta Hayashibe 2022-09-23 20:41:44 +09:00
parent 8b5615cefa
commit 957a3ffe18

View file

@ -1,6 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from typing import Iterator, List, Optional from typing import Iterator, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
@ -126,7 +126,9 @@ class WhisperStreamingTranscriber:
no_speech_prob=result.no_speech_prob, no_speech_prob=result.no_speech_prob,
) )
def _deal_timestamp(self, *, result, segment_duration) -> Iterator[ParsedChunk]: def _deal_timestamp(
self, *, result, segment_duration
) -> Iterator[Union[ParsedChunk, int]]:
tokens = torch.tensor(result.tokens) tokens = torch.tensor(result.tokens)
timestamp_tokens: torch.Tensor = tokens.ge(self.tokenizer.timestamp_begin) timestamp_tokens: torch.Tensor = tokens.ge(self.tokenizer.timestamp_begin)
@ -156,11 +158,13 @@ class WhisperStreamingTranscriber:
if chunk is not None: if chunk is not None:
yield chunk yield chunk
last_slice = current_slice last_slice = current_slice
last_timestamp_position = ( last_timestamp_position0: int = (
tokens[last_slice - 1].item() - self.tokenizer.timestamp_begin tokens[last_slice - 1].item()
- self.tokenizer.timestamp_begin # type:ignore
) )
self.buffer_tokens.extend(tokens[: last_slice + 1].tolist()) self.buffer_tokens.extend(tokens[: last_slice + 1].tolist())
self.timestamp += last_timestamp_position * self.time_precision self.timestamp += last_timestamp_position0 * self.time_precision
yield last_timestamp_position0
else: else:
duration = segment_duration duration = segment_duration
timestamps = tokens[timestamp_tokens.nonzero().flatten()] timestamps = tokens[timestamp_tokens.nonzero().flatten()]
@ -207,6 +211,9 @@ class WhisperStreamingTranscriber:
): ):
return return
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
yield from self._deal_timestamp( for v in self._deal_timestamp(result=result, segment_duration=segment_duration):
result=result, segment_duration=segment_duration if isinstance(v, int):
) # FIXME: save log_spec
pass
else:
yield v