mirror of
https://github.com/shirayu/whispering.git
synced 2024-06-12 18:29:21 +00:00
Need to save log_spec
This commit is contained in:
parent
8b5615cefa
commit
957a3ffe18
|
@ -1,6 +1,6 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
from typing import Iterator, List, Optional
|
from typing import Iterator, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -126,7 +126,9 @@ class WhisperStreamingTranscriber:
|
||||||
no_speech_prob=result.no_speech_prob,
|
no_speech_prob=result.no_speech_prob,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _deal_timestamp(self, *, result, segment_duration) -> Iterator[ParsedChunk]:
|
def _deal_timestamp(
|
||||||
|
self, *, result, segment_duration
|
||||||
|
) -> Iterator[Union[ParsedChunk, int]]:
|
||||||
tokens = torch.tensor(result.tokens)
|
tokens = torch.tensor(result.tokens)
|
||||||
timestamp_tokens: torch.Tensor = tokens.ge(self.tokenizer.timestamp_begin)
|
timestamp_tokens: torch.Tensor = tokens.ge(self.tokenizer.timestamp_begin)
|
||||||
|
|
||||||
|
@ -156,11 +158,13 @@ class WhisperStreamingTranscriber:
|
||||||
if chunk is not None:
|
if chunk is not None:
|
||||||
yield chunk
|
yield chunk
|
||||||
last_slice = current_slice
|
last_slice = current_slice
|
||||||
last_timestamp_position = (
|
last_timestamp_position0: int = (
|
||||||
tokens[last_slice - 1].item() - self.tokenizer.timestamp_begin
|
tokens[last_slice - 1].item()
|
||||||
|
- self.tokenizer.timestamp_begin # type:ignore
|
||||||
)
|
)
|
||||||
self.buffer_tokens.extend(tokens[: last_slice + 1].tolist())
|
self.buffer_tokens.extend(tokens[: last_slice + 1].tolist())
|
||||||
self.timestamp += last_timestamp_position * self.time_precision
|
self.timestamp += last_timestamp_position0 * self.time_precision
|
||||||
|
yield last_timestamp_position0
|
||||||
else:
|
else:
|
||||||
duration = segment_duration
|
duration = segment_duration
|
||||||
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
||||||
|
@ -207,6 +211,9 @@ class WhisperStreamingTranscriber:
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
|
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
|
||||||
yield from self._deal_timestamp(
|
for v in self._deal_timestamp(result=result, segment_duration=segment_duration):
|
||||||
result=result, segment_duration=segment_duration
|
if isinstance(v, int):
|
||||||
)
|
# FIXME: save log_spec
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
yield v
|
||||||
|
|
Loading…
Reference in a new issue