Merge websocket client to CLI

This commit is contained in:
Yuta Hayashibe 2022-09-25 04:13:28 +09:00
parent b9e8ace9af
commit ee4e1bf29b
3 changed files with 44 additions and 66 deletions

View file

@ -54,16 +54,14 @@ poetry run whisper_streaming --language en --model tiny --host 0.0.0.0 --port 80
``` ```
You can set ``--allow-padding`` and other options. You can set ``--allow-padding`` and other options.
(``-n`` for hosts makes no sense)
### Client ### Client
```bash ```bash
poetry run python -m whisper_streaming.websocket_client --host ADDRESS_OF_HOST --port 8000 -n 20 poetry run whisper_streaming --model tiny --host ADDRESS_OF_HOST --port 8000 --mode client
``` ```
You can set ``-n`` and other options. You can set ``-n`` and other options.
(``--allow-padding`` for clients makes no sense)
## Tips ## Tips

View file

@ -15,6 +15,7 @@ from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
from whisper_streaming.schema import WhisperConfig from whisper_streaming.schema import WhisperConfig
from whisper_streaming.serve import serve_with_websocket from whisper_streaming.serve import serve_with_websocket
from whisper_streaming.transcriber import WhisperStreamingTranscriber from whisper_streaming.transcriber import WhisperStreamingTranscriber
from whisper_streaming.websocket_client import run_websocket_client
logger = getLogger(__name__) logger = getLogger(__name__)
@ -58,13 +59,11 @@ def get_opts() -> argparse.Namespace:
default=None, default=None,
choices=sorted(LANGUAGES.keys()) choices=sorted(LANGUAGES.keys())
+ sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]),
required=True,
) )
parser.add_argument( parser.add_argument(
"--model", "--model",
type=str, type=str,
choices=available_models(), choices=available_models(),
required=True,
) )
parser.add_argument( parser.add_argument(
"--device", "--device",
@ -113,10 +112,29 @@ def get_opts() -> argparse.Namespace:
"--allow-padding", "--allow-padding",
action="store_true", action="store_true",
) )
parser.add_argument(
"--mode",
choices=["client"],
)
return parser.parse_args() return parser.parse_args()
def get_wshiper(*, opts):
config = WhisperConfig(
model_name=opts.model,
language=opts.language,
device=opts.device,
beam_size=opts.beam_size,
temperatures=opts.temperature,
allow_padding=opts.allow_padding,
)
logger.debug(f"WhisperConfig: {config}")
wsp = WhisperStreamingTranscriber(config=config)
return wsp
def main() -> None: def main() -> None:
opts = get_opts() opts = get_opts()
basicConfig( basicConfig(
@ -135,18 +153,19 @@ def main() -> None:
except Exception: except Exception:
pass pass
config = WhisperConfig(
model_name=opts.model,
language=opts.language,
device=opts.device,
beam_size=opts.beam_size,
temperatures=opts.temperature,
allow_padding=opts.allow_padding,
)
logger.debug(f"WhisperConfig: {config}")
wsp = WhisperStreamingTranscriber(config=config)
if opts.host is not None and opts.port is not None: if opts.host is not None and opts.port is not None:
if opts.mode == "client":
assert opts.language is None
assert opts.model is None
asyncio.run(
run_websocket_client(
opts=opts,
)
)
else:
assert opts.language is not None
assert opts.model is not None
wsp = get_wshiper(opts=opts)
asyncio.run( asyncio.run(
serve_with_websocket( serve_with_websocket(
wsp=wsp, wsp=wsp,
@ -155,6 +174,9 @@ def main() -> None:
) )
) )
else: else:
assert opts.language is not None
assert opts.model is not None
wsp = get_wshiper(opts=opts)
transcribe_from_mic( transcribe_from_mic(
wsp=wsp, wsp=wsp,
sd_device=opts.mic, sd_device=opts.mic,

View file

@ -1,7 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import argparse
import asyncio import asyncio
from logging import DEBUG, INFO, basicConfig, getLogger from logging import getLogger
from typing import Optional, Union from typing import Optional, Union
import sounddevice as sd import sounddevice as sd
@ -68,44 +67,7 @@ async def transcribe_from_mic_and_send(
idx += 1 idx += 1
def get_opts() -> argparse.Namespace: async def run_websocket_client(*, opts) -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"--host",
required=True,
help="host of websocker server",
)
parser.add_argument(
"--port",
type=int,
required=True,
help="Port number of websocker server",
)
parser.add_argument(
"--mic",
)
parser.add_argument(
"--num_block",
"-n",
type=int,
default=160,
help="Number of operation unit",
)
parser.add_argument(
"--debug",
action="store_true",
)
return parser.parse_args()
async def main() -> None:
opts = get_opts()
basicConfig(
level=DEBUG if opts.debug else INFO,
format="[%(asctime)s] %(module)s.%(funcName)s:%(lineno)d %(levelname)s -> %(message)s",
)
global q global q
global loop global loop
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
@ -117,7 +79,3 @@ async def main() -> None:
host=opts.host, host=opts.host,
port=opts.port, port=opts.port,
) )
if __name__ == "__main__":
asyncio.run(main())