diff --git a/README.md b/README.md index 22ae88f..60b5748 100644 --- a/README.md +++ b/README.md @@ -54,16 +54,14 @@ poetry run whisper_streaming --language en --model tiny --host 0.0.0.0 --port 80 ``` You can set ``--allow-padding`` and other options. -(``-n`` for hosts makes no sense) ### Client ```bash -poetry run python -m whisper_streaming.websocket_client --host ADDRESS_OF_HOST --port 8000 -n 20 +poetry run whisper_streaming --model tiny --host ADDRESS_OF_HOST --port 8000 --mode client ``` You can set ``-n`` and other options. -(``--allow-padding`` for clients makes no sense) ## Tips diff --git a/whisper_streaming/cli.py b/whisper_streaming/cli.py index e4ad406..5af0db2 100644 --- a/whisper_streaming/cli.py +++ b/whisper_streaming/cli.py @@ -15,6 +15,7 @@ from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE from whisper_streaming.schema import WhisperConfig from whisper_streaming.serve import serve_with_websocket from whisper_streaming.transcriber import WhisperStreamingTranscriber +from whisper_streaming.websocket_client import run_websocket_client logger = getLogger(__name__) @@ -58,13 +59,11 @@ def get_opts() -> argparse.Namespace: default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), - required=True, ) parser.add_argument( "--model", type=str, choices=available_models(), - required=True, ) parser.add_argument( "--device", @@ -113,10 +112,29 @@ def get_opts() -> argparse.Namespace: "--allow-padding", action="store_true", ) + parser.add_argument( + "--mode", + choices=["client"], + ) return parser.parse_args() +def get_wshiper(*, opts): + config = WhisperConfig( + model_name=opts.model, + language=opts.language, + device=opts.device, + beam_size=opts.beam_size, + temperatures=opts.temperature, + allow_padding=opts.allow_padding, + ) + + logger.debug(f"WhisperConfig: {config}") + wsp = WhisperStreamingTranscriber(config=config) + return wsp + + def main() -> None: opts = get_opts() basicConfig( @@ -135,26 +153,30 @@ def main() -> None: except Exception: pass - config = WhisperConfig( - model_name=opts.model, - language=opts.language, - device=opts.device, - beam_size=opts.beam_size, - temperatures=opts.temperature, - allow_padding=opts.allow_padding, - ) - - logger.debug(f"WhisperConfig: {config}") - wsp = WhisperStreamingTranscriber(config=config) if opts.host is not None and opts.port is not None: - asyncio.run( - serve_with_websocket( - wsp=wsp, - host=opts.host, - port=opts.port, + if opts.mode == "client": + assert opts.language is None + assert opts.model is None + asyncio.run( + run_websocket_client( + opts=opts, + ) + ) + else: + assert opts.language is not None + assert opts.model is not None + wsp = get_wshiper(opts=opts) + asyncio.run( + serve_with_websocket( + wsp=wsp, + host=opts.host, + port=opts.port, + ) ) - ) else: + assert opts.language is not None + assert opts.model is not None + wsp = get_wshiper(opts=opts) transcribe_from_mic( wsp=wsp, sd_device=opts.mic, diff --git a/whisper_streaming/websocket_client.py b/whisper_streaming/websocket_client.py index 4857c61..2eac7f8 100644 --- a/whisper_streaming/websocket_client.py +++ b/whisper_streaming/websocket_client.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 -import argparse import asyncio -from logging import DEBUG, INFO, basicConfig, getLogger +from logging import getLogger from typing import Optional, Union import sounddevice as sd @@ -68,44 +67,7 @@ async def transcribe_from_mic_and_send( idx += 1 -def get_opts() -> argparse.Namespace: - parser = argparse.ArgumentParser() - parser.add_argument( - "--host", - required=True, - help="host of websocker server", - ) - parser.add_argument( - "--port", - type=int, - required=True, - help="Port number of websocker server", - ) - - parser.add_argument( - "--mic", - ) - parser.add_argument( - "--num_block", - "-n", - type=int, - default=160, - help="Number of operation unit", - ) - parser.add_argument( - "--debug", - action="store_true", - ) - - return parser.parse_args() - - -async def main() -> None: - opts = get_opts() - basicConfig( - level=DEBUG if opts.debug else INFO, - format="[%(asctime)s] %(module)s.%(funcName)s:%(lineno)d %(levelname)s -> %(message)s", - ) +async def run_websocket_client(*, opts) -> None: global q global loop loop = asyncio.get_running_loop() @@ -117,7 +79,3 @@ async def main() -> None: host=opts.host, port=opts.port, ) - - -if __name__ == "__main__": - asyncio.run(main())