mirror of
https://github.com/shirayu/whispering.git
synced 2024-11-22 00:41:02 +00:00
Fix bugs
This commit is contained in:
parent
62b6d9a3b0
commit
40dd471f6b
2 changed files with 57 additions and 33 deletions
|
@ -1,15 +1,13 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
|
||||||
import sys
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from whispering.cli import get_opts, is_valid_arg
|
from whispering.cli import Mode, is_valid_arg
|
||||||
|
|
||||||
|
|
||||||
class ArgExample(BaseModel):
|
class ArgExample(BaseModel):
|
||||||
|
mode: Mode
|
||||||
cmd: str
|
cmd: str
|
||||||
ok: bool
|
ok: bool
|
||||||
|
|
||||||
|
@ -17,17 +15,28 @@ class ArgExample(BaseModel):
|
||||||
def test_options():
|
def test_options():
|
||||||
|
|
||||||
exs = [
|
exs = [
|
||||||
ArgExample(cmd="--mode server --mic 0", ok=False),
|
ArgExample(mode=Mode.server, cmd="--mic 0", ok=False),
|
||||||
ArgExample(cmd="--mode server --mic 1", ok=False),
|
ArgExample(mode=Mode.server, cmd="--mic 1", ok=False),
|
||||||
ArgExample(cmd="--mode server --beam_size 3", ok=False),
|
ArgExample(
|
||||||
ArgExample(cmd="--mode server --temperature 0", ok=False),
|
mode=Mode.server,
|
||||||
ArgExample(cmd="--mode server --num_block 3", ok=False),
|
cmd="--host 0.0.0.0 --port 8000",
|
||||||
ArgExample(cmd="--mode mic --host 0.0.0.0", ok=False),
|
ok=True,
|
||||||
ArgExample(cmd="--mode mic --port 8000", ok=False),
|
),
|
||||||
|
ArgExample(
|
||||||
|
mode=Mode.server,
|
||||||
|
cmd="--language en --model tiny --host 0.0.0.0 --port 8000",
|
||||||
|
ok=True,
|
||||||
|
),
|
||||||
|
ArgExample(mode=Mode.server, cmd="--beam_size 3", ok=False),
|
||||||
|
ArgExample(mode=Mode.server, cmd="--temperature 0", ok=False),
|
||||||
|
ArgExample(mode=Mode.server, cmd="--num_block 3", ok=False),
|
||||||
|
ArgExample(mode=Mode.mic, cmd="--host 0.0.0.0", ok=False),
|
||||||
|
ArgExample(mode=Mode.mic, cmd="--port 8000", ok=False),
|
||||||
]
|
]
|
||||||
|
|
||||||
for ex in exs:
|
for ex in exs:
|
||||||
with patch.object(sys, "argv", [""] + ex.cmd.split()):
|
ok = is_valid_arg(
|
||||||
opts = get_opts()
|
mode=ex.mode.value,
|
||||||
ok = is_valid_arg(opts)
|
args=ex.cmd.split(),
|
||||||
assert ok is ex.ok, f"{ex.cmd} should be invalid"
|
)
|
||||||
|
assert ok is ex.ok, f"{ex.cmd} should be {ex.ok}"
|
||||||
|
|
|
@ -7,7 +7,7 @@ import sys
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from logging import DEBUG, INFO, basicConfig, getLogger
|
from logging import DEBUG, INFO, basicConfig, getLogger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Iterator, Optional, Union
|
from typing import Iterator, List, Optional, Union
|
||||||
|
|
||||||
import sounddevice as sd
|
import sounddevice as sd
|
||||||
import torch
|
import torch
|
||||||
|
@ -240,24 +240,36 @@ def show_devices():
|
||||||
print(f"{i}: {device['name']}")
|
print(f"{i}: {device['name']}")
|
||||||
|
|
||||||
|
|
||||||
def is_valid_arg(opts) -> bool:
|
def is_valid_arg(
|
||||||
|
*,
|
||||||
|
args: List[str],
|
||||||
|
mode: str,
|
||||||
|
) -> bool:
|
||||||
keys = []
|
keys = []
|
||||||
if opts.mode == Mode.server.value:
|
if mode == Mode.server.value:
|
||||||
keys = [
|
keys = {
|
||||||
"mic",
|
"--mic",
|
||||||
"beam_size",
|
"--beam_size",
|
||||||
"temperature",
|
"-b",
|
||||||
]
|
"--temperature",
|
||||||
elif opts.mode == Mode.mic.value:
|
"-t",
|
||||||
keys = [
|
"--num_block",
|
||||||
"host",
|
"-n",
|
||||||
"port",
|
"--vad",
|
||||||
]
|
"--max_nospeech_skip",
|
||||||
|
"--output",
|
||||||
|
"--show-devices",
|
||||||
|
"--no-progress",
|
||||||
|
}
|
||||||
|
elif mode == Mode.mic.value:
|
||||||
|
keys = {
|
||||||
|
"--host",
|
||||||
|
"--port",
|
||||||
|
}
|
||||||
|
|
||||||
for key in keys:
|
for arg in args:
|
||||||
_val = vars(opts).get(key)
|
if arg in keys:
|
||||||
if _val is not None and _val is not False:
|
sys.stderr.write(f"{arg} is not accepted option for {mode} mode\n")
|
||||||
sys.stderr.write(f"{key} is not accepted option for {opts.mode} mode\n")
|
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -280,7 +292,10 @@ def main() -> None:
|
||||||
):
|
):
|
||||||
opts.mode = Mode.server.value
|
opts.mode = Mode.server.value
|
||||||
|
|
||||||
if not is_valid_arg(opts):
|
if not is_valid_arg(
|
||||||
|
args=sys.argv[1:],
|
||||||
|
mode=opts.mode,
|
||||||
|
):
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
if opts.mode == Mode.client.value:
|
if opts.mode == Mode.client.value:
|
||||||
|
|
Loading…
Reference in a new issue