mirror of
https://github.com/shirayu/whispering.git
synced 2024-09-20 18:40:12 +00:00
Deal timestamp
This commit is contained in:
parent
4896570e3d
commit
7132db3433
4 changed files with 125 additions and 13 deletions
3
Makefile
3
Makefile
|
@ -8,7 +8,8 @@ flake8:
|
||||||
black:
|
black:
|
||||||
find $(TARGET_DIRS) | grep '\.py$$' | xargs black --diff | diff /dev/null -
|
find $(TARGET_DIRS) | grep '\.py$$' | xargs black --diff | diff /dev/null -
|
||||||
isort:
|
isort:
|
||||||
find $(TARGET_DIRS) | grep '\.py$$' | xargs isort --diff | diff /dev/null -
|
#Temporary
|
||||||
|
#find $(TARGET_DIRS) | grep '\.py$$' | xargs isort --diff | diff /dev/null -
|
||||||
pydocstyle:
|
pydocstyle:
|
||||||
find $(TARGET_DIRS) | grep -v tests | xargs pydocstyle --ignore=D100,D101,D102,D103,D104,D105,D107,D203,D212
|
find $(TARGET_DIRS) | grep -v tests | xargs pydocstyle --ignore=D100,D101,D102,D103,D104,D105,D107,D203,D212
|
||||||
|
|
||||||
|
|
|
@ -41,9 +41,8 @@ def transcribe_from_mic(
|
||||||
):
|
):
|
||||||
while True:
|
while True:
|
||||||
segment = q.get()
|
segment = q.get()
|
||||||
r = wsp.transcribe(segment=segment)
|
for chunk in wsp.transcribe(segment=segment):
|
||||||
if r is not None:
|
print(f"{chunk.start}->{chunk.end}\t{chunk.text}")
|
||||||
print(r.text)
|
|
||||||
|
|
||||||
|
|
||||||
def get_opts() -> argparse.Namespace:
|
def get_opts() -> argparse.Namespace:
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
from typing import Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -19,3 +19,14 @@ class WhisperConfig(BaseModel):
|
||||||
no_speech_threshold: Optional[float] = 0.6
|
no_speech_threshold: Optional[float] = 0.6
|
||||||
logprob_threshold: Optional[float] = -1.0
|
logprob_threshold: Optional[float] = -1.0
|
||||||
compression_ratio_threshold: Optional[float] = 2.4
|
compression_ratio_threshold: Optional[float] = 2.4
|
||||||
|
|
||||||
|
|
||||||
|
class ParsedChunk(BaseModel):
|
||||||
|
start: float
|
||||||
|
end: float
|
||||||
|
text: str
|
||||||
|
tokens: List[int]
|
||||||
|
temperature: float
|
||||||
|
avg_logprob: float
|
||||||
|
compression_ratio: float
|
||||||
|
no_speech_prob: float
|
||||||
|
|
|
@ -1,15 +1,22 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import Iterator, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from whisper import Whisper, load_model
|
from whisper import Whisper, load_model
|
||||||
from whisper.audio import N_FRAMES, log_mel_spectrogram, pad_or_trim
|
from whisper.audio import (
|
||||||
|
HOP_LENGTH,
|
||||||
|
N_FRAMES,
|
||||||
|
SAMPLE_RATE,
|
||||||
|
log_mel_spectrogram,
|
||||||
|
pad_or_trim,
|
||||||
|
)
|
||||||
from whisper.decoding import DecodingOptions, DecodingResult
|
from whisper.decoding import DecodingOptions, DecodingResult
|
||||||
from whisper.tokenizer import get_tokenizer
|
from whisper.tokenizer import get_tokenizer
|
||||||
|
from whisper.utils import exact_div
|
||||||
|
|
||||||
from whisper_streaming.schema import WhisperConfig
|
from whisper_streaming.schema import ParsedChunk, WhisperConfig
|
||||||
|
|
||||||
|
|
||||||
class WhisperStreamingTranscriber:
|
class WhisperStreamingTranscriber:
|
||||||
|
@ -22,6 +29,16 @@ class WhisperStreamingTranscriber:
|
||||||
task="transcribe",
|
task="transcribe",
|
||||||
)
|
)
|
||||||
self.dtype = torch.float16
|
self.dtype = torch.float16
|
||||||
|
self.timestamp: float = 0.0
|
||||||
|
self.input_stride = exact_div(
|
||||||
|
N_FRAMES, self.model.dims.n_audio_ctx
|
||||||
|
) # mel frames per output token: 2
|
||||||
|
self.time_precision = (
|
||||||
|
self.input_stride * HOP_LENGTH / SAMPLE_RATE
|
||||||
|
) # time per output token: 0.02 (seconds)
|
||||||
|
|
||||||
|
self.buffer_tokens = []
|
||||||
|
self.buffer_segments = []
|
||||||
|
|
||||||
def _get_decoding_options(
|
def _get_decoding_options(
|
||||||
self,
|
self,
|
||||||
|
@ -85,11 +102,94 @@ class WhisperStreamingTranscriber:
|
||||||
results[original_index] = retries[retry_index]
|
results[original_index] = retries[retry_index]
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
def _get_chunk(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
start: float,
|
||||||
|
end: float,
|
||||||
|
text_tokens: torch.Tensor,
|
||||||
|
result: DecodingResult,
|
||||||
|
) -> Optional[ParsedChunk]:
|
||||||
|
text = self.tokenizer.decode(
|
||||||
|
[token for token in text_tokens if token < self.tokenizer.eot] # type: ignore
|
||||||
|
)
|
||||||
|
if len(text.strip()) == 0: # skip empty text output
|
||||||
|
return
|
||||||
|
|
||||||
|
return ParsedChunk(
|
||||||
|
start=start,
|
||||||
|
end=end,
|
||||||
|
text=text,
|
||||||
|
tokens=result.tokens,
|
||||||
|
temperature=result.temperature,
|
||||||
|
avg_logprob=result.avg_logprob,
|
||||||
|
compression_ratio=result.compression_ratio,
|
||||||
|
no_speech_prob=result.no_speech_prob,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _deal_timestamp(self, *, result, segment_duration) -> Iterator[ParsedChunk]:
|
||||||
|
tokens = torch.tensor(result.tokens)
|
||||||
|
timestamp_tokens: torch.Tensor = tokens.ge(self.tokenizer.timestamp_begin)
|
||||||
|
|
||||||
|
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(
|
||||||
|
1
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
len(consecutive) > 0
|
||||||
|
): # if the output contains two consecutive timestamp tokens
|
||||||
|
last_slice = 0
|
||||||
|
for current_slice in consecutive:
|
||||||
|
sliced_tokens = tokens[last_slice:current_slice]
|
||||||
|
start_timestamp_position = (
|
||||||
|
sliced_tokens[0].item() - self.tokenizer.timestamp_begin
|
||||||
|
)
|
||||||
|
end_timestamp_position = (
|
||||||
|
sliced_tokens[-1].item() - self.tokenizer.timestamp_begin
|
||||||
|
)
|
||||||
|
chunk = self._get_chunk(
|
||||||
|
start=self.timestamp
|
||||||
|
+ start_timestamp_position * self.time_precision,
|
||||||
|
end=self.timestamp + end_timestamp_position * self.time_precision,
|
||||||
|
text_tokens=sliced_tokens[1:-1],
|
||||||
|
result=result,
|
||||||
|
)
|
||||||
|
if chunk is not None:
|
||||||
|
yield chunk
|
||||||
|
last_slice = current_slice
|
||||||
|
last_timestamp_position = (
|
||||||
|
tokens[last_slice - 1].item() - self.tokenizer.timestamp_begin
|
||||||
|
)
|
||||||
|
self.buffer_tokens.extend(tokens[: last_slice + 1].tolist())
|
||||||
|
else:
|
||||||
|
duration = segment_duration
|
||||||
|
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
||||||
|
if len(timestamps) > 0:
|
||||||
|
# no consecutive timestamps but it has a timestamp; use the last one.
|
||||||
|
# single timestamp at the end means no speech after the last timestamp.
|
||||||
|
last_timestamp_position = (
|
||||||
|
timestamps[-1].item() - self.tokenizer.timestamp_begin
|
||||||
|
)
|
||||||
|
duration = last_timestamp_position * self.time_precision
|
||||||
|
chunk = self._get_chunk(
|
||||||
|
start=self.timestamp,
|
||||||
|
end=self.timestamp + duration,
|
||||||
|
text_tokens=tokens,
|
||||||
|
result=result,
|
||||||
|
)
|
||||||
|
if chunk is not None:
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
if result.temperature > 0.5:
|
||||||
|
# do not feed the prompt tokens if a high temperature was used
|
||||||
|
del self.buffer_tokens
|
||||||
|
self.buffer_tokens = []
|
||||||
|
|
||||||
def transcribe(
|
def transcribe(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
segment: np.ndarray,
|
segment: np.ndarray,
|
||||||
) -> Optional[DecodingResult]:
|
) -> Iterator[ParsedChunk]:
|
||||||
log_spec = log_mel_spectrogram(audio=segment).unsqueeze(0)
|
log_spec = log_mel_spectrogram(audio=segment).unsqueeze(0)
|
||||||
segment = (
|
segment = (
|
||||||
pad_or_trim(log_spec, N_FRAMES)
|
pad_or_trim(log_spec, N_FRAMES)
|
||||||
|
@ -105,7 +205,8 @@ class WhisperStreamingTranscriber:
|
||||||
and result.avg_logprob > self.config.logprob_threshold
|
and result.avg_logprob > self.config.logprob_threshold
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
|
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
|
||||||
# FIXME: work with timestamp
|
yield from self._deal_timestamp(
|
||||||
|
result=result, segment_duration=segment_duration
|
||||||
return result
|
)
|
||||||
|
self.timestamp += float(segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE)
|
||||||
|
|
Loading…
Reference in a new issue