Updated Context

This commit is contained in:
Yuta Hayashibe 2022-09-29 20:43:49 +09:00
parent 5abd0db07a
commit 88cda28b83
4 changed files with 51 additions and 36 deletions

View file

@ -25,6 +25,7 @@ def transcribe_from_mic(
wsp: WhisperStreamingTranscriber,
sd_device: Optional[Union[int, str]],
num_block: int,
ctx: Context,
) -> None:
q = queue.Queue()
@ -33,7 +34,6 @@ def transcribe_from_mic(
logger.warning(status)
q.put(indata.ravel())
ctx: Context = Context()
logger.info("Ready to transcribe")
with sd.InputStream(
samplerate=SAMPLE_RATE,
@ -136,14 +136,11 @@ def get_opts() -> argparse.Namespace:
return opts
def get_wshiper(*, opts):
def get_wshiper(*, opts) -> WhisperStreamingTranscriber:
config = WhisperConfig(
model_name=opts.model,
language=opts.language,
device=opts.device,
beam_size=opts.beam_size,
temperatures=opts.temperature,
allow_padding=opts.allow_padding,
)
logger.debug(f"WhisperConfig: {config}")
@ -151,6 +148,16 @@ def get_wshiper(*, opts):
return wsp
def get_context(*, opts) -> Context:
ctx = Context(
beam_size=opts.beam_size,
temperatures=opts.temperature,
allow_padding=opts.allow_padding,
)
logger.debug(f"Context: {ctx}")
return ctx
def show_devices():
devices = sd.query_devices()
for i, device in enumerate(devices):
@ -182,21 +189,25 @@ def main() -> None:
assert opts.language is not None
assert opts.model is not None
wsp = get_wshiper(opts=opts)
ctx: Context = get_context(opts=opts)
asyncio.run(
serve_with_websocket(
wsp=wsp,
host=opts.host,
port=opts.port,
ctx=ctx,
)
)
else:
assert opts.language is not None
assert opts.model is not None
wsp = get_wshiper(opts=opts)
ctx: Context = get_context(opts=opts)
transcribe_from_mic(
wsp=wsp,
sd_device=opts.mic,
num_block=opts.num_block,
ctx=ctx,
)

View file

@ -10,19 +10,7 @@ class WhisperConfig(BaseModel):
model_name: str
device: str
language: str
allow_padding: bool = False
temperatures: List[float]
fp16: bool = True
compression_ratio_threshold: Optional[float] = 2.4
logprob_threshold: Optional[float] = -1.0
no_captions_threshold: Optional[float] = 0.6
best_of: int = 5
beam_size: Optional[int] = None
no_speech_threshold: Optional[float] = 0.6
logprob_threshold: Optional[float] = -1.0
compression_ratio_threshold: Optional[float] = 2.4
buffer_threshold: Optional[float] = 0.5
@root_validator
def validate_model_name(cls, values):
@ -39,6 +27,19 @@ class Context(BaseModel, arbitrary_types_allowed=True):
buffer_tokens: List[torch.Tensor] = []
buffer_mel: Optional[torch.Tensor] = None
temperatures: List[float]
allow_padding: bool = False
patience: Optional[float] = None
compression_ratio_threshold: Optional[float] = 2.4
logprob_threshold: Optional[float] = -1.0
no_captions_threshold: Optional[float] = 0.6
best_of: int = 5
beam_size: Optional[int] = None
no_speech_threshold: Optional[float] = 0.6
logprob_threshold: Optional[float] = -1.0
compression_ratio_threshold: Optional[float] = 2.4
buffer_threshold: Optional[float] = 0.5
class ParsedChunk(BaseModel):
start: float

View file

@ -13,8 +13,8 @@ logger = getLogger(__name__)
async def serve_with_websocket_main(websocket):
global g_wsp
global g_ctx
idx: int = 0
ctx: Context = Context()
while True:
logger.debug(f"Segment #: {idx}")
@ -26,7 +26,7 @@ async def serve_with_websocket_main(websocket):
logger.debug(f"Message size: {len(message)}")
segment = np.frombuffer(message, dtype=np.float32)
for chunk in g_wsp.transcribe(segment=segment, ctx=ctx):
for chunk in g_wsp.transcribe(segment=segment, ctx=g_ctx):
await websocket.send(chunk.json())
idx += 1
@ -36,11 +36,14 @@ async def serve_with_websocket(
wsp: WhisperStreamingTranscriber,
host: str,
port: int,
ctx: Context,
):
logger.info(f"Serve at {host}:{port}")
logger.info("Make secure with your responsibility!")
global g_wsp
global g_ctx
g_wsp = wsp
g_ctx = ctx
try:
async with websockets.serve( # type: ignore

View file

@ -85,25 +85,25 @@ class WhisperStreamingTranscriber:
segment: np.ndarray,
ctx: Context,
) -> List[DecodingResult]:
assert len(self.config.temperatures) >= 1
t = self.config.temperatures[0]
assert len(ctx.temperatures) >= 1
t = ctx.temperatures[0]
logger.debug(f"temperature: {t}")
_decode_options1: DecodingOptions = self._get_decoding_options(
t=t,
prompt=ctx.buffer_tokens,
beam_size=self.config.beam_size,
beam_size=ctx.beam_size,
patience=None,
best_of=None,
)
results: List[DecodingResult] = self.model.decode(segment, _decode_options1) # type: ignore
for t in self.config.temperatures[1:]:
for t in ctx.temperatures[1:]:
needs_fallback = [
self.config.compression_ratio_threshold is not None
and result.compression_ratio > self.config.compression_ratio_threshold
or self.config.logprob_threshold is not None
and result.avg_logprob < self.config.logprob_threshold
ctx.compression_ratio_threshold is not None
and result.compression_ratio > ctx.compression_ratio_threshold
or ctx.logprob_threshold is not None
and result.avg_logprob < ctx.logprob_threshold
for result in results
]
if any(needs_fallback):
@ -114,8 +114,8 @@ class WhisperStreamingTranscriber:
t=t,
prompt=ctx.buffer_tokens,
beam_size=None,
patience=0.0,
best_of=self.config.best_of,
patience=ctx.patience,
best_of=ctx.best_of,
)
retries: List[DecodingResult] = self.model.decode(
segment[needs_fallback], _decode_options2 # type: ignore
@ -221,7 +221,7 @@ class WhisperStreamingTranscriber:
yield chunk
ctx.timestamp += duration
if result.temperature > self.config.buffer_threshold:
if result.temperature > ctx.buffer_threshold:
# do not feed the prompt tokens if a high temperature was used
del ctx.buffer_tokens
ctx.buffer_tokens = []
@ -250,7 +250,7 @@ class WhisperStreamingTranscriber:
.to(self.model.device) # type: ignore
.to(self.dtype)
)
if not self.config.allow_padding and segment.shape[-1] > mel.shape[-1]:
if not ctx.allow_padding and segment.shape[-1] > mel.shape[-1]:
logger.warning("Padding is not expected while speaking")
logger.debug(
@ -267,10 +267,10 @@ class WhisperStreamingTranscriber:
f"avg_logprob={result.avg_logprob:.2f}"
)
if self.config.no_speech_threshold is not None:
if (result.no_speech_prob > self.config.no_speech_threshold) and not (
self.config.logprob_threshold is not None
and result.avg_logprob > self.config.logprob_threshold
if ctx.no_speech_threshold is not None:
if (result.no_speech_prob > ctx.no_speech_threshold) and not (
ctx.logprob_threshold is not None
and result.avg_logprob > ctx.logprob_threshold
):
seek += segment.shape[-1]
logger.debug(
@ -295,7 +295,7 @@ class WhisperStreamingTranscriber:
seek += last_timestamp_position * self.input_stride
logger.debug(f"new seek={seek}, mel.shape: {mel.shape}")
if (not self.config.allow_padding) and (mel.shape[-1] - seek < N_FRAMES):
if (not ctx.allow_padding) and (mel.shape[-1] - seek < N_FRAMES):
break
if mel.shape[-1] - seek <= 0: