close properly

This commit is contained in:
ricwo 2022-11-10 09:26:50 +01:00
parent 59b1230af7
commit e4e8cad3da
2 changed files with 98 additions and 70 deletions

View file

@ -1,24 +1,39 @@
import sys
import wave
import numpy
import websockets
import argparse
import asyncio
import base64
import datetime
import json
import logging
import os
import time
from pathlib import Path
import coloredlogs
import soundfile as sf
import websockets
logging.basicConfig(level=logging.INFO)
# URL = "ws://ec2-3-71-72-35.eu-central-1.compute.amazonaws.com"
# URL = "ws://localhost:1111/ws"
URL = "ws://localhost:8000"
FNAME = sys.argv[1]
logger = logging.getLogger(__name__)
coloredlogs.install(
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
format="%(asctime)s %(levelname)s %(message)s",
)
parser = argparse.ArgumentParser()
parser.add_argument("-f", "--file", help="file to send")
parser.add_argument(
"-u",
"--url",
help="url to connect to",
default="ws://ec2-3-71-72-35.eu-central-1.compute.amazonaws.com",
)
parser.add_argument(
"-c",
"--chunk-size",
help="chunk size in bytes",
default=100000,
type=int,
)
initial_context = {
"context": {
@ -42,66 +57,37 @@ initial_context = {
}
}
# # read binary file in chunks
# def wave_generator(chunk_size=50000):
# with open(FNAME, "rb") as f:
# while True:
# data = f.read(chunk_size)
# if not data:
# break
# yield data
# def wave_generator():
# with wave.open(FNAME, "rb") as f:
# while True:
# data = f.readframes(100000)
# if not data:
# break
# yield data
CHUNK_SIZE = 100000
def wave_generator(chunk_size: int = CHUNK_SIZE):
for data in sf.blocks(FNAME, blocksize=chunk_size, dtype="float32"):
def wave_generator():
for data in sf.blocks(_tfile, blocksize=args.chunk_size, dtype="float32"):
yield data.tobytes()
# def wave_generator(chunk_size=60000):
# from scipy.io.wavfile import read
#
# a = read(FNAME)
# data = numpy.array(a[1], dtype=numpy.float32)
# print(f"Data: {data.shape}")
# for i in range(0, data.shape[0], chunk_size):
# print(f"From {i} to {i+chunk_size}")
# yield data[i : i + chunk_size]
async def send_receive():
print(f"Connecting websocket to url ${URL}")
logger.info(f"Connecting websocket to url ${args.url}")
wave_gen = wave_generator()
start_time = time.time()
async with websockets.connect(
URL,
args.url,
ping_interval=5,
ping_timeout=20,
) as _ws:
await asyncio.sleep(0.1)
# print("Receiving SessionBegins ...")
logger.info(f"Sending initial context")
await _ws.send(json.dumps(initial_context))
# session_begins = await _ws.recv()
# print(session_begins)
print("Sending messages ...")
logger.info("Sending messages ...")
async def send():
print("Gonna send")
i = 1
while True:
try:
# print("Trying to send")
segment = next(wave_gen)
print(f"Sending segment of length {len(segment)}")
if len(segment) == CHUNK_SIZE * 4:
logger.info(f"Sending segment {i} of length {len(segment)}")
if len(segment) == args.chunk_size * 4:
msg = segment
else:
@ -109,37 +95,68 @@ async def send_receive():
payload = base64.b64encode(segment).decode("utf-8")
msg = json.dumps({"last_message": payload})
i += 1
await _ws.send(msg)
except websockets.exceptions.ConnectionClosedError as e:
print(e)
logger.error(e)
assert e.code == 4008
break
except Exception as e:
print(f"Caught send exception {type(e)}")
print(e)
except StopIteration:
break
await asyncio.sleep(1)
except Exception as e:
logger.error(f"Caught send exception {type(e)}")
logger.error(e)
break
await asyncio.sleep(1e-5)
return True
async def receive():
i = 0
full_text = ""
while True:
try:
result_str = await _ws.recv()
# print(f"Result: {result_str}")
rjs = json.loads(result_str)
print(f"Result: '{rjs['text']}'")
if "close_connection" in rjs:
logger.info("Closing connection")
break
res = rjs["text"]
logger.info(f"Result {i}:\t{res.strip()}")
full_text += res
except websockets.exceptions.ConnectionClosedError as e:
print(e)
logger.error(e)
assert e.code == 4008
break
except Exception as e:
print(f"Caught exception {type(e)}")
except StopIteration:
break
except Exception as e:
logger.error(f"Caught exception {type(e)}")
break
i += 1
logger.info(f"Full text:\n---\n{full_text.strip()}\n---")
logger.info(f"Total time: {time.time() - start_time:.3f} s")
send_result, receive_result = await asyncio.gather(send(), receive())
def convert_to_wav(tfile_name):
os.system(f"ffmpeg -i {args.file} -acodec pcm_s16le -ac 1 -ar 16000 {tfile_name}")
logger.info(f"Converted file to {tfile_name}")
if __name__ == "__main__":
loop = asyncio.get_event_loop()
loop.run_until_complete(send_receive())
args = parser.parse_args()
_tfile = f"tmp_{args.file}.wav"
try:
convert_to_wav(_tfile)
loop = asyncio.get_event_loop()
loop.run_until_complete(send_receive())
finally:
Path(_tfile).unlink()

View file

@ -24,6 +24,7 @@ async def serve_with_websocket_main(websocket):
idx: int = 0
ctx: Optional[Context] = None
j = 0
while True:
logger.debug(f"Audio #: {idx}")
try:
@ -33,6 +34,10 @@ async def serve_with_websocket_main(websocket):
force_padding = False
logger.debug(f"j={j}, padding={force_padding}")
j += 1
if isinstance(message, str) and not ctx:
logger.debug(f"Got str: {message}")
d = json.loads(message)
@ -88,17 +93,23 @@ async def serve_with_websocket_main(websocket):
)
)
return
logger.debug(f"Have message of size {len(message)} data type {ctx.data_type}")
# bytes to np array
# audio = np.frombuffer(message, dtype=ctx.data_type)
audio = np.frombuffer(message, dtype=np.dtype(ctx.data_type)).astype(np.float32)
logger.warning(f"Have audio shape {audio.shape}")
k = 0
for chunk in g_wsp.transcribe(
audio=audio, ctx=ctx, force_padding=force_padding # type: ignore
):
await websocket.send(chunk.json())
cjs = json.loads(chunk.json())
logger.debug(f"k={k}, padding={force_padding}")
await websocket.send(json.dumps(cjs))
k += 1
if force_padding:
await websocket.send(json.dumps({"close_connection": force_padding}))
idx += 1