Add Context to manage context

This commit is contained in:
Yuta Hayashibe 2022-09-29 20:14:56 +09:00
parent f940fef3b3
commit f1d85762fc
4 changed files with 51 additions and 36 deletions

View file

@ -12,7 +12,7 @@ from whisper import available_models
from whisper.audio import N_FRAMES, SAMPLE_RATE
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
from whispering.schema import WhisperConfig
from whispering.schema import Context, WhisperConfig
from whispering.serve import serve_with_websocket
from whispering.transcriber import WhisperStreamingTranscriber
from whispering.websocket_client import run_websocket_client
@ -33,6 +33,7 @@ def transcribe_from_mic(
logger.warning(status)
q.put(indata.ravel())
ctx: Context = Context()
logger.info("Ready to transcribe")
with sd.InputStream(
samplerate=SAMPLE_RATE,
@ -46,7 +47,7 @@ def transcribe_from_mic(
while True:
logger.debug(f"Segment #: {idx}, The rest of queue: {q.qsize()}")
segment = q.get()
for chunk in wsp.transcribe(segment=segment):
for chunk in wsp.transcribe(segment=segment, ctx=ctx):
print(f"{chunk.start:.2f}->{chunk.end:.2f}\t{chunk.text}")
idx += 1

View file

@ -2,6 +2,7 @@
from typing import List, Optional
import torch
from pydantic import BaseModel
@ -23,6 +24,12 @@ class WhisperConfig(BaseModel):
compression_ratio_threshold: Optional[float] = 2.4
class Context(BaseModel, arbitrary_types_allowed=True):
timestamp: float = 0.0
buffer_tokens: List[torch.Tensor] = []
buffer_mel: Optional[torch.Tensor] = None
class ParsedChunk(BaseModel):
start: float
end: float

View file

@ -6,7 +6,7 @@ from logging import getLogger
import numpy as np
import websockets
from whispering.transcriber import WhisperStreamingTranscriber
from whispering.transcriber import Context, WhisperStreamingTranscriber
logger = getLogger(__name__)
@ -14,6 +14,7 @@ logger = getLogger(__name__)
async def serve_with_websocket_main(websocket):
global g_wsp
idx: int = 0
ctx: Context = Context()
while True:
logger.debug(f"Segment #: {idx}")
@ -25,7 +26,7 @@ async def serve_with_websocket_main(websocket):
logger.debug(f"Message size: {len(message)}")
segment = np.frombuffer(message, dtype=np.float32)
for chunk in g_wsp.transcribe(segment=segment):
for chunk in g_wsp.transcribe(segment=segment, ctx=ctx):
await websocket.send(chunk.json())
idx += 1

View file

@ -1,7 +1,7 @@
#!/usr/bin/env python3
from logging import getLogger
from typing import Iterator, List, Optional, Union
from typing import Final, Iterator, List, Optional, Union
import numpy as np
import torch
@ -17,7 +17,7 @@ from whisper.decoding import DecodingOptions, DecodingResult
from whisper.tokenizer import get_tokenizer
from whisper.utils import exact_div
from whispering.schema import ParsedChunk, WhisperConfig
from whispering.schema import Context, ParsedChunk, WhisperConfig
logger = getLogger(__name__)
@ -37,25 +37,21 @@ class WhisperStreamingTranscriber:
self.fp16 = False
def __init__(self, *, config: WhisperConfig):
self.config: WhisperConfig = config
self.model: Whisper = load_model(config.model_name, device=config.device)
self.config: Final[WhisperConfig] = config
self.model: Final[Whisper] = load_model(config.model_name, device=config.device)
self.tokenizer = get_tokenizer(
self.model.is_multilingual,
language=config.language,
task="transcribe",
)
self._set_dtype(config.fp16)
self.timestamp: float = 0.0
self.input_stride = exact_div(
self.input_stride: Final[int] = exact_div(
N_FRAMES, self.model.dims.n_audio_ctx
) # mel frames per output token: 2
self.time_precision = (
self.time_precision: Final[float] = (
self.input_stride * HOP_LENGTH / SAMPLE_RATE
) # time per output token: 0.02 (seconds)
self.buffer_tokens = []
self.buffer_mel = None
def _get_decoding_options(
self,
*,
@ -87,6 +83,7 @@ class WhisperStreamingTranscriber:
self,
*,
segment: np.ndarray,
ctx: Context,
) -> List[DecodingResult]:
assert len(self.config.temperatures) >= 1
t = self.config.temperatures[0]
@ -94,7 +91,7 @@ class WhisperStreamingTranscriber:
_decode_options1: DecodingOptions = self._get_decoding_options(
t=t,
prompt=self.buffer_tokens,
prompt=ctx.buffer_tokens,
beam_size=self.config.beam_size,
patience=0.0,
best_of=None,
@ -115,7 +112,7 @@ class WhisperStreamingTranscriber:
)
_decode_options2: DecodingOptions = self._get_decoding_options(
t=t,
prompt=self.buffer_tokens,
prompt=ctx.buffer_tokens,
beam_size=None,
patience=0.0,
best_of=self.config.best_of,
@ -158,7 +155,11 @@ class WhisperStreamingTranscriber:
)
def _deal_timestamp(
self, *, result, segment_duration
self,
*,
result,
segment_duration,
ctx: Context,
) -> Iterator[Union[ParsedChunk, int]]:
tokens = torch.tensor(result.tokens)
timestamp_tokens: torch.Tensor = tokens.ge(self.tokenizer.timestamp_begin)
@ -182,9 +183,9 @@ class WhisperStreamingTranscriber:
sliced_tokens[-1].item() - self.tokenizer.timestamp_begin
)
chunk = self._get_chunk(
start=self.timestamp
start=ctx.timestamp
+ start_timestamp_position * self.time_precision,
end=self.timestamp + end_timestamp_position * self.time_precision,
end=ctx.timestamp + end_timestamp_position * self.time_precision,
text_tokens=sliced_tokens[1:-1],
result=result,
)
@ -195,8 +196,8 @@ class WhisperStreamingTranscriber:
tokens[last_slice - 1].item()
- self.tokenizer.timestamp_begin # type:ignore
)
self.buffer_tokens.extend(tokens[: last_slice + 1].tolist())
self.timestamp += last_timestamp_position0 * self.time_precision
ctx.buffer_tokens.extend(tokens[: last_slice + 1].tolist())
ctx.timestamp += last_timestamp_position0 * self.time_precision
yield last_timestamp_position0
else:
duration = segment_duration
@ -211,34 +212,35 @@ class WhisperStreamingTranscriber:
duration = last_timestamp_position * self.time_precision
logger.debug(f"segment_duration: {segment_duration}, Duration: {duration}")
chunk = self._get_chunk(
start=self.timestamp,
end=self.timestamp + duration,
start=ctx.timestamp,
end=ctx.timestamp + duration,
text_tokens=tokens,
result=result,
)
if chunk is not None:
yield chunk
self.timestamp += duration
ctx.timestamp += duration
if result.temperature > 0.5:
# do not feed the prompt tokens if a high temperature was used
del self.buffer_tokens
self.buffer_tokens = []
logger.debug(f"Length of buffer: {len(self.buffer_tokens)}")
del ctx.buffer_tokens
ctx.buffer_tokens = []
logger.debug(f"Length of buffer: {len(ctx.buffer_tokens)}")
def transcribe(
self,
*,
segment: np.ndarray,
ctx: Context,
) -> Iterator[ParsedChunk]:
new_mel = log_mel_spectrogram(audio=segment).unsqueeze(0)
logger.debug(f"Incoming new_mel.shape: {new_mel.shape}")
if self.buffer_mel is None:
if ctx.buffer_mel is None:
mel = new_mel
else:
logger.debug(f"buffer_mel.shape: {self.buffer_mel.shape}")
mel = torch.cat([self.buffer_mel, new_mel], dim=-1)
self.buffer_mel = None
logger.debug(f"buffer_mel.shape: {ctx.buffer_mel.shape}")
mel = torch.cat([ctx.buffer_mel, new_mel], dim=-1)
ctx.buffer_mel = None
logger.debug(f"mel.shape: {mel.shape}")
seek: int = 0
@ -252,11 +254,12 @@ class WhisperStreamingTranscriber:
logger.warning("Padding is not expected while speaking")
logger.debug(
f"seek={seek}, timestamp={self.timestamp}, "
f"seek={seek}, timestamp={ctx.timestamp}, "
f"mel.shape: {mel.shape}, segment.shape: {segment.shape}"
)
results = self._decode_with_fallback(
segment=segment,
ctx=ctx,
)
result = results[0]
logger.debug(
@ -278,7 +281,9 @@ class WhisperStreamingTranscriber:
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
last_timestamp_position: Optional[int] = None
for v in self._deal_timestamp(
result=result, segment_duration=segment_duration
result=result,
segment_duration=segment_duration,
ctx=ctx,
):
if isinstance(v, int):
last_timestamp_position = v
@ -294,8 +299,9 @@ class WhisperStreamingTranscriber:
break
if mel.shape[-1] - seek <= 0:
logger.debug(f"self.buffer_mel is None ({mel.shape}, {seek})")
logger.debug(f"ctx.buffer_mel is None ({mel.shape}, {seek})")
return
self.buffer_mel = mel[:, :, seek:]
logger.debug(f"self.buffer_mel.shape: {self.buffer_mel.shape}")
ctx.buffer_mel = mel[:, :, seek:]
assert ctx.buffer_mel is not None
logger.debug(f"ctx.buffer_mel.shape: {ctx.buffer_mel.shape}")
del mel