mirror of
https://github.com/shirayu/whispering.git
synced 2025-02-16 10:35:16 +00:00
Removed ctx from server
This commit is contained in:
parent
36c547ad50
commit
3f60122e49
3 changed files with 49 additions and 15 deletions
|
@ -224,6 +224,19 @@ def show_devices():
|
|||
print(f"{i}: {device['name']}")
|
||||
|
||||
|
||||
def check_invalid_arg(opts):
|
||||
ngs = []
|
||||
if opts.mode == Mode.server.value:
|
||||
ngs = [
|
||||
"mic",
|
||||
"allow_padding",
|
||||
]
|
||||
for ng in ngs:
|
||||
if vars(opts).get(ng) not in {None, False}:
|
||||
sys.stderr.write(f"{ng} is not accepted option for {opts.mode} mode\n")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
opts = get_opts()
|
||||
|
||||
|
@ -242,13 +255,20 @@ def main() -> None:
|
|||
):
|
||||
opts.mode = Mode.server.value
|
||||
|
||||
check_invalid_arg(opts)
|
||||
if opts.mode == Mode.client.value:
|
||||
assert opts.language is None
|
||||
assert opts.model is None
|
||||
ctx: Context = get_context(opts=opts)
|
||||
try:
|
||||
asyncio.run(
|
||||
run_websocket_client(
|
||||
opts=opts,
|
||||
sd_device=opts.mic,
|
||||
num_block=opts.num_block,
|
||||
host=opts.host,
|
||||
port=opts.port,
|
||||
no_progress=opts.no_progress,
|
||||
ctx=ctx,
|
||||
)
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
|
@ -257,13 +277,11 @@ def main() -> None:
|
|||
assert opts.language is not None
|
||||
assert opts.model is not None
|
||||
wsp = get_wshiper(opts=opts)
|
||||
ctx: Context = get_context(opts=opts)
|
||||
asyncio.run(
|
||||
serve_with_websocket(
|
||||
wsp=wsp,
|
||||
host=opts.host,
|
||||
port=opts.port,
|
||||
ctx=ctx,
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from logging import getLogger
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import websockets
|
||||
|
@ -14,11 +16,8 @@ logger = getLogger(__name__)
|
|||
|
||||
async def serve_with_websocket_main(websocket):
|
||||
global g_wsp
|
||||
global g_ctx
|
||||
idx: int = 0
|
||||
ctx: Context = g_ctx.copy(
|
||||
deep=True,
|
||||
)
|
||||
ctx: Optional[Context] = None
|
||||
|
||||
while True:
|
||||
logger.debug(f"Audio #: {idx}")
|
||||
|
@ -33,6 +32,15 @@ async def serve_with_websocket_main(websocket):
|
|||
|
||||
logger.debug(f"Message size: {len(message)}")
|
||||
audio = np.frombuffer(message, dtype=np.float32)
|
||||
if ctx is None:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"error": "no context",
|
||||
}
|
||||
)
|
||||
)
|
||||
return
|
||||
for chunk in g_wsp.transcribe(
|
||||
audio=audio, # type: ignore
|
||||
ctx=ctx,
|
||||
|
@ -46,14 +54,11 @@ async def serve_with_websocket(
|
|||
wsp: WhisperStreamingTranscriber,
|
||||
host: str,
|
||||
port: int,
|
||||
ctx: Context,
|
||||
):
|
||||
logger.info(f"Serve at {host}:{port}")
|
||||
logger.info("Make secure with your responsibility!")
|
||||
global g_wsp
|
||||
global g_ctx
|
||||
g_wsp = wsp
|
||||
g_ctx = ctx
|
||||
|
||||
try:
|
||||
async with websockets.serve( # type: ignore
|
||||
|
|
|
@ -8,6 +8,7 @@ import websockets
|
|||
from whisper.audio import N_FRAMES, SAMPLE_RATE
|
||||
|
||||
from whispering.schema import ParsedChunk
|
||||
from whispering.transcriber import Context
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -24,6 +25,7 @@ async def transcribe_from_mic_and_send(
|
|||
num_block: int,
|
||||
host: str,
|
||||
port: int,
|
||||
ctx: Context,
|
||||
) -> None:
|
||||
uri = f"ws://{host}:{port}"
|
||||
|
||||
|
@ -67,15 +69,24 @@ async def transcribe_from_mic_and_send(
|
|||
idx += 1
|
||||
|
||||
|
||||
async def run_websocket_client(*, opts) -> None:
|
||||
async def run_websocket_client(
|
||||
*,
|
||||
sd_device: Optional[Union[int, str]],
|
||||
num_block: int,
|
||||
host: str,
|
||||
port: int,
|
||||
ctx: Context,
|
||||
no_progress: bool,
|
||||
) -> None:
|
||||
global q
|
||||
global loop
|
||||
loop = asyncio.get_running_loop()
|
||||
q = asyncio.Queue()
|
||||
|
||||
await transcribe_from_mic_and_send(
|
||||
sd_device=opts.mic,
|
||||
num_block=opts.num_block,
|
||||
host=opts.host,
|
||||
port=opts.port,
|
||||
sd_device=sd_device,
|
||||
num_block=num_block,
|
||||
host=host,
|
||||
port=port,
|
||||
ctx=ctx,
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue