This commit is contained in:
Yuta Hayashibe 2022-10-17 21:29:12 +09:00
parent 62b6d9a3b0
commit 40dd471f6b
2 changed files with 57 additions and 33 deletions

View file

@ -1,15 +1,13 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import sys
from unittest.mock import patch
from pydantic import BaseModel from pydantic import BaseModel
from whispering.cli import get_opts, is_valid_arg from whispering.cli import Mode, is_valid_arg
class ArgExample(BaseModel): class ArgExample(BaseModel):
mode: Mode
cmd: str cmd: str
ok: bool ok: bool
@ -17,17 +15,28 @@ class ArgExample(BaseModel):
def test_options(): def test_options():
exs = [ exs = [
ArgExample(cmd="--mode server --mic 0", ok=False), ArgExample(mode=Mode.server, cmd="--mic 0", ok=False),
ArgExample(cmd="--mode server --mic 1", ok=False), ArgExample(mode=Mode.server, cmd="--mic 1", ok=False),
ArgExample(cmd="--mode server --beam_size 3", ok=False), ArgExample(
ArgExample(cmd="--mode server --temperature 0", ok=False), mode=Mode.server,
ArgExample(cmd="--mode server --num_block 3", ok=False), cmd="--host 0.0.0.0 --port 8000",
ArgExample(cmd="--mode mic --host 0.0.0.0", ok=False), ok=True,
ArgExample(cmd="--mode mic --port 8000", ok=False), ),
ArgExample(
mode=Mode.server,
cmd="--language en --model tiny --host 0.0.0.0 --port 8000",
ok=True,
),
ArgExample(mode=Mode.server, cmd="--beam_size 3", ok=False),
ArgExample(mode=Mode.server, cmd="--temperature 0", ok=False),
ArgExample(mode=Mode.server, cmd="--num_block 3", ok=False),
ArgExample(mode=Mode.mic, cmd="--host 0.0.0.0", ok=False),
ArgExample(mode=Mode.mic, cmd="--port 8000", ok=False),
] ]
for ex in exs: for ex in exs:
with patch.object(sys, "argv", [""] + ex.cmd.split()): ok = is_valid_arg(
opts = get_opts() mode=ex.mode.value,
ok = is_valid_arg(opts) args=ex.cmd.split(),
assert ok is ex.ok, f"{ex.cmd} should be invalid" )
assert ok is ex.ok, f"{ex.cmd} should be {ex.ok}"

View file

@ -7,7 +7,7 @@ import sys
from enum import Enum from enum import Enum
from logging import DEBUG, INFO, basicConfig, getLogger from logging import DEBUG, INFO, basicConfig, getLogger
from pathlib import Path from pathlib import Path
from typing import Iterator, Optional, Union from typing import Iterator, List, Optional, Union
import sounddevice as sd import sounddevice as sd
import torch import torch
@ -240,24 +240,36 @@ def show_devices():
print(f"{i}: {device['name']}") print(f"{i}: {device['name']}")
def is_valid_arg(opts) -> bool: def is_valid_arg(
*,
args: List[str],
mode: str,
) -> bool:
keys = [] keys = []
if opts.mode == Mode.server.value: if mode == Mode.server.value:
keys = [ keys = {
"mic", "--mic",
"beam_size", "--beam_size",
"temperature", "-b",
] "--temperature",
elif opts.mode == Mode.mic.value: "-t",
keys = [ "--num_block",
"host", "-n",
"port", "--vad",
] "--max_nospeech_skip",
"--output",
"--show-devices",
"--no-progress",
}
elif mode == Mode.mic.value:
keys = {
"--host",
"--port",
}
for key in keys: for arg in args:
_val = vars(opts).get(key) if arg in keys:
if _val is not None and _val is not False: sys.stderr.write(f"{arg} is not accepted option for {mode} mode\n")
sys.stderr.write(f"{key} is not accepted option for {opts.mode} mode\n")
return False return False
return True return True
@ -280,7 +292,10 @@ def main() -> None:
): ):
opts.mode = Mode.server.value opts.mode = Mode.server.value
if not is_valid_arg(opts): if not is_valid_arg(
args=sys.argv[1:],
mode=opts.mode,
):
sys.exit(1) sys.exit(1)
if opts.mode == Mode.client.value: if opts.mode == Mode.client.value: