diff --git a/tests/test_cli.py b/tests/test_cli.py index d0b40e3..5504845 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -4,23 +4,30 @@ import sys from unittest.mock import patch +from pydantic import BaseModel + from whispering.cli import get_opts, is_valid_arg +class ArgExample(BaseModel): + cmd: str + ok: bool + + def test_options(): - invalid_args = [ - "--mode server --mic 0", - "--mode server --mic 1", - "--mode server --beam_size 3", - "--mode server --temperature 0", - "--mode server --num_block 3", - "--mode mic --host 0.0.0.0", - "--mode mic --port 8000", + exs = [ + ArgExample(cmd="--mode server --mic 0", ok=False), + ArgExample(cmd="--mode server --mic 1", ok=False), + ArgExample(cmd="--mode server --beam_size 3", ok=False), + ArgExample(cmd="--mode server --temperature 0", ok=False), + ArgExample(cmd="--mode server --num_block 3", ok=False), + ArgExample(cmd="--mode mic --host 0.0.0.0", ok=False), + ArgExample(cmd="--mode mic --port 8000", ok=False), ] - for invalid_arg in invalid_args: - with patch.object(sys, "argv", [""] + invalid_arg.split()): + for ex in exs: + with patch.object(sys, "argv", [""] + ex.cmd.split()): opts = get_opts() ok = is_valid_arg(opts) - assert ok is False, f"{invalid_arg} should be invalid" + assert ok is ex.ok, f"{ex.cmd} should be invalid"