Add Context to manage context

This commit is contained in:
Yuta Hayashibe 2022-09-29 20:14:56 +09:00
parent f940fef3b3
commit f1d85762fc
4 changed files with 51 additions and 36 deletions

View file

@ -12,7 +12,7 @@ from whisper import available_models
from whisper.audio import N_FRAMES, SAMPLE_RATE from whisper.audio import N_FRAMES, SAMPLE_RATE
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
from whispering.schema import WhisperConfig from whispering.schema import Context, WhisperConfig
from whispering.serve import serve_with_websocket from whispering.serve import serve_with_websocket
from whispering.transcriber import WhisperStreamingTranscriber from whispering.transcriber import WhisperStreamingTranscriber
from whispering.websocket_client import run_websocket_client from whispering.websocket_client import run_websocket_client
@ -33,6 +33,7 @@ def transcribe_from_mic(
logger.warning(status) logger.warning(status)
q.put(indata.ravel()) q.put(indata.ravel())
ctx: Context = Context()
logger.info("Ready to transcribe") logger.info("Ready to transcribe")
with sd.InputStream( with sd.InputStream(
samplerate=SAMPLE_RATE, samplerate=SAMPLE_RATE,
@ -46,7 +47,7 @@ def transcribe_from_mic(
while True: while True:
logger.debug(f"Segment #: {idx}, The rest of queue: {q.qsize()}") logger.debug(f"Segment #: {idx}, The rest of queue: {q.qsize()}")
segment = q.get() segment = q.get()
for chunk in wsp.transcribe(segment=segment): for chunk in wsp.transcribe(segment=segment, ctx=ctx):
print(f"{chunk.start:.2f}->{chunk.end:.2f}\t{chunk.text}") print(f"{chunk.start:.2f}->{chunk.end:.2f}\t{chunk.text}")
idx += 1 idx += 1

View file

@ -2,6 +2,7 @@
from typing import List, Optional from typing import List, Optional
import torch
from pydantic import BaseModel from pydantic import BaseModel
@ -23,6 +24,12 @@ class WhisperConfig(BaseModel):
compression_ratio_threshold: Optional[float] = 2.4 compression_ratio_threshold: Optional[float] = 2.4
class Context(BaseModel, arbitrary_types_allowed=True):
timestamp: float = 0.0
buffer_tokens: List[torch.Tensor] = []
buffer_mel: Optional[torch.Tensor] = None
class ParsedChunk(BaseModel): class ParsedChunk(BaseModel):
start: float start: float
end: float end: float

View file

@ -6,7 +6,7 @@ from logging import getLogger
import numpy as np import numpy as np
import websockets import websockets
from whispering.transcriber import WhisperStreamingTranscriber from whispering.transcriber import Context, WhisperStreamingTranscriber
logger = getLogger(__name__) logger = getLogger(__name__)
@ -14,6 +14,7 @@ logger = getLogger(__name__)
async def serve_with_websocket_main(websocket): async def serve_with_websocket_main(websocket):
global g_wsp global g_wsp
idx: int = 0 idx: int = 0
ctx: Context = Context()
while True: while True:
logger.debug(f"Segment #: {idx}") logger.debug(f"Segment #: {idx}")
@ -25,7 +26,7 @@ async def serve_with_websocket_main(websocket):
logger.debug(f"Message size: {len(message)}") logger.debug(f"Message size: {len(message)}")
segment = np.frombuffer(message, dtype=np.float32) segment = np.frombuffer(message, dtype=np.float32)
for chunk in g_wsp.transcribe(segment=segment): for chunk in g_wsp.transcribe(segment=segment, ctx=ctx):
await websocket.send(chunk.json()) await websocket.send(chunk.json())
idx += 1 idx += 1

View file

@ -1,7 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from logging import getLogger from logging import getLogger
from typing import Iterator, List, Optional, Union from typing import Final, Iterator, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
@ -17,7 +17,7 @@ from whisper.decoding import DecodingOptions, DecodingResult
from whisper.tokenizer import get_tokenizer from whisper.tokenizer import get_tokenizer
from whisper.utils import exact_div from whisper.utils import exact_div
from whispering.schema import ParsedChunk, WhisperConfig from whispering.schema import Context, ParsedChunk, WhisperConfig
logger = getLogger(__name__) logger = getLogger(__name__)
@ -37,25 +37,21 @@ class WhisperStreamingTranscriber:
self.fp16 = False self.fp16 = False
def __init__(self, *, config: WhisperConfig): def __init__(self, *, config: WhisperConfig):
self.config: WhisperConfig = config self.config: Final[WhisperConfig] = config
self.model: Whisper = load_model(config.model_name, device=config.device) self.model: Final[Whisper] = load_model(config.model_name, device=config.device)
self.tokenizer = get_tokenizer( self.tokenizer = get_tokenizer(
self.model.is_multilingual, self.model.is_multilingual,
language=config.language, language=config.language,
task="transcribe", task="transcribe",
) )
self._set_dtype(config.fp16) self._set_dtype(config.fp16)
self.timestamp: float = 0.0 self.input_stride: Final[int] = exact_div(
self.input_stride = exact_div(
N_FRAMES, self.model.dims.n_audio_ctx N_FRAMES, self.model.dims.n_audio_ctx
) # mel frames per output token: 2 ) # mel frames per output token: 2
self.time_precision = ( self.time_precision: Final[float] = (
self.input_stride * HOP_LENGTH / SAMPLE_RATE self.input_stride * HOP_LENGTH / SAMPLE_RATE
) # time per output token: 0.02 (seconds) ) # time per output token: 0.02 (seconds)
self.buffer_tokens = []
self.buffer_mel = None
def _get_decoding_options( def _get_decoding_options(
self, self,
*, *,
@ -87,6 +83,7 @@ class WhisperStreamingTranscriber:
self, self,
*, *,
segment: np.ndarray, segment: np.ndarray,
ctx: Context,
) -> List[DecodingResult]: ) -> List[DecodingResult]:
assert len(self.config.temperatures) >= 1 assert len(self.config.temperatures) >= 1
t = self.config.temperatures[0] t = self.config.temperatures[0]
@ -94,7 +91,7 @@ class WhisperStreamingTranscriber:
_decode_options1: DecodingOptions = self._get_decoding_options( _decode_options1: DecodingOptions = self._get_decoding_options(
t=t, t=t,
prompt=self.buffer_tokens, prompt=ctx.buffer_tokens,
beam_size=self.config.beam_size, beam_size=self.config.beam_size,
patience=0.0, patience=0.0,
best_of=None, best_of=None,
@ -115,7 +112,7 @@ class WhisperStreamingTranscriber:
) )
_decode_options2: DecodingOptions = self._get_decoding_options( _decode_options2: DecodingOptions = self._get_decoding_options(
t=t, t=t,
prompt=self.buffer_tokens, prompt=ctx.buffer_tokens,
beam_size=None, beam_size=None,
patience=0.0, patience=0.0,
best_of=self.config.best_of, best_of=self.config.best_of,
@ -158,7 +155,11 @@ class WhisperStreamingTranscriber:
) )
def _deal_timestamp( def _deal_timestamp(
self, *, result, segment_duration self,
*,
result,
segment_duration,
ctx: Context,
) -> Iterator[Union[ParsedChunk, int]]: ) -> Iterator[Union[ParsedChunk, int]]:
tokens = torch.tensor(result.tokens) tokens = torch.tensor(result.tokens)
timestamp_tokens: torch.Tensor = tokens.ge(self.tokenizer.timestamp_begin) timestamp_tokens: torch.Tensor = tokens.ge(self.tokenizer.timestamp_begin)
@ -182,9 +183,9 @@ class WhisperStreamingTranscriber:
sliced_tokens[-1].item() - self.tokenizer.timestamp_begin sliced_tokens[-1].item() - self.tokenizer.timestamp_begin
) )
chunk = self._get_chunk( chunk = self._get_chunk(
start=self.timestamp start=ctx.timestamp
+ start_timestamp_position * self.time_precision, + start_timestamp_position * self.time_precision,
end=self.timestamp + end_timestamp_position * self.time_precision, end=ctx.timestamp + end_timestamp_position * self.time_precision,
text_tokens=sliced_tokens[1:-1], text_tokens=sliced_tokens[1:-1],
result=result, result=result,
) )
@ -195,8 +196,8 @@ class WhisperStreamingTranscriber:
tokens[last_slice - 1].item() tokens[last_slice - 1].item()
- self.tokenizer.timestamp_begin # type:ignore - self.tokenizer.timestamp_begin # type:ignore
) )
self.buffer_tokens.extend(tokens[: last_slice + 1].tolist()) ctx.buffer_tokens.extend(tokens[: last_slice + 1].tolist())
self.timestamp += last_timestamp_position0 * self.time_precision ctx.timestamp += last_timestamp_position0 * self.time_precision
yield last_timestamp_position0 yield last_timestamp_position0
else: else:
duration = segment_duration duration = segment_duration
@ -211,34 +212,35 @@ class WhisperStreamingTranscriber:
duration = last_timestamp_position * self.time_precision duration = last_timestamp_position * self.time_precision
logger.debug(f"segment_duration: {segment_duration}, Duration: {duration}") logger.debug(f"segment_duration: {segment_duration}, Duration: {duration}")
chunk = self._get_chunk( chunk = self._get_chunk(
start=self.timestamp, start=ctx.timestamp,
end=self.timestamp + duration, end=ctx.timestamp + duration,
text_tokens=tokens, text_tokens=tokens,
result=result, result=result,
) )
if chunk is not None: if chunk is not None:
yield chunk yield chunk
self.timestamp += duration ctx.timestamp += duration
if result.temperature > 0.5: if result.temperature > 0.5:
# do not feed the prompt tokens if a high temperature was used # do not feed the prompt tokens if a high temperature was used
del self.buffer_tokens del ctx.buffer_tokens
self.buffer_tokens = [] ctx.buffer_tokens = []
logger.debug(f"Length of buffer: {len(self.buffer_tokens)}") logger.debug(f"Length of buffer: {len(ctx.buffer_tokens)}")
def transcribe( def transcribe(
self, self,
*, *,
segment: np.ndarray, segment: np.ndarray,
ctx: Context,
) -> Iterator[ParsedChunk]: ) -> Iterator[ParsedChunk]:
new_mel = log_mel_spectrogram(audio=segment).unsqueeze(0) new_mel = log_mel_spectrogram(audio=segment).unsqueeze(0)
logger.debug(f"Incoming new_mel.shape: {new_mel.shape}") logger.debug(f"Incoming new_mel.shape: {new_mel.shape}")
if self.buffer_mel is None: if ctx.buffer_mel is None:
mel = new_mel mel = new_mel
else: else:
logger.debug(f"buffer_mel.shape: {self.buffer_mel.shape}") logger.debug(f"buffer_mel.shape: {ctx.buffer_mel.shape}")
mel = torch.cat([self.buffer_mel, new_mel], dim=-1) mel = torch.cat([ctx.buffer_mel, new_mel], dim=-1)
self.buffer_mel = None ctx.buffer_mel = None
logger.debug(f"mel.shape: {mel.shape}") logger.debug(f"mel.shape: {mel.shape}")
seek: int = 0 seek: int = 0
@ -252,11 +254,12 @@ class WhisperStreamingTranscriber:
logger.warning("Padding is not expected while speaking") logger.warning("Padding is not expected while speaking")
logger.debug( logger.debug(
f"seek={seek}, timestamp={self.timestamp}, " f"seek={seek}, timestamp={ctx.timestamp}, "
f"mel.shape: {mel.shape}, segment.shape: {segment.shape}" f"mel.shape: {mel.shape}, segment.shape: {segment.shape}"
) )
results = self._decode_with_fallback( results = self._decode_with_fallback(
segment=segment, segment=segment,
ctx=ctx,
) )
result = results[0] result = results[0]
logger.debug( logger.debug(
@ -278,7 +281,9 @@ class WhisperStreamingTranscriber:
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
last_timestamp_position: Optional[int] = None last_timestamp_position: Optional[int] = None
for v in self._deal_timestamp( for v in self._deal_timestamp(
result=result, segment_duration=segment_duration result=result,
segment_duration=segment_duration,
ctx=ctx,
): ):
if isinstance(v, int): if isinstance(v, int):
last_timestamp_position = v last_timestamp_position = v
@ -294,8 +299,9 @@ class WhisperStreamingTranscriber:
break break
if mel.shape[-1] - seek <= 0: if mel.shape[-1] - seek <= 0:
logger.debug(f"self.buffer_mel is None ({mel.shape}, {seek})") logger.debug(f"ctx.buffer_mel is None ({mel.shape}, {seek})")
return return
self.buffer_mel = mel[:, :, seek:] ctx.buffer_mel = mel[:, :, seek:]
logger.debug(f"self.buffer_mel.shape: {self.buffer_mel.shape}") assert ctx.buffer_mel is not None
logger.debug(f"ctx.buffer_mel.shape: {ctx.buffer_mel.shape}")
del mel del mel