mirror of
https://github.com/shirayu/whispering.git
synced 2025-04-26 18:44:42 +00:00
Add Context to manage context
This commit is contained in:
parent
f940fef3b3
commit
f1d85762fc
4 changed files with 51 additions and 36 deletions
|
@ -12,7 +12,7 @@ from whisper import available_models
|
||||||
from whisper.audio import N_FRAMES, SAMPLE_RATE
|
from whisper.audio import N_FRAMES, SAMPLE_RATE
|
||||||
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
|
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
|
||||||
|
|
||||||
from whispering.schema import WhisperConfig
|
from whispering.schema import Context, WhisperConfig
|
||||||
from whispering.serve import serve_with_websocket
|
from whispering.serve import serve_with_websocket
|
||||||
from whispering.transcriber import WhisperStreamingTranscriber
|
from whispering.transcriber import WhisperStreamingTranscriber
|
||||||
from whispering.websocket_client import run_websocket_client
|
from whispering.websocket_client import run_websocket_client
|
||||||
|
@ -33,6 +33,7 @@ def transcribe_from_mic(
|
||||||
logger.warning(status)
|
logger.warning(status)
|
||||||
q.put(indata.ravel())
|
q.put(indata.ravel())
|
||||||
|
|
||||||
|
ctx: Context = Context()
|
||||||
logger.info("Ready to transcribe")
|
logger.info("Ready to transcribe")
|
||||||
with sd.InputStream(
|
with sd.InputStream(
|
||||||
samplerate=SAMPLE_RATE,
|
samplerate=SAMPLE_RATE,
|
||||||
|
@ -46,7 +47,7 @@ def transcribe_from_mic(
|
||||||
while True:
|
while True:
|
||||||
logger.debug(f"Segment #: {idx}, The rest of queue: {q.qsize()}")
|
logger.debug(f"Segment #: {idx}, The rest of queue: {q.qsize()}")
|
||||||
segment = q.get()
|
segment = q.get()
|
||||||
for chunk in wsp.transcribe(segment=segment):
|
for chunk in wsp.transcribe(segment=segment, ctx=ctx):
|
||||||
print(f"{chunk.start:.2f}->{chunk.end:.2f}\t{chunk.text}")
|
print(f"{chunk.start:.2f}->{chunk.end:.2f}\t{chunk.text}")
|
||||||
idx += 1
|
idx += 1
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
@ -23,6 +24,12 @@ class WhisperConfig(BaseModel):
|
||||||
compression_ratio_threshold: Optional[float] = 2.4
|
compression_ratio_threshold: Optional[float] = 2.4
|
||||||
|
|
||||||
|
|
||||||
|
class Context(BaseModel, arbitrary_types_allowed=True):
|
||||||
|
timestamp: float = 0.0
|
||||||
|
buffer_tokens: List[torch.Tensor] = []
|
||||||
|
buffer_mel: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
|
||||||
class ParsedChunk(BaseModel):
|
class ParsedChunk(BaseModel):
|
||||||
start: float
|
start: float
|
||||||
end: float
|
end: float
|
||||||
|
|
|
@ -6,7 +6,7 @@ from logging import getLogger
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import websockets
|
import websockets
|
||||||
|
|
||||||
from whispering.transcriber import WhisperStreamingTranscriber
|
from whispering.transcriber import Context, WhisperStreamingTranscriber
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -14,6 +14,7 @@ logger = getLogger(__name__)
|
||||||
async def serve_with_websocket_main(websocket):
|
async def serve_with_websocket_main(websocket):
|
||||||
global g_wsp
|
global g_wsp
|
||||||
idx: int = 0
|
idx: int = 0
|
||||||
|
ctx: Context = Context()
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
logger.debug(f"Segment #: {idx}")
|
logger.debug(f"Segment #: {idx}")
|
||||||
|
@ -25,7 +26,7 @@ async def serve_with_websocket_main(websocket):
|
||||||
|
|
||||||
logger.debug(f"Message size: {len(message)}")
|
logger.debug(f"Message size: {len(message)}")
|
||||||
segment = np.frombuffer(message, dtype=np.float32)
|
segment = np.frombuffer(message, dtype=np.float32)
|
||||||
for chunk in g_wsp.transcribe(segment=segment):
|
for chunk in g_wsp.transcribe(segment=segment, ctx=ctx):
|
||||||
await websocket.send(chunk.json())
|
await websocket.send(chunk.json())
|
||||||
idx += 1
|
idx += 1
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import Iterator, List, Optional, Union
|
from typing import Final, Iterator, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -17,7 +17,7 @@ from whisper.decoding import DecodingOptions, DecodingResult
|
||||||
from whisper.tokenizer import get_tokenizer
|
from whisper.tokenizer import get_tokenizer
|
||||||
from whisper.utils import exact_div
|
from whisper.utils import exact_div
|
||||||
|
|
||||||
from whispering.schema import ParsedChunk, WhisperConfig
|
from whispering.schema import Context, ParsedChunk, WhisperConfig
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -37,25 +37,21 @@ class WhisperStreamingTranscriber:
|
||||||
self.fp16 = False
|
self.fp16 = False
|
||||||
|
|
||||||
def __init__(self, *, config: WhisperConfig):
|
def __init__(self, *, config: WhisperConfig):
|
||||||
self.config: WhisperConfig = config
|
self.config: Final[WhisperConfig] = config
|
||||||
self.model: Whisper = load_model(config.model_name, device=config.device)
|
self.model: Final[Whisper] = load_model(config.model_name, device=config.device)
|
||||||
self.tokenizer = get_tokenizer(
|
self.tokenizer = get_tokenizer(
|
||||||
self.model.is_multilingual,
|
self.model.is_multilingual,
|
||||||
language=config.language,
|
language=config.language,
|
||||||
task="transcribe",
|
task="transcribe",
|
||||||
)
|
)
|
||||||
self._set_dtype(config.fp16)
|
self._set_dtype(config.fp16)
|
||||||
self.timestamp: float = 0.0
|
self.input_stride: Final[int] = exact_div(
|
||||||
self.input_stride = exact_div(
|
|
||||||
N_FRAMES, self.model.dims.n_audio_ctx
|
N_FRAMES, self.model.dims.n_audio_ctx
|
||||||
) # mel frames per output token: 2
|
) # mel frames per output token: 2
|
||||||
self.time_precision = (
|
self.time_precision: Final[float] = (
|
||||||
self.input_stride * HOP_LENGTH / SAMPLE_RATE
|
self.input_stride * HOP_LENGTH / SAMPLE_RATE
|
||||||
) # time per output token: 0.02 (seconds)
|
) # time per output token: 0.02 (seconds)
|
||||||
|
|
||||||
self.buffer_tokens = []
|
|
||||||
self.buffer_mel = None
|
|
||||||
|
|
||||||
def _get_decoding_options(
|
def _get_decoding_options(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
@ -87,6 +83,7 @@ class WhisperStreamingTranscriber:
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
segment: np.ndarray,
|
segment: np.ndarray,
|
||||||
|
ctx: Context,
|
||||||
) -> List[DecodingResult]:
|
) -> List[DecodingResult]:
|
||||||
assert len(self.config.temperatures) >= 1
|
assert len(self.config.temperatures) >= 1
|
||||||
t = self.config.temperatures[0]
|
t = self.config.temperatures[0]
|
||||||
|
@ -94,7 +91,7 @@ class WhisperStreamingTranscriber:
|
||||||
|
|
||||||
_decode_options1: DecodingOptions = self._get_decoding_options(
|
_decode_options1: DecodingOptions = self._get_decoding_options(
|
||||||
t=t,
|
t=t,
|
||||||
prompt=self.buffer_tokens,
|
prompt=ctx.buffer_tokens,
|
||||||
beam_size=self.config.beam_size,
|
beam_size=self.config.beam_size,
|
||||||
patience=0.0,
|
patience=0.0,
|
||||||
best_of=None,
|
best_of=None,
|
||||||
|
@ -115,7 +112,7 @@ class WhisperStreamingTranscriber:
|
||||||
)
|
)
|
||||||
_decode_options2: DecodingOptions = self._get_decoding_options(
|
_decode_options2: DecodingOptions = self._get_decoding_options(
|
||||||
t=t,
|
t=t,
|
||||||
prompt=self.buffer_tokens,
|
prompt=ctx.buffer_tokens,
|
||||||
beam_size=None,
|
beam_size=None,
|
||||||
patience=0.0,
|
patience=0.0,
|
||||||
best_of=self.config.best_of,
|
best_of=self.config.best_of,
|
||||||
|
@ -158,7 +155,11 @@ class WhisperStreamingTranscriber:
|
||||||
)
|
)
|
||||||
|
|
||||||
def _deal_timestamp(
|
def _deal_timestamp(
|
||||||
self, *, result, segment_duration
|
self,
|
||||||
|
*,
|
||||||
|
result,
|
||||||
|
segment_duration,
|
||||||
|
ctx: Context,
|
||||||
) -> Iterator[Union[ParsedChunk, int]]:
|
) -> Iterator[Union[ParsedChunk, int]]:
|
||||||
tokens = torch.tensor(result.tokens)
|
tokens = torch.tensor(result.tokens)
|
||||||
timestamp_tokens: torch.Tensor = tokens.ge(self.tokenizer.timestamp_begin)
|
timestamp_tokens: torch.Tensor = tokens.ge(self.tokenizer.timestamp_begin)
|
||||||
|
@ -182,9 +183,9 @@ class WhisperStreamingTranscriber:
|
||||||
sliced_tokens[-1].item() - self.tokenizer.timestamp_begin
|
sliced_tokens[-1].item() - self.tokenizer.timestamp_begin
|
||||||
)
|
)
|
||||||
chunk = self._get_chunk(
|
chunk = self._get_chunk(
|
||||||
start=self.timestamp
|
start=ctx.timestamp
|
||||||
+ start_timestamp_position * self.time_precision,
|
+ start_timestamp_position * self.time_precision,
|
||||||
end=self.timestamp + end_timestamp_position * self.time_precision,
|
end=ctx.timestamp + end_timestamp_position * self.time_precision,
|
||||||
text_tokens=sliced_tokens[1:-1],
|
text_tokens=sliced_tokens[1:-1],
|
||||||
result=result,
|
result=result,
|
||||||
)
|
)
|
||||||
|
@ -195,8 +196,8 @@ class WhisperStreamingTranscriber:
|
||||||
tokens[last_slice - 1].item()
|
tokens[last_slice - 1].item()
|
||||||
- self.tokenizer.timestamp_begin # type:ignore
|
- self.tokenizer.timestamp_begin # type:ignore
|
||||||
)
|
)
|
||||||
self.buffer_tokens.extend(tokens[: last_slice + 1].tolist())
|
ctx.buffer_tokens.extend(tokens[: last_slice + 1].tolist())
|
||||||
self.timestamp += last_timestamp_position0 * self.time_precision
|
ctx.timestamp += last_timestamp_position0 * self.time_precision
|
||||||
yield last_timestamp_position0
|
yield last_timestamp_position0
|
||||||
else:
|
else:
|
||||||
duration = segment_duration
|
duration = segment_duration
|
||||||
|
@ -211,34 +212,35 @@ class WhisperStreamingTranscriber:
|
||||||
duration = last_timestamp_position * self.time_precision
|
duration = last_timestamp_position * self.time_precision
|
||||||
logger.debug(f"segment_duration: {segment_duration}, Duration: {duration}")
|
logger.debug(f"segment_duration: {segment_duration}, Duration: {duration}")
|
||||||
chunk = self._get_chunk(
|
chunk = self._get_chunk(
|
||||||
start=self.timestamp,
|
start=ctx.timestamp,
|
||||||
end=self.timestamp + duration,
|
end=ctx.timestamp + duration,
|
||||||
text_tokens=tokens,
|
text_tokens=tokens,
|
||||||
result=result,
|
result=result,
|
||||||
)
|
)
|
||||||
if chunk is not None:
|
if chunk is not None:
|
||||||
yield chunk
|
yield chunk
|
||||||
self.timestamp += duration
|
ctx.timestamp += duration
|
||||||
|
|
||||||
if result.temperature > 0.5:
|
if result.temperature > 0.5:
|
||||||
# do not feed the prompt tokens if a high temperature was used
|
# do not feed the prompt tokens if a high temperature was used
|
||||||
del self.buffer_tokens
|
del ctx.buffer_tokens
|
||||||
self.buffer_tokens = []
|
ctx.buffer_tokens = []
|
||||||
logger.debug(f"Length of buffer: {len(self.buffer_tokens)}")
|
logger.debug(f"Length of buffer: {len(ctx.buffer_tokens)}")
|
||||||
|
|
||||||
def transcribe(
|
def transcribe(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
segment: np.ndarray,
|
segment: np.ndarray,
|
||||||
|
ctx: Context,
|
||||||
) -> Iterator[ParsedChunk]:
|
) -> Iterator[ParsedChunk]:
|
||||||
new_mel = log_mel_spectrogram(audio=segment).unsqueeze(0)
|
new_mel = log_mel_spectrogram(audio=segment).unsqueeze(0)
|
||||||
logger.debug(f"Incoming new_mel.shape: {new_mel.shape}")
|
logger.debug(f"Incoming new_mel.shape: {new_mel.shape}")
|
||||||
if self.buffer_mel is None:
|
if ctx.buffer_mel is None:
|
||||||
mel = new_mel
|
mel = new_mel
|
||||||
else:
|
else:
|
||||||
logger.debug(f"buffer_mel.shape: {self.buffer_mel.shape}")
|
logger.debug(f"buffer_mel.shape: {ctx.buffer_mel.shape}")
|
||||||
mel = torch.cat([self.buffer_mel, new_mel], dim=-1)
|
mel = torch.cat([ctx.buffer_mel, new_mel], dim=-1)
|
||||||
self.buffer_mel = None
|
ctx.buffer_mel = None
|
||||||
logger.debug(f"mel.shape: {mel.shape}")
|
logger.debug(f"mel.shape: {mel.shape}")
|
||||||
|
|
||||||
seek: int = 0
|
seek: int = 0
|
||||||
|
@ -252,11 +254,12 @@ class WhisperStreamingTranscriber:
|
||||||
logger.warning("Padding is not expected while speaking")
|
logger.warning("Padding is not expected while speaking")
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"seek={seek}, timestamp={self.timestamp}, "
|
f"seek={seek}, timestamp={ctx.timestamp}, "
|
||||||
f"mel.shape: {mel.shape}, segment.shape: {segment.shape}"
|
f"mel.shape: {mel.shape}, segment.shape: {segment.shape}"
|
||||||
)
|
)
|
||||||
results = self._decode_with_fallback(
|
results = self._decode_with_fallback(
|
||||||
segment=segment,
|
segment=segment,
|
||||||
|
ctx=ctx,
|
||||||
)
|
)
|
||||||
result = results[0]
|
result = results[0]
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
@ -278,7 +281,9 @@ class WhisperStreamingTranscriber:
|
||||||
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
|
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
|
||||||
last_timestamp_position: Optional[int] = None
|
last_timestamp_position: Optional[int] = None
|
||||||
for v in self._deal_timestamp(
|
for v in self._deal_timestamp(
|
||||||
result=result, segment_duration=segment_duration
|
result=result,
|
||||||
|
segment_duration=segment_duration,
|
||||||
|
ctx=ctx,
|
||||||
):
|
):
|
||||||
if isinstance(v, int):
|
if isinstance(v, int):
|
||||||
last_timestamp_position = v
|
last_timestamp_position = v
|
||||||
|
@ -294,8 +299,9 @@ class WhisperStreamingTranscriber:
|
||||||
break
|
break
|
||||||
|
|
||||||
if mel.shape[-1] - seek <= 0:
|
if mel.shape[-1] - seek <= 0:
|
||||||
logger.debug(f"self.buffer_mel is None ({mel.shape}, {seek})")
|
logger.debug(f"ctx.buffer_mel is None ({mel.shape}, {seek})")
|
||||||
return
|
return
|
||||||
self.buffer_mel = mel[:, :, seek:]
|
ctx.buffer_mel = mel[:, :, seek:]
|
||||||
logger.debug(f"self.buffer_mel.shape: {self.buffer_mel.shape}")
|
assert ctx.buffer_mel is not None
|
||||||
|
logger.debug(f"ctx.buffer_mel.shape: {ctx.buffer_mel.shape}")
|
||||||
del mel
|
del mel
|
||||||
|
|
Loading…
Reference in a new issue