mirror of
https://github.com/shirayu/whispering.git
synced 2025-03-28 04:55:28 +00:00
Add Context to manage context
This commit is contained in:
parent
f940fef3b3
commit
f1d85762fc
4 changed files with 51 additions and 36 deletions
|
@ -12,7 +12,7 @@ from whisper import available_models
|
|||
from whisper.audio import N_FRAMES, SAMPLE_RATE
|
||||
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
|
||||
|
||||
from whispering.schema import WhisperConfig
|
||||
from whispering.schema import Context, WhisperConfig
|
||||
from whispering.serve import serve_with_websocket
|
||||
from whispering.transcriber import WhisperStreamingTranscriber
|
||||
from whispering.websocket_client import run_websocket_client
|
||||
|
@ -33,6 +33,7 @@ def transcribe_from_mic(
|
|||
logger.warning(status)
|
||||
q.put(indata.ravel())
|
||||
|
||||
ctx: Context = Context()
|
||||
logger.info("Ready to transcribe")
|
||||
with sd.InputStream(
|
||||
samplerate=SAMPLE_RATE,
|
||||
|
@ -46,7 +47,7 @@ def transcribe_from_mic(
|
|||
while True:
|
||||
logger.debug(f"Segment #: {idx}, The rest of queue: {q.qsize()}")
|
||||
segment = q.get()
|
||||
for chunk in wsp.transcribe(segment=segment):
|
||||
for chunk in wsp.transcribe(segment=segment, ctx=ctx):
|
||||
print(f"{chunk.start:.2f}->{chunk.end:.2f}\t{chunk.text}")
|
||||
idx += 1
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
|
@ -23,6 +24,12 @@ class WhisperConfig(BaseModel):
|
|||
compression_ratio_threshold: Optional[float] = 2.4
|
||||
|
||||
|
||||
class Context(BaseModel, arbitrary_types_allowed=True):
|
||||
timestamp: float = 0.0
|
||||
buffer_tokens: List[torch.Tensor] = []
|
||||
buffer_mel: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
class ParsedChunk(BaseModel):
|
||||
start: float
|
||||
end: float
|
||||
|
|
|
@ -6,7 +6,7 @@ from logging import getLogger
|
|||
import numpy as np
|
||||
import websockets
|
||||
|
||||
from whispering.transcriber import WhisperStreamingTranscriber
|
||||
from whispering.transcriber import Context, WhisperStreamingTranscriber
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -14,6 +14,7 @@ logger = getLogger(__name__)
|
|||
async def serve_with_websocket_main(websocket):
|
||||
global g_wsp
|
||||
idx: int = 0
|
||||
ctx: Context = Context()
|
||||
|
||||
while True:
|
||||
logger.debug(f"Segment #: {idx}")
|
||||
|
@ -25,7 +26,7 @@ async def serve_with_websocket_main(websocket):
|
|||
|
||||
logger.debug(f"Message size: {len(message)}")
|
||||
segment = np.frombuffer(message, dtype=np.float32)
|
||||
for chunk in g_wsp.transcribe(segment=segment):
|
||||
for chunk in g_wsp.transcribe(segment=segment, ctx=ctx):
|
||||
await websocket.send(chunk.json())
|
||||
idx += 1
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from logging import getLogger
|
||||
from typing import Iterator, List, Optional, Union
|
||||
from typing import Final, Iterator, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -17,7 +17,7 @@ from whisper.decoding import DecodingOptions, DecodingResult
|
|||
from whisper.tokenizer import get_tokenizer
|
||||
from whisper.utils import exact_div
|
||||
|
||||
from whispering.schema import ParsedChunk, WhisperConfig
|
||||
from whispering.schema import Context, ParsedChunk, WhisperConfig
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -37,25 +37,21 @@ class WhisperStreamingTranscriber:
|
|||
self.fp16 = False
|
||||
|
||||
def __init__(self, *, config: WhisperConfig):
|
||||
self.config: WhisperConfig = config
|
||||
self.model: Whisper = load_model(config.model_name, device=config.device)
|
||||
self.config: Final[WhisperConfig] = config
|
||||
self.model: Final[Whisper] = load_model(config.model_name, device=config.device)
|
||||
self.tokenizer = get_tokenizer(
|
||||
self.model.is_multilingual,
|
||||
language=config.language,
|
||||
task="transcribe",
|
||||
)
|
||||
self._set_dtype(config.fp16)
|
||||
self.timestamp: float = 0.0
|
||||
self.input_stride = exact_div(
|
||||
self.input_stride: Final[int] = exact_div(
|
||||
N_FRAMES, self.model.dims.n_audio_ctx
|
||||
) # mel frames per output token: 2
|
||||
self.time_precision = (
|
||||
self.time_precision: Final[float] = (
|
||||
self.input_stride * HOP_LENGTH / SAMPLE_RATE
|
||||
) # time per output token: 0.02 (seconds)
|
||||
|
||||
self.buffer_tokens = []
|
||||
self.buffer_mel = None
|
||||
|
||||
def _get_decoding_options(
|
||||
self,
|
||||
*,
|
||||
|
@ -87,6 +83,7 @@ class WhisperStreamingTranscriber:
|
|||
self,
|
||||
*,
|
||||
segment: np.ndarray,
|
||||
ctx: Context,
|
||||
) -> List[DecodingResult]:
|
||||
assert len(self.config.temperatures) >= 1
|
||||
t = self.config.temperatures[0]
|
||||
|
@ -94,7 +91,7 @@ class WhisperStreamingTranscriber:
|
|||
|
||||
_decode_options1: DecodingOptions = self._get_decoding_options(
|
||||
t=t,
|
||||
prompt=self.buffer_tokens,
|
||||
prompt=ctx.buffer_tokens,
|
||||
beam_size=self.config.beam_size,
|
||||
patience=0.0,
|
||||
best_of=None,
|
||||
|
@ -115,7 +112,7 @@ class WhisperStreamingTranscriber:
|
|||
)
|
||||
_decode_options2: DecodingOptions = self._get_decoding_options(
|
||||
t=t,
|
||||
prompt=self.buffer_tokens,
|
||||
prompt=ctx.buffer_tokens,
|
||||
beam_size=None,
|
||||
patience=0.0,
|
||||
best_of=self.config.best_of,
|
||||
|
@ -158,7 +155,11 @@ class WhisperStreamingTranscriber:
|
|||
)
|
||||
|
||||
def _deal_timestamp(
|
||||
self, *, result, segment_duration
|
||||
self,
|
||||
*,
|
||||
result,
|
||||
segment_duration,
|
||||
ctx: Context,
|
||||
) -> Iterator[Union[ParsedChunk, int]]:
|
||||
tokens = torch.tensor(result.tokens)
|
||||
timestamp_tokens: torch.Tensor = tokens.ge(self.tokenizer.timestamp_begin)
|
||||
|
@ -182,9 +183,9 @@ class WhisperStreamingTranscriber:
|
|||
sliced_tokens[-1].item() - self.tokenizer.timestamp_begin
|
||||
)
|
||||
chunk = self._get_chunk(
|
||||
start=self.timestamp
|
||||
start=ctx.timestamp
|
||||
+ start_timestamp_position * self.time_precision,
|
||||
end=self.timestamp + end_timestamp_position * self.time_precision,
|
||||
end=ctx.timestamp + end_timestamp_position * self.time_precision,
|
||||
text_tokens=sliced_tokens[1:-1],
|
||||
result=result,
|
||||
)
|
||||
|
@ -195,8 +196,8 @@ class WhisperStreamingTranscriber:
|
|||
tokens[last_slice - 1].item()
|
||||
- self.tokenizer.timestamp_begin # type:ignore
|
||||
)
|
||||
self.buffer_tokens.extend(tokens[: last_slice + 1].tolist())
|
||||
self.timestamp += last_timestamp_position0 * self.time_precision
|
||||
ctx.buffer_tokens.extend(tokens[: last_slice + 1].tolist())
|
||||
ctx.timestamp += last_timestamp_position0 * self.time_precision
|
||||
yield last_timestamp_position0
|
||||
else:
|
||||
duration = segment_duration
|
||||
|
@ -211,34 +212,35 @@ class WhisperStreamingTranscriber:
|
|||
duration = last_timestamp_position * self.time_precision
|
||||
logger.debug(f"segment_duration: {segment_duration}, Duration: {duration}")
|
||||
chunk = self._get_chunk(
|
||||
start=self.timestamp,
|
||||
end=self.timestamp + duration,
|
||||
start=ctx.timestamp,
|
||||
end=ctx.timestamp + duration,
|
||||
text_tokens=tokens,
|
||||
result=result,
|
||||
)
|
||||
if chunk is not None:
|
||||
yield chunk
|
||||
self.timestamp += duration
|
||||
ctx.timestamp += duration
|
||||
|
||||
if result.temperature > 0.5:
|
||||
# do not feed the prompt tokens if a high temperature was used
|
||||
del self.buffer_tokens
|
||||
self.buffer_tokens = []
|
||||
logger.debug(f"Length of buffer: {len(self.buffer_tokens)}")
|
||||
del ctx.buffer_tokens
|
||||
ctx.buffer_tokens = []
|
||||
logger.debug(f"Length of buffer: {len(ctx.buffer_tokens)}")
|
||||
|
||||
def transcribe(
|
||||
self,
|
||||
*,
|
||||
segment: np.ndarray,
|
||||
ctx: Context,
|
||||
) -> Iterator[ParsedChunk]:
|
||||
new_mel = log_mel_spectrogram(audio=segment).unsqueeze(0)
|
||||
logger.debug(f"Incoming new_mel.shape: {new_mel.shape}")
|
||||
if self.buffer_mel is None:
|
||||
if ctx.buffer_mel is None:
|
||||
mel = new_mel
|
||||
else:
|
||||
logger.debug(f"buffer_mel.shape: {self.buffer_mel.shape}")
|
||||
mel = torch.cat([self.buffer_mel, new_mel], dim=-1)
|
||||
self.buffer_mel = None
|
||||
logger.debug(f"buffer_mel.shape: {ctx.buffer_mel.shape}")
|
||||
mel = torch.cat([ctx.buffer_mel, new_mel], dim=-1)
|
||||
ctx.buffer_mel = None
|
||||
logger.debug(f"mel.shape: {mel.shape}")
|
||||
|
||||
seek: int = 0
|
||||
|
@ -252,11 +254,12 @@ class WhisperStreamingTranscriber:
|
|||
logger.warning("Padding is not expected while speaking")
|
||||
|
||||
logger.debug(
|
||||
f"seek={seek}, timestamp={self.timestamp}, "
|
||||
f"seek={seek}, timestamp={ctx.timestamp}, "
|
||||
f"mel.shape: {mel.shape}, segment.shape: {segment.shape}"
|
||||
)
|
||||
results = self._decode_with_fallback(
|
||||
segment=segment,
|
||||
ctx=ctx,
|
||||
)
|
||||
result = results[0]
|
||||
logger.debug(
|
||||
|
@ -278,7 +281,9 @@ class WhisperStreamingTranscriber:
|
|||
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
|
||||
last_timestamp_position: Optional[int] = None
|
||||
for v in self._deal_timestamp(
|
||||
result=result, segment_duration=segment_duration
|
||||
result=result,
|
||||
segment_duration=segment_duration,
|
||||
ctx=ctx,
|
||||
):
|
||||
if isinstance(v, int):
|
||||
last_timestamp_position = v
|
||||
|
@ -294,8 +299,9 @@ class WhisperStreamingTranscriber:
|
|||
break
|
||||
|
||||
if mel.shape[-1] - seek <= 0:
|
||||
logger.debug(f"self.buffer_mel is None ({mel.shape}, {seek})")
|
||||
logger.debug(f"ctx.buffer_mel is None ({mel.shape}, {seek})")
|
||||
return
|
||||
self.buffer_mel = mel[:, :, seek:]
|
||||
logger.debug(f"self.buffer_mel.shape: {self.buffer_mel.shape}")
|
||||
ctx.buffer_mel = mel[:, :, seek:]
|
||||
assert ctx.buffer_mel is not None
|
||||
logger.debug(f"ctx.buffer_mel.shape: {ctx.buffer_mel.shape}")
|
||||
del mel
|
||||
|
|
Loading…
Reference in a new issue