This commit is contained in:
Yuta Hayashibe 2022-10-17 21:29:12 +09:00
parent 62b6d9a3b0
commit 40dd471f6b
2 changed files with 57 additions and 33 deletions

View file

@ -1,15 +1,13 @@
#!/usr/bin/env python3
import sys
from unittest.mock import patch
from pydantic import BaseModel
from whispering.cli import get_opts, is_valid_arg
from whispering.cli import Mode, is_valid_arg
class ArgExample(BaseModel):
mode: Mode
cmd: str
ok: bool
@ -17,17 +15,28 @@ class ArgExample(BaseModel):
def test_options():
exs = [
ArgExample(cmd="--mode server --mic 0", ok=False),
ArgExample(cmd="--mode server --mic 1", ok=False),
ArgExample(cmd="--mode server --beam_size 3", ok=False),
ArgExample(cmd="--mode server --temperature 0", ok=False),
ArgExample(cmd="--mode server --num_block 3", ok=False),
ArgExample(cmd="--mode mic --host 0.0.0.0", ok=False),
ArgExample(cmd="--mode mic --port 8000", ok=False),
ArgExample(mode=Mode.server, cmd="--mic 0", ok=False),
ArgExample(mode=Mode.server, cmd="--mic 1", ok=False),
ArgExample(
mode=Mode.server,
cmd="--host 0.0.0.0 --port 8000",
ok=True,
),
ArgExample(
mode=Mode.server,
cmd="--language en --model tiny --host 0.0.0.0 --port 8000",
ok=True,
),
ArgExample(mode=Mode.server, cmd="--beam_size 3", ok=False),
ArgExample(mode=Mode.server, cmd="--temperature 0", ok=False),
ArgExample(mode=Mode.server, cmd="--num_block 3", ok=False),
ArgExample(mode=Mode.mic, cmd="--host 0.0.0.0", ok=False),
ArgExample(mode=Mode.mic, cmd="--port 8000", ok=False),
]
for ex in exs:
with patch.object(sys, "argv", [""] + ex.cmd.split()):
opts = get_opts()
ok = is_valid_arg(opts)
assert ok is ex.ok, f"{ex.cmd} should be invalid"
ok = is_valid_arg(
mode=ex.mode.value,
args=ex.cmd.split(),
)
assert ok is ex.ok, f"{ex.cmd} should be {ex.ok}"

View file

@ -7,7 +7,7 @@ import sys
from enum import Enum
from logging import DEBUG, INFO, basicConfig, getLogger
from pathlib import Path
from typing import Iterator, Optional, Union
from typing import Iterator, List, Optional, Union
import sounddevice as sd
import torch
@ -240,24 +240,36 @@ def show_devices():
print(f"{i}: {device['name']}")
def is_valid_arg(opts) -> bool:
def is_valid_arg(
*,
args: List[str],
mode: str,
) -> bool:
keys = []
if opts.mode == Mode.server.value:
keys = [
"mic",
"beam_size",
"temperature",
]
elif opts.mode == Mode.mic.value:
keys = [
"host",
"port",
]
if mode == Mode.server.value:
keys = {
"--mic",
"--beam_size",
"-b",
"--temperature",
"-t",
"--num_block",
"-n",
"--vad",
"--max_nospeech_skip",
"--output",
"--show-devices",
"--no-progress",
}
elif mode == Mode.mic.value:
keys = {
"--host",
"--port",
}
for key in keys:
_val = vars(opts).get(key)
if _val is not None and _val is not False:
sys.stderr.write(f"{key} is not accepted option for {opts.mode} mode\n")
for arg in args:
if arg in keys:
sys.stderr.write(f"{arg} is not accepted option for {mode} mode\n")
return False
return True
@ -280,7 +292,10 @@ def main() -> None:
):
opts.mode = Mode.server.value
if not is_valid_arg(opts):
if not is_valid_arg(
args=sys.argv[1:],
mode=opts.mode,
):
sys.exit(1)
if opts.mode == Mode.client.value: