mirror of
https://github.com/shirayu/whispering.git
synced 2025-02-16 10:35:16 +00:00
Updated Context
This commit is contained in:
parent
5abd0db07a
commit
88cda28b83
4 changed files with 51 additions and 36 deletions
|
@ -25,6 +25,7 @@ def transcribe_from_mic(
|
|||
wsp: WhisperStreamingTranscriber,
|
||||
sd_device: Optional[Union[int, str]],
|
||||
num_block: int,
|
||||
ctx: Context,
|
||||
) -> None:
|
||||
q = queue.Queue()
|
||||
|
||||
|
@ -33,7 +34,6 @@ def transcribe_from_mic(
|
|||
logger.warning(status)
|
||||
q.put(indata.ravel())
|
||||
|
||||
ctx: Context = Context()
|
||||
logger.info("Ready to transcribe")
|
||||
with sd.InputStream(
|
||||
samplerate=SAMPLE_RATE,
|
||||
|
@ -136,14 +136,11 @@ def get_opts() -> argparse.Namespace:
|
|||
return opts
|
||||
|
||||
|
||||
def get_wshiper(*, opts):
|
||||
def get_wshiper(*, opts) -> WhisperStreamingTranscriber:
|
||||
config = WhisperConfig(
|
||||
model_name=opts.model,
|
||||
language=opts.language,
|
||||
device=opts.device,
|
||||
beam_size=opts.beam_size,
|
||||
temperatures=opts.temperature,
|
||||
allow_padding=opts.allow_padding,
|
||||
)
|
||||
|
||||
logger.debug(f"WhisperConfig: {config}")
|
||||
|
@ -151,6 +148,16 @@ def get_wshiper(*, opts):
|
|||
return wsp
|
||||
|
||||
|
||||
def get_context(*, opts) -> Context:
|
||||
ctx = Context(
|
||||
beam_size=opts.beam_size,
|
||||
temperatures=opts.temperature,
|
||||
allow_padding=opts.allow_padding,
|
||||
)
|
||||
logger.debug(f"Context: {ctx}")
|
||||
return ctx
|
||||
|
||||
|
||||
def show_devices():
|
||||
devices = sd.query_devices()
|
||||
for i, device in enumerate(devices):
|
||||
|
@ -182,21 +189,25 @@ def main() -> None:
|
|||
assert opts.language is not None
|
||||
assert opts.model is not None
|
||||
wsp = get_wshiper(opts=opts)
|
||||
ctx: Context = get_context(opts=opts)
|
||||
asyncio.run(
|
||||
serve_with_websocket(
|
||||
wsp=wsp,
|
||||
host=opts.host,
|
||||
port=opts.port,
|
||||
ctx=ctx,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert opts.language is not None
|
||||
assert opts.model is not None
|
||||
wsp = get_wshiper(opts=opts)
|
||||
ctx: Context = get_context(opts=opts)
|
||||
transcribe_from_mic(
|
||||
wsp=wsp,
|
||||
sd_device=opts.mic,
|
||||
num_block=opts.num_block,
|
||||
ctx=ctx,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -10,19 +10,7 @@ class WhisperConfig(BaseModel):
|
|||
model_name: str
|
||||
device: str
|
||||
language: str
|
||||
|
||||
allow_padding: bool = False
|
||||
temperatures: List[float]
|
||||
fp16: bool = True
|
||||
compression_ratio_threshold: Optional[float] = 2.4
|
||||
logprob_threshold: Optional[float] = -1.0
|
||||
no_captions_threshold: Optional[float] = 0.6
|
||||
best_of: int = 5
|
||||
beam_size: Optional[int] = None
|
||||
no_speech_threshold: Optional[float] = 0.6
|
||||
logprob_threshold: Optional[float] = -1.0
|
||||
compression_ratio_threshold: Optional[float] = 2.4
|
||||
buffer_threshold: Optional[float] = 0.5
|
||||
|
||||
@root_validator
|
||||
def validate_model_name(cls, values):
|
||||
|
@ -39,6 +27,19 @@ class Context(BaseModel, arbitrary_types_allowed=True):
|
|||
buffer_tokens: List[torch.Tensor] = []
|
||||
buffer_mel: Optional[torch.Tensor] = None
|
||||
|
||||
temperatures: List[float]
|
||||
allow_padding: bool = False
|
||||
patience: Optional[float] = None
|
||||
compression_ratio_threshold: Optional[float] = 2.4
|
||||
logprob_threshold: Optional[float] = -1.0
|
||||
no_captions_threshold: Optional[float] = 0.6
|
||||
best_of: int = 5
|
||||
beam_size: Optional[int] = None
|
||||
no_speech_threshold: Optional[float] = 0.6
|
||||
logprob_threshold: Optional[float] = -1.0
|
||||
compression_ratio_threshold: Optional[float] = 2.4
|
||||
buffer_threshold: Optional[float] = 0.5
|
||||
|
||||
|
||||
class ParsedChunk(BaseModel):
|
||||
start: float
|
||||
|
|
|
@ -13,8 +13,8 @@ logger = getLogger(__name__)
|
|||
|
||||
async def serve_with_websocket_main(websocket):
|
||||
global g_wsp
|
||||
global g_ctx
|
||||
idx: int = 0
|
||||
ctx: Context = Context()
|
||||
|
||||
while True:
|
||||
logger.debug(f"Segment #: {idx}")
|
||||
|
@ -26,7 +26,7 @@ async def serve_with_websocket_main(websocket):
|
|||
|
||||
logger.debug(f"Message size: {len(message)}")
|
||||
segment = np.frombuffer(message, dtype=np.float32)
|
||||
for chunk in g_wsp.transcribe(segment=segment, ctx=ctx):
|
||||
for chunk in g_wsp.transcribe(segment=segment, ctx=g_ctx):
|
||||
await websocket.send(chunk.json())
|
||||
idx += 1
|
||||
|
||||
|
@ -36,11 +36,14 @@ async def serve_with_websocket(
|
|||
wsp: WhisperStreamingTranscriber,
|
||||
host: str,
|
||||
port: int,
|
||||
ctx: Context,
|
||||
):
|
||||
logger.info(f"Serve at {host}:{port}")
|
||||
logger.info("Make secure with your responsibility!")
|
||||
global g_wsp
|
||||
global g_ctx
|
||||
g_wsp = wsp
|
||||
g_ctx = ctx
|
||||
|
||||
try:
|
||||
async with websockets.serve( # type: ignore
|
||||
|
|
|
@ -85,25 +85,25 @@ class WhisperStreamingTranscriber:
|
|||
segment: np.ndarray,
|
||||
ctx: Context,
|
||||
) -> List[DecodingResult]:
|
||||
assert len(self.config.temperatures) >= 1
|
||||
t = self.config.temperatures[0]
|
||||
assert len(ctx.temperatures) >= 1
|
||||
t = ctx.temperatures[0]
|
||||
logger.debug(f"temperature: {t}")
|
||||
|
||||
_decode_options1: DecodingOptions = self._get_decoding_options(
|
||||
t=t,
|
||||
prompt=ctx.buffer_tokens,
|
||||
beam_size=self.config.beam_size,
|
||||
beam_size=ctx.beam_size,
|
||||
patience=None,
|
||||
best_of=None,
|
||||
)
|
||||
results: List[DecodingResult] = self.model.decode(segment, _decode_options1) # type: ignore
|
||||
|
||||
for t in self.config.temperatures[1:]:
|
||||
for t in ctx.temperatures[1:]:
|
||||
needs_fallback = [
|
||||
self.config.compression_ratio_threshold is not None
|
||||
and result.compression_ratio > self.config.compression_ratio_threshold
|
||||
or self.config.logprob_threshold is not None
|
||||
and result.avg_logprob < self.config.logprob_threshold
|
||||
ctx.compression_ratio_threshold is not None
|
||||
and result.compression_ratio > ctx.compression_ratio_threshold
|
||||
or ctx.logprob_threshold is not None
|
||||
and result.avg_logprob < ctx.logprob_threshold
|
||||
for result in results
|
||||
]
|
||||
if any(needs_fallback):
|
||||
|
@ -114,8 +114,8 @@ class WhisperStreamingTranscriber:
|
|||
t=t,
|
||||
prompt=ctx.buffer_tokens,
|
||||
beam_size=None,
|
||||
patience=0.0,
|
||||
best_of=self.config.best_of,
|
||||
patience=ctx.patience,
|
||||
best_of=ctx.best_of,
|
||||
)
|
||||
retries: List[DecodingResult] = self.model.decode(
|
||||
segment[needs_fallback], _decode_options2 # type: ignore
|
||||
|
@ -221,7 +221,7 @@ class WhisperStreamingTranscriber:
|
|||
yield chunk
|
||||
ctx.timestamp += duration
|
||||
|
||||
if result.temperature > self.config.buffer_threshold:
|
||||
if result.temperature > ctx.buffer_threshold:
|
||||
# do not feed the prompt tokens if a high temperature was used
|
||||
del ctx.buffer_tokens
|
||||
ctx.buffer_tokens = []
|
||||
|
@ -250,7 +250,7 @@ class WhisperStreamingTranscriber:
|
|||
.to(self.model.device) # type: ignore
|
||||
.to(self.dtype)
|
||||
)
|
||||
if not self.config.allow_padding and segment.shape[-1] > mel.shape[-1]:
|
||||
if not ctx.allow_padding and segment.shape[-1] > mel.shape[-1]:
|
||||
logger.warning("Padding is not expected while speaking")
|
||||
|
||||
logger.debug(
|
||||
|
@ -267,10 +267,10 @@ class WhisperStreamingTranscriber:
|
|||
f"avg_logprob={result.avg_logprob:.2f}"
|
||||
)
|
||||
|
||||
if self.config.no_speech_threshold is not None:
|
||||
if (result.no_speech_prob > self.config.no_speech_threshold) and not (
|
||||
self.config.logprob_threshold is not None
|
||||
and result.avg_logprob > self.config.logprob_threshold
|
||||
if ctx.no_speech_threshold is not None:
|
||||
if (result.no_speech_prob > ctx.no_speech_threshold) and not (
|
||||
ctx.logprob_threshold is not None
|
||||
and result.avg_logprob > ctx.logprob_threshold
|
||||
):
|
||||
seek += segment.shape[-1]
|
||||
logger.debug(
|
||||
|
@ -295,7 +295,7 @@ class WhisperStreamingTranscriber:
|
|||
seek += last_timestamp_position * self.input_stride
|
||||
logger.debug(f"new seek={seek}, mel.shape: {mel.shape}")
|
||||
|
||||
if (not self.config.allow_padding) and (mel.shape[-1] - seek < N_FRAMES):
|
||||
if (not ctx.allow_padding) and (mel.shape[-1] - seek < N_FRAMES):
|
||||
break
|
||||
|
||||
if mel.shape[-1] - seek <= 0:
|
||||
|
|
Loading…
Reference in a new issue