whispering/whispering/serve.py

109 lines
2.9 KiB
Python
Raw Permalink Normal View History

2022-09-24 11:45:20 +00:00
#!/usr/bin/env python3
import asyncio
2022-10-02 12:59:02 +00:00
import json
2022-09-24 11:45:20 +00:00
from logging import getLogger
2022-10-15 04:23:00 +00:00
from typing import Final, Optional
2022-09-24 11:45:20 +00:00
import numpy as np
import websockets
from websockets.exceptions import ConnectionClosedOK
2022-09-24 11:45:20 +00:00
2022-10-17 13:09:23 +00:00
from whispering.schema import CURRENT_PROTOCOL_VERSION, Context
2022-10-17 12:46:15 +00:00
from whispering.transcriber import WhisperStreamingTranscriber
2022-09-24 11:45:20 +00:00
logger = getLogger(__name__)
2022-10-15 04:23:00 +00:00
MIN_PROTOCOL_VERSION: Final[int] = int("000_006_000")
2022-10-17 13:09:23 +00:00
MAX_PROTOCOL_VERSION: Final[int] = CURRENT_PROTOCOL_VERSION
2022-10-15 04:23:00 +00:00
2022-09-24 11:45:20 +00:00
async def serve_with_websocket_main(websocket):
global g_wsp
idx: int = 0
2022-10-02 12:59:02 +00:00
ctx: Optional[Context] = None
2022-09-24 11:45:20 +00:00
while True:
2022-10-02 10:48:41 +00:00
logger.debug(f"Audio #: {idx}")
try:
message = await websocket.recv()
except ConnectionClosedOK:
break
2022-09-24 11:45:20 +00:00
if isinstance(message, str):
logger.debug(f"Got str: {message}")
2022-10-02 13:02:17 +00:00
d = json.loads(message)
v = d.get("context")
if v is not None:
ctx = Context.parse_obj(v)
else:
await websocket.send(
json.dumps(
{
"error": "unsupported message",
}
)
)
return
2022-10-15 04:23:00 +00:00
if ctx.protocol_version < MIN_PROTOCOL_VERSION:
await websocket.send(
json.dumps(
{
"error": f"protocol_version is older than {MIN_PROTOCOL_VERSION}"
}
)
)
elif ctx.protocol_version > MAX_PROTOCOL_VERSION:
await websocket.send(
json.dumps(
{
"error": f"protocol_version is newer than {MAX_PROTOCOL_VERSION}"
}
)
)
return
2022-09-24 11:45:20 +00:00
continue
logger.debug(f"Message size: {len(message)}")
2022-10-02 12:59:02 +00:00
if ctx is None:
await websocket.send(
json.dumps(
{
"error": "no context",
}
)
)
return
2022-10-17 12:46:15 +00:00
audio = np.frombuffer(message, dtype=np.dtype(ctx.data_type)).astype(np.float32)
2022-10-01 14:20:46 +00:00
for chunk in g_wsp.transcribe(
2022-10-02 10:48:41 +00:00
audio=audio, # type: ignore
ctx=ctx,
2022-10-01 14:20:46 +00:00
):
2022-09-24 12:39:30 +00:00
await websocket.send(chunk.json())
2022-09-24 11:45:20 +00:00
idx += 1
async def serve_with_websocket(
*,
wsp: WhisperStreamingTranscriber,
host: str,
port: int,
):
logger.info(f"Serve at {host}:{port}")
logger.info("Make secure with your responsibility!")
global g_wsp
g_wsp = wsp
try:
async with websockets.serve( # type: ignore
serve_with_websocket_main,
host=host,
port=port,
max_size=999999999,
):
await asyncio.Future()
except KeyboardInterrupt:
pass