diff --git a/whispering/cli.py b/whispering/cli.py index 4183860..33e3e13 100644 --- a/whispering/cli.py +++ b/whispering/cli.py @@ -224,6 +224,19 @@ def show_devices(): print(f"{i}: {device['name']}") +def check_invalid_arg(opts): + ngs = [] + if opts.mode == Mode.server.value: + ngs = [ + "mic", + "allow_padding", + ] + for ng in ngs: + if vars(opts).get(ng) not in {None, False}: + sys.stderr.write(f"{ng} is not accepted option for {opts.mode} mode\n") + sys.exit(1) + + def main() -> None: opts = get_opts() @@ -242,13 +255,20 @@ def main() -> None: ): opts.mode = Mode.server.value + check_invalid_arg(opts) if opts.mode == Mode.client.value: assert opts.language is None assert opts.model is None + ctx: Context = get_context(opts=opts) try: asyncio.run( run_websocket_client( - opts=opts, + sd_device=opts.mic, + num_block=opts.num_block, + host=opts.host, + port=opts.port, + no_progress=opts.no_progress, + ctx=ctx, ) ) except KeyboardInterrupt: @@ -257,13 +277,11 @@ def main() -> None: assert opts.language is not None assert opts.model is not None wsp = get_wshiper(opts=opts) - ctx: Context = get_context(opts=opts) asyncio.run( serve_with_websocket( wsp=wsp, host=opts.host, port=opts.port, - ctx=ctx, ) ) else: diff --git a/whispering/serve.py b/whispering/serve.py index c0d2e16..1399a3c 100644 --- a/whispering/serve.py +++ b/whispering/serve.py @@ -1,7 +1,9 @@ #!/usr/bin/env python3 import asyncio +import json from logging import getLogger +from typing import Optional import numpy as np import websockets @@ -14,11 +16,8 @@ logger = getLogger(__name__) async def serve_with_websocket_main(websocket): global g_wsp - global g_ctx idx: int = 0 - ctx: Context = g_ctx.copy( - deep=True, - ) + ctx: Optional[Context] = None while True: logger.debug(f"Audio #: {idx}") @@ -33,6 +32,15 @@ async def serve_with_websocket_main(websocket): logger.debug(f"Message size: {len(message)}") audio = np.frombuffer(message, dtype=np.float32) + if ctx is None: + await websocket.send( + json.dumps( + { + "error": "no context", + } + ) + ) + return for chunk in g_wsp.transcribe( audio=audio, # type: ignore ctx=ctx, @@ -46,14 +54,11 @@ async def serve_with_websocket( wsp: WhisperStreamingTranscriber, host: str, port: int, - ctx: Context, ): logger.info(f"Serve at {host}:{port}") logger.info("Make secure with your responsibility!") global g_wsp - global g_ctx g_wsp = wsp - g_ctx = ctx try: async with websockets.serve( # type: ignore diff --git a/whispering/websocket_client.py b/whispering/websocket_client.py index 26540ac..f59de38 100644 --- a/whispering/websocket_client.py +++ b/whispering/websocket_client.py @@ -8,6 +8,7 @@ import websockets from whisper.audio import N_FRAMES, SAMPLE_RATE from whispering.schema import ParsedChunk +from whispering.transcriber import Context logger = getLogger(__name__) @@ -24,6 +25,7 @@ async def transcribe_from_mic_and_send( num_block: int, host: str, port: int, + ctx: Context, ) -> None: uri = f"ws://{host}:{port}" @@ -67,15 +69,24 @@ async def transcribe_from_mic_and_send( idx += 1 -async def run_websocket_client(*, opts) -> None: +async def run_websocket_client( + *, + sd_device: Optional[Union[int, str]], + num_block: int, + host: str, + port: int, + ctx: Context, + no_progress: bool, +) -> None: global q global loop loop = asyncio.get_running_loop() q = asyncio.Queue() await transcribe_from_mic_and_send( - sd_device=opts.mic, - num_block=opts.num_block, - host=opts.host, - port=opts.port, + sd_device=sd_device, + num_block=num_block, + host=host, + port=port, + ctx=ctx, )