From 40dd471f6b2f4b6b212aff5c7be7afec4ff4b3e5 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Mon, 17 Oct 2022 21:29:12 +0900 Subject: [PATCH] Fix bugs --- tests/test_cli.py | 39 ++++++++++++++++++++++-------------- whispering/cli.py | 51 ++++++++++++++++++++++++++++++----------------- 2 files changed, 57 insertions(+), 33 deletions(-) diff --git a/tests/test_cli.py b/tests/test_cli.py index 5504845..a4f3ac7 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,15 +1,13 @@ #!/usr/bin/env python3 -import sys -from unittest.mock import patch - from pydantic import BaseModel -from whispering.cli import get_opts, is_valid_arg +from whispering.cli import Mode, is_valid_arg class ArgExample(BaseModel): + mode: Mode cmd: str ok: bool @@ -17,17 +15,28 @@ class ArgExample(BaseModel): def test_options(): exs = [ - ArgExample(cmd="--mode server --mic 0", ok=False), - ArgExample(cmd="--mode server --mic 1", ok=False), - ArgExample(cmd="--mode server --beam_size 3", ok=False), - ArgExample(cmd="--mode server --temperature 0", ok=False), - ArgExample(cmd="--mode server --num_block 3", ok=False), - ArgExample(cmd="--mode mic --host 0.0.0.0", ok=False), - ArgExample(cmd="--mode mic --port 8000", ok=False), + ArgExample(mode=Mode.server, cmd="--mic 0", ok=False), + ArgExample(mode=Mode.server, cmd="--mic 1", ok=False), + ArgExample( + mode=Mode.server, + cmd="--host 0.0.0.0 --port 8000", + ok=True, + ), + ArgExample( + mode=Mode.server, + cmd="--language en --model tiny --host 0.0.0.0 --port 8000", + ok=True, + ), + ArgExample(mode=Mode.server, cmd="--beam_size 3", ok=False), + ArgExample(mode=Mode.server, cmd="--temperature 0", ok=False), + ArgExample(mode=Mode.server, cmd="--num_block 3", ok=False), + ArgExample(mode=Mode.mic, cmd="--host 0.0.0.0", ok=False), + ArgExample(mode=Mode.mic, cmd="--port 8000", ok=False), ] for ex in exs: - with patch.object(sys, "argv", [""] + ex.cmd.split()): - opts = get_opts() - ok = is_valid_arg(opts) - assert ok is ex.ok, f"{ex.cmd} should be invalid" + ok = is_valid_arg( + mode=ex.mode.value, + args=ex.cmd.split(), + ) + assert ok is ex.ok, f"{ex.cmd} should be {ex.ok}" diff --git a/whispering/cli.py b/whispering/cli.py index 106cabe..63bb070 100644 --- a/whispering/cli.py +++ b/whispering/cli.py @@ -7,7 +7,7 @@ import sys from enum import Enum from logging import DEBUG, INFO, basicConfig, getLogger from pathlib import Path -from typing import Iterator, Optional, Union +from typing import Iterator, List, Optional, Union import sounddevice as sd import torch @@ -240,24 +240,36 @@ def show_devices(): print(f"{i}: {device['name']}") -def is_valid_arg(opts) -> bool: +def is_valid_arg( + *, + args: List[str], + mode: str, +) -> bool: keys = [] - if opts.mode == Mode.server.value: - keys = [ - "mic", - "beam_size", - "temperature", - ] - elif opts.mode == Mode.mic.value: - keys = [ - "host", - "port", - ] + if mode == Mode.server.value: + keys = { + "--mic", + "--beam_size", + "-b", + "--temperature", + "-t", + "--num_block", + "-n", + "--vad", + "--max_nospeech_skip", + "--output", + "--show-devices", + "--no-progress", + } + elif mode == Mode.mic.value: + keys = { + "--host", + "--port", + } - for key in keys: - _val = vars(opts).get(key) - if _val is not None and _val is not False: - sys.stderr.write(f"{key} is not accepted option for {opts.mode} mode\n") + for arg in args: + if arg in keys: + sys.stderr.write(f"{arg} is not accepted option for {mode} mode\n") return False return True @@ -280,7 +292,10 @@ def main() -> None: ): opts.mode = Mode.server.value - if not is_valid_arg(opts): + if not is_valid_arg( + args=sys.argv[1:], + mode=opts.mode, + ): sys.exit(1) if opts.mode == Mode.client.value: