Merge branch 'feature/ctx_msg'

This commit is contained in:
Yuta Hayashibe 2022-10-02 22:17:03 +09:00
commit 8d32e607df
4 changed files with 74 additions and 20 deletions

View file

@ -57,15 +57,13 @@ Run with ``--host`` and ``--port``.
whispering --language en --model tiny --host 0.0.0.0 --port 8000 whispering --language en --model tiny --host 0.0.0.0 --port 8000
``` ```
You can set ``--allow-padding`` and other options.
### Client ### Client
```bash ```bash
whispering --host ADDRESS_OF_HOST --port 8000 --mode client whispering --host ADDRESS_OF_HOST --port 8000 --mode client
``` ```
You can set ``-n`` and other options. You can set ``-n``, ``--allow-padding`` and other options.
## Tips ## Tips

View file

@ -224,6 +224,19 @@ def show_devices():
print(f"{i}: {device['name']}") print(f"{i}: {device['name']}")
def check_invalid_arg(opts):
ngs = []
if opts.mode == Mode.server.value:
ngs = [
"mic",
"allow_padding",
]
for ng in ngs:
if vars(opts).get(ng) not in {None, False}:
sys.stderr.write(f"{ng} is not accepted option for {opts.mode} mode\n")
sys.exit(1)
def main() -> None: def main() -> None:
opts = get_opts() opts = get_opts()
@ -242,13 +255,20 @@ def main() -> None:
): ):
opts.mode = Mode.server.value opts.mode = Mode.server.value
check_invalid_arg(opts)
if opts.mode == Mode.client.value: if opts.mode == Mode.client.value:
assert opts.language is None assert opts.language is None
assert opts.model is None assert opts.model is None
ctx: Context = get_context(opts=opts)
try: try:
asyncio.run( asyncio.run(
run_websocket_client( run_websocket_client(
opts=opts, sd_device=opts.mic,
num_block=opts.num_block,
host=opts.host,
port=opts.port,
no_progress=opts.no_progress,
ctx=ctx,
) )
) )
except KeyboardInterrupt: except KeyboardInterrupt:
@ -257,13 +277,11 @@ def main() -> None:
assert opts.language is not None assert opts.language is not None
assert opts.model is not None assert opts.model is not None
wsp = get_wshiper(opts=opts) wsp = get_wshiper(opts=opts)
ctx: Context = get_context(opts=opts)
asyncio.run( asyncio.run(
serve_with_websocket( serve_with_websocket(
wsp=wsp, wsp=wsp,
host=opts.host, host=opts.host,
port=opts.port, port=opts.port,
ctx=ctx,
) )
) )
else: else:

View file

@ -1,7 +1,9 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import asyncio import asyncio
import json
from logging import getLogger from logging import getLogger
from typing import Optional
import numpy as np import numpy as np
import websockets import websockets
@ -14,11 +16,8 @@ logger = getLogger(__name__)
async def serve_with_websocket_main(websocket): async def serve_with_websocket_main(websocket):
global g_wsp global g_wsp
global g_ctx
idx: int = 0 idx: int = 0
ctx: Context = g_ctx.copy( ctx: Optional[Context] = None
deep=True,
)
while True: while True:
logger.debug(f"Audio #: {idx}") logger.debug(f"Audio #: {idx}")
@ -29,10 +28,32 @@ async def serve_with_websocket_main(websocket):
if isinstance(message, str): if isinstance(message, str):
logger.debug(f"Got str: {message}") logger.debug(f"Got str: {message}")
d = json.loads(message)
v = d.get("context")
if v is not None:
ctx = Context.parse_obj(v)
else:
await websocket.send(
json.dumps(
{
"error": "unsupported message",
}
)
)
return
continue continue
logger.debug(f"Message size: {len(message)}") logger.debug(f"Message size: {len(message)}")
audio = np.frombuffer(message, dtype=np.float32) audio = np.frombuffer(message, dtype=np.float32)
if ctx is None:
await websocket.send(
json.dumps(
{
"error": "no context",
}
)
)
return
for chunk in g_wsp.transcribe( for chunk in g_wsp.transcribe(
audio=audio, # type: ignore audio=audio, # type: ignore
ctx=ctx, ctx=ctx,
@ -46,14 +67,11 @@ async def serve_with_websocket(
wsp: WhisperStreamingTranscriber, wsp: WhisperStreamingTranscriber,
host: str, host: str,
port: int, port: int,
ctx: Context,
): ):
logger.info(f"Serve at {host}:{port}") logger.info(f"Serve at {host}:{port}")
logger.info("Make secure with your responsibility!") logger.info("Make secure with your responsibility!")
global g_wsp global g_wsp
global g_ctx
g_wsp = wsp g_wsp = wsp
g_ctx = ctx
try: try:
async with websockets.serve( # type: ignore async with websockets.serve( # type: ignore

View file

@ -1,5 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import asyncio import asyncio
import json
import sys
from logging import getLogger from logging import getLogger
from typing import Optional, Union from typing import Optional, Union
@ -8,6 +10,7 @@ import websockets
from whisper.audio import N_FRAMES, SAMPLE_RATE from whisper.audio import N_FRAMES, SAMPLE_RATE
from whispering.schema import ParsedChunk from whispering.schema import ParsedChunk
from whispering.transcriber import Context
logger = getLogger(__name__) logger = getLogger(__name__)
@ -24,6 +27,7 @@ async def transcribe_from_mic_and_send(
num_block: int, num_block: int,
host: str, host: str,
port: int, port: int,
ctx: Context,
) -> None: ) -> None:
uri = f"ws://{host}:{port}" uri = f"ws://{host}:{port}"
@ -36,6 +40,10 @@ async def transcribe_from_mic_and_send(
callback=sd_callback, callback=sd_callback,
): ):
async with websockets.connect(uri, max_size=999999999) as ws: # type:ignore async with websockets.connect(uri, max_size=999999999) as ws: # type:ignore
logger.debug("Sent context")
v: str = ctx.json()
await ws.send("""{"context": """ + v + """}""")
idx: int = 0 idx: int = 0
while True: while True:
@ -59,23 +67,35 @@ async def transcribe_from_mic_and_send(
while True: while True:
try: try:
c = await asyncio.wait_for(recv(), timeout=0.5) c = await asyncio.wait_for(recv(), timeout=0.5)
chunk = ParsedChunk.parse_raw(c) c_json = json.loads(c)
if (err := c_json.get("error")) is not None:
print(f"Error: {err}")
sys.exit(1)
chunk = ParsedChunk.parse_obj(c_json)
print(f"{chunk.start:.2f}->{chunk.end:.2f}\t{chunk.text}") print(f"{chunk.start:.2f}->{chunk.end:.2f}\t{chunk.text}")
except asyncio.TimeoutError: except asyncio.TimeoutError:
break break
idx += 1 idx += 1
async def run_websocket_client(*, opts) -> None: async def run_websocket_client(
*,
sd_device: Optional[Union[int, str]],
num_block: int,
host: str,
port: int,
ctx: Context,
no_progress: bool,
) -> None:
global q global q
global loop global loop
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
q = asyncio.Queue() q = asyncio.Queue()
await transcribe_from_mic_and_send( await transcribe_from_mic_and_send(
sd_device=opts.mic, sd_device=sd_device,
num_block=opts.num_block, num_block=num_block,
host=opts.host, host=host,
port=opts.port, port=port,
ctx=ctx,
) )