Removed ctx from server

This commit is contained in:
Yuta Hayashibe 2022-10-02 21:59:02 +09:00
parent 36c547ad50
commit 3f60122e49
3 changed files with 49 additions and 15 deletions

View file

@ -224,6 +224,19 @@ def show_devices():
print(f"{i}: {device['name']}") print(f"{i}: {device['name']}")
def check_invalid_arg(opts):
ngs = []
if opts.mode == Mode.server.value:
ngs = [
"mic",
"allow_padding",
]
for ng in ngs:
if vars(opts).get(ng) not in {None, False}:
sys.stderr.write(f"{ng} is not accepted option for {opts.mode} mode\n")
sys.exit(1)
def main() -> None: def main() -> None:
opts = get_opts() opts = get_opts()
@ -242,13 +255,20 @@ def main() -> None:
): ):
opts.mode = Mode.server.value opts.mode = Mode.server.value
check_invalid_arg(opts)
if opts.mode == Mode.client.value: if opts.mode == Mode.client.value:
assert opts.language is None assert opts.language is None
assert opts.model is None assert opts.model is None
ctx: Context = get_context(opts=opts)
try: try:
asyncio.run( asyncio.run(
run_websocket_client( run_websocket_client(
opts=opts, sd_device=opts.mic,
num_block=opts.num_block,
host=opts.host,
port=opts.port,
no_progress=opts.no_progress,
ctx=ctx,
) )
) )
except KeyboardInterrupt: except KeyboardInterrupt:
@ -257,13 +277,11 @@ def main() -> None:
assert opts.language is not None assert opts.language is not None
assert opts.model is not None assert opts.model is not None
wsp = get_wshiper(opts=opts) wsp = get_wshiper(opts=opts)
ctx: Context = get_context(opts=opts)
asyncio.run( asyncio.run(
serve_with_websocket( serve_with_websocket(
wsp=wsp, wsp=wsp,
host=opts.host, host=opts.host,
port=opts.port, port=opts.port,
ctx=ctx,
) )
) )
else: else:

View file

@ -1,7 +1,9 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import asyncio import asyncio
import json
from logging import getLogger from logging import getLogger
from typing import Optional
import numpy as np import numpy as np
import websockets import websockets
@ -14,11 +16,8 @@ logger = getLogger(__name__)
async def serve_with_websocket_main(websocket): async def serve_with_websocket_main(websocket):
global g_wsp global g_wsp
global g_ctx
idx: int = 0 idx: int = 0
ctx: Context = g_ctx.copy( ctx: Optional[Context] = None
deep=True,
)
while True: while True:
logger.debug(f"Audio #: {idx}") logger.debug(f"Audio #: {idx}")
@ -33,6 +32,15 @@ async def serve_with_websocket_main(websocket):
logger.debug(f"Message size: {len(message)}") logger.debug(f"Message size: {len(message)}")
audio = np.frombuffer(message, dtype=np.float32) audio = np.frombuffer(message, dtype=np.float32)
if ctx is None:
await websocket.send(
json.dumps(
{
"error": "no context",
}
)
)
return
for chunk in g_wsp.transcribe( for chunk in g_wsp.transcribe(
audio=audio, # type: ignore audio=audio, # type: ignore
ctx=ctx, ctx=ctx,
@ -46,14 +54,11 @@ async def serve_with_websocket(
wsp: WhisperStreamingTranscriber, wsp: WhisperStreamingTranscriber,
host: str, host: str,
port: int, port: int,
ctx: Context,
): ):
logger.info(f"Serve at {host}:{port}") logger.info(f"Serve at {host}:{port}")
logger.info("Make secure with your responsibility!") logger.info("Make secure with your responsibility!")
global g_wsp global g_wsp
global g_ctx
g_wsp = wsp g_wsp = wsp
g_ctx = ctx
try: try:
async with websockets.serve( # type: ignore async with websockets.serve( # type: ignore

View file

@ -8,6 +8,7 @@ import websockets
from whisper.audio import N_FRAMES, SAMPLE_RATE from whisper.audio import N_FRAMES, SAMPLE_RATE
from whispering.schema import ParsedChunk from whispering.schema import ParsedChunk
from whispering.transcriber import Context
logger = getLogger(__name__) logger = getLogger(__name__)
@ -24,6 +25,7 @@ async def transcribe_from_mic_and_send(
num_block: int, num_block: int,
host: str, host: str,
port: int, port: int,
ctx: Context,
) -> None: ) -> None:
uri = f"ws://{host}:{port}" uri = f"ws://{host}:{port}"
@ -67,15 +69,24 @@ async def transcribe_from_mic_and_send(
idx += 1 idx += 1
async def run_websocket_client(*, opts) -> None: async def run_websocket_client(
*,
sd_device: Optional[Union[int, str]],
num_block: int,
host: str,
port: int,
ctx: Context,
no_progress: bool,
) -> None:
global q global q
global loop global loop
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
q = asyncio.Queue() q = asyncio.Queue()
await transcribe_from_mic_and_send( await transcribe_from_mic_and_send(
sd_device=opts.mic, sd_device=sd_device,
num_block=opts.num_block, num_block=num_block,
host=opts.host, host=host,
port=opts.port, port=port,
ctx=ctx,
) )