mirror of
https://github.com/shirayu/whispering.git
synced 2024-11-21 08:21:02 +00:00
Fix bugs
This commit is contained in:
parent
62b6d9a3b0
commit
40dd471f6b
2 changed files with 57 additions and 33 deletions
|
@ -1,15 +1,13 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
|
||||
import sys
|
||||
from unittest.mock import patch
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from whispering.cli import get_opts, is_valid_arg
|
||||
from whispering.cli import Mode, is_valid_arg
|
||||
|
||||
|
||||
class ArgExample(BaseModel):
|
||||
mode: Mode
|
||||
cmd: str
|
||||
ok: bool
|
||||
|
||||
|
@ -17,17 +15,28 @@ class ArgExample(BaseModel):
|
|||
def test_options():
|
||||
|
||||
exs = [
|
||||
ArgExample(cmd="--mode server --mic 0", ok=False),
|
||||
ArgExample(cmd="--mode server --mic 1", ok=False),
|
||||
ArgExample(cmd="--mode server --beam_size 3", ok=False),
|
||||
ArgExample(cmd="--mode server --temperature 0", ok=False),
|
||||
ArgExample(cmd="--mode server --num_block 3", ok=False),
|
||||
ArgExample(cmd="--mode mic --host 0.0.0.0", ok=False),
|
||||
ArgExample(cmd="--mode mic --port 8000", ok=False),
|
||||
ArgExample(mode=Mode.server, cmd="--mic 0", ok=False),
|
||||
ArgExample(mode=Mode.server, cmd="--mic 1", ok=False),
|
||||
ArgExample(
|
||||
mode=Mode.server,
|
||||
cmd="--host 0.0.0.0 --port 8000",
|
||||
ok=True,
|
||||
),
|
||||
ArgExample(
|
||||
mode=Mode.server,
|
||||
cmd="--language en --model tiny --host 0.0.0.0 --port 8000",
|
||||
ok=True,
|
||||
),
|
||||
ArgExample(mode=Mode.server, cmd="--beam_size 3", ok=False),
|
||||
ArgExample(mode=Mode.server, cmd="--temperature 0", ok=False),
|
||||
ArgExample(mode=Mode.server, cmd="--num_block 3", ok=False),
|
||||
ArgExample(mode=Mode.mic, cmd="--host 0.0.0.0", ok=False),
|
||||
ArgExample(mode=Mode.mic, cmd="--port 8000", ok=False),
|
||||
]
|
||||
|
||||
for ex in exs:
|
||||
with patch.object(sys, "argv", [""] + ex.cmd.split()):
|
||||
opts = get_opts()
|
||||
ok = is_valid_arg(opts)
|
||||
assert ok is ex.ok, f"{ex.cmd} should be invalid"
|
||||
ok = is_valid_arg(
|
||||
mode=ex.mode.value,
|
||||
args=ex.cmd.split(),
|
||||
)
|
||||
assert ok is ex.ok, f"{ex.cmd} should be {ex.ok}"
|
||||
|
|
|
@ -7,7 +7,7 @@ import sys
|
|||
from enum import Enum
|
||||
from logging import DEBUG, INFO, basicConfig, getLogger
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Optional, Union
|
||||
from typing import Iterator, List, Optional, Union
|
||||
|
||||
import sounddevice as sd
|
||||
import torch
|
||||
|
@ -240,24 +240,36 @@ def show_devices():
|
|||
print(f"{i}: {device['name']}")
|
||||
|
||||
|
||||
def is_valid_arg(opts) -> bool:
|
||||
def is_valid_arg(
|
||||
*,
|
||||
args: List[str],
|
||||
mode: str,
|
||||
) -> bool:
|
||||
keys = []
|
||||
if opts.mode == Mode.server.value:
|
||||
keys = [
|
||||
"mic",
|
||||
"beam_size",
|
||||
"temperature",
|
||||
]
|
||||
elif opts.mode == Mode.mic.value:
|
||||
keys = [
|
||||
"host",
|
||||
"port",
|
||||
]
|
||||
if mode == Mode.server.value:
|
||||
keys = {
|
||||
"--mic",
|
||||
"--beam_size",
|
||||
"-b",
|
||||
"--temperature",
|
||||
"-t",
|
||||
"--num_block",
|
||||
"-n",
|
||||
"--vad",
|
||||
"--max_nospeech_skip",
|
||||
"--output",
|
||||
"--show-devices",
|
||||
"--no-progress",
|
||||
}
|
||||
elif mode == Mode.mic.value:
|
||||
keys = {
|
||||
"--host",
|
||||
"--port",
|
||||
}
|
||||
|
||||
for key in keys:
|
||||
_val = vars(opts).get(key)
|
||||
if _val is not None and _val is not False:
|
||||
sys.stderr.write(f"{key} is not accepted option for {opts.mode} mode\n")
|
||||
for arg in args:
|
||||
if arg in keys:
|
||||
sys.stderr.write(f"{arg} is not accepted option for {mode} mode\n")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
@ -280,7 +292,10 @@ def main() -> None:
|
|||
):
|
||||
opts.mode = Mode.server.value
|
||||
|
||||
if not is_valid_arg(opts):
|
||||
if not is_valid_arg(
|
||||
args=sys.argv[1:],
|
||||
mode=opts.mode,
|
||||
):
|
||||
sys.exit(1)
|
||||
|
||||
if opts.mode == Mode.client.value:
|
||||
|
|
Loading…
Reference in a new issue