Removed ctx from server

This commit is contained in:
Yuta Hayashibe 2022-10-02 21:59:02 +09:00
parent 36c547ad50
commit 3f60122e49
3 changed files with 49 additions and 15 deletions

View file

@ -224,6 +224,19 @@ def show_devices():
print(f"{i}: {device['name']}")
def check_invalid_arg(opts):
ngs = []
if opts.mode == Mode.server.value:
ngs = [
"mic",
"allow_padding",
]
for ng in ngs:
if vars(opts).get(ng) not in {None, False}:
sys.stderr.write(f"{ng} is not accepted option for {opts.mode} mode\n")
sys.exit(1)
def main() -> None:
opts = get_opts()
@ -242,13 +255,20 @@ def main() -> None:
):
opts.mode = Mode.server.value
check_invalid_arg(opts)
if opts.mode == Mode.client.value:
assert opts.language is None
assert opts.model is None
ctx: Context = get_context(opts=opts)
try:
asyncio.run(
run_websocket_client(
opts=opts,
sd_device=opts.mic,
num_block=opts.num_block,
host=opts.host,
port=opts.port,
no_progress=opts.no_progress,
ctx=ctx,
)
)
except KeyboardInterrupt:
@ -257,13 +277,11 @@ def main() -> None:
assert opts.language is not None
assert opts.model is not None
wsp = get_wshiper(opts=opts)
ctx: Context = get_context(opts=opts)
asyncio.run(
serve_with_websocket(
wsp=wsp,
host=opts.host,
port=opts.port,
ctx=ctx,
)
)
else:

View file

@ -1,7 +1,9 @@
#!/usr/bin/env python3
import asyncio
import json
from logging import getLogger
from typing import Optional
import numpy as np
import websockets
@ -14,11 +16,8 @@ logger = getLogger(__name__)
async def serve_with_websocket_main(websocket):
global g_wsp
global g_ctx
idx: int = 0
ctx: Context = g_ctx.copy(
deep=True,
)
ctx: Optional[Context] = None
while True:
logger.debug(f"Audio #: {idx}")
@ -33,6 +32,15 @@ async def serve_with_websocket_main(websocket):
logger.debug(f"Message size: {len(message)}")
audio = np.frombuffer(message, dtype=np.float32)
if ctx is None:
await websocket.send(
json.dumps(
{
"error": "no context",
}
)
)
return
for chunk in g_wsp.transcribe(
audio=audio, # type: ignore
ctx=ctx,
@ -46,14 +54,11 @@ async def serve_with_websocket(
wsp: WhisperStreamingTranscriber,
host: str,
port: int,
ctx: Context,
):
logger.info(f"Serve at {host}:{port}")
logger.info("Make secure with your responsibility!")
global g_wsp
global g_ctx
g_wsp = wsp
g_ctx = ctx
try:
async with websockets.serve( # type: ignore

View file

@ -8,6 +8,7 @@ import websockets
from whisper.audio import N_FRAMES, SAMPLE_RATE
from whispering.schema import ParsedChunk
from whispering.transcriber import Context
logger = getLogger(__name__)
@ -24,6 +25,7 @@ async def transcribe_from_mic_and_send(
num_block: int,
host: str,
port: int,
ctx: Context,
) -> None:
uri = f"ws://{host}:{port}"
@ -67,15 +69,24 @@ async def transcribe_from_mic_and_send(
idx += 1
async def run_websocket_client(*, opts) -> None:
async def run_websocket_client(
*,
sd_device: Optional[Union[int, str]],
num_block: int,
host: str,
port: int,
ctx: Context,
no_progress: bool,
) -> None:
global q
global loop
loop = asyncio.get_running_loop()
q = asyncio.Queue()
await transcribe_from_mic_and_send(
sd_device=opts.mic,
num_block=opts.num_block,
host=opts.host,
port=opts.port,
sd_device=sd_device,
num_block=num_block,
host=host,
port=port,
ctx=ctx,
)