diff --git a/whispering/cli.py b/whispering/cli.py index 7b68f14..8c1641a 100644 --- a/whispering/cli.py +++ b/whispering/cli.py @@ -4,6 +4,7 @@ import argparse import asyncio import queue import sys +from enum import Enum from logging import DEBUG, INFO, basicConfig, getLogger from typing import Optional, Union @@ -22,6 +23,15 @@ from whispering.websocket_client import run_websocket_client logger = getLogger(__name__) +class Mode(Enum): + client = "client" + server = "server" + mic = "mic" + + def __str__(self): + return self.value + + def transcribe_from_mic( *, wsp: WhisperStreamingTranscriber, @@ -151,7 +161,7 @@ def get_opts() -> argparse.Namespace: ) group_ctx.add_argument( "--mode", - choices=["client"], + choices=[v.value for v in Mode], ) group_misc = parser.add_argument_group("Other options") @@ -225,31 +235,33 @@ def main() -> None: if opts.show_devices: return show_devices() - if opts.host is not None and opts.port is not None: - if opts.mode == "client": - assert opts.language is None - assert opts.model is None - try: - asyncio.run( - run_websocket_client( - opts=opts, - ) - ) - except KeyboardInterrupt: - pass - else: - assert opts.language is not None - assert opts.model is not None - wsp = get_wshiper(opts=opts) - ctx: Context = get_context(opts=opts) + if opts.host is not None and opts.port is not None and opts.mode != Mode.client: + opts.mode = Mode.server + + if opts.mode == Mode.client: + assert opts.language is None + assert opts.model is None + try: asyncio.run( - serve_with_websocket( - wsp=wsp, - host=opts.host, - port=opts.port, - ctx=ctx, + run_websocket_client( + opts=opts, ) ) + except KeyboardInterrupt: + pass + elif opts.mode == Mode.server: + assert opts.language is not None + assert opts.model is not None + wsp = get_wshiper(opts=opts) + ctx: Context = get_context(opts=opts) + asyncio.run( + serve_with_websocket( + wsp=wsp, + host=opts.host, + port=opts.port, + ctx=ctx, + ) + ) else: assert opts.language is not None assert opts.model is not None