mirror of
https://github.com/shirayu/whispering.git
synced 2024-06-13 10:49:22 +00:00
close properly
This commit is contained in:
parent
59b1230af7
commit
e4e8cad3da
147
stream_file.py
147
stream_file.py
|
@ -1,24 +1,39 @@
|
|||
import sys
|
||||
import wave
|
||||
|
||||
import numpy
|
||||
import websockets
|
||||
import argparse
|
||||
import asyncio
|
||||
import base64
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import coloredlogs
|
||||
import soundfile as sf
|
||||
import websockets
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
# URL = "ws://ec2-3-71-72-35.eu-central-1.compute.amazonaws.com"
|
||||
|
||||
# URL = "ws://localhost:1111/ws"
|
||||
|
||||
URL = "ws://localhost:8000"
|
||||
|
||||
FNAME = sys.argv[1]
|
||||
logger = logging.getLogger(__name__)
|
||||
coloredlogs.install(
|
||||
level=logging.INFO,
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
format="%(asctime)s %(levelname)s %(message)s",
|
||||
)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-f", "--file", help="file to send")
|
||||
parser.add_argument(
|
||||
"-u",
|
||||
"--url",
|
||||
help="url to connect to",
|
||||
default="ws://ec2-3-71-72-35.eu-central-1.compute.amazonaws.com",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--chunk-size",
|
||||
help="chunk size in bytes",
|
||||
default=100000,
|
||||
type=int,
|
||||
)
|
||||
|
||||
initial_context = {
|
||||
"context": {
|
||||
|
@ -42,66 +57,37 @@ initial_context = {
|
|||
}
|
||||
}
|
||||
|
||||
# # read binary file in chunks
|
||||
# def wave_generator(chunk_size=50000):
|
||||
# with open(FNAME, "rb") as f:
|
||||
# while True:
|
||||
# data = f.read(chunk_size)
|
||||
# if not data:
|
||||
# break
|
||||
# yield data
|
||||
|
||||
|
||||
# def wave_generator():
|
||||
# with wave.open(FNAME, "rb") as f:
|
||||
# while True:
|
||||
# data = f.readframes(100000)
|
||||
# if not data:
|
||||
# break
|
||||
# yield data
|
||||
|
||||
CHUNK_SIZE = 100000
|
||||
|
||||
|
||||
def wave_generator(chunk_size: int = CHUNK_SIZE):
|
||||
for data in sf.blocks(FNAME, blocksize=chunk_size, dtype="float32"):
|
||||
def wave_generator():
|
||||
for data in sf.blocks(_tfile, blocksize=args.chunk_size, dtype="float32"):
|
||||
yield data.tobytes()
|
||||
|
||||
|
||||
# def wave_generator(chunk_size=60000):
|
||||
# from scipy.io.wavfile import read
|
||||
#
|
||||
# a = read(FNAME)
|
||||
# data = numpy.array(a[1], dtype=numpy.float32)
|
||||
# print(f"Data: {data.shape}")
|
||||
# for i in range(0, data.shape[0], chunk_size):
|
||||
# print(f"From {i} to {i+chunk_size}")
|
||||
# yield data[i : i + chunk_size]
|
||||
|
||||
|
||||
async def send_receive():
|
||||
print(f"Connecting websocket to url ${URL}")
|
||||
logger.info(f"Connecting websocket to url ${args.url}")
|
||||
wave_gen = wave_generator()
|
||||
start_time = time.time()
|
||||
async with websockets.connect(
|
||||
URL,
|
||||
args.url,
|
||||
ping_interval=5,
|
||||
ping_timeout=20,
|
||||
) as _ws:
|
||||
await asyncio.sleep(0.1)
|
||||
# print("Receiving SessionBegins ...")
|
||||
logger.info(f"Sending initial context")
|
||||
await _ws.send(json.dumps(initial_context))
|
||||
# session_begins = await _ws.recv()
|
||||
# print(session_begins)
|
||||
print("Sending messages ...")
|
||||
logger.info("Sending messages ...")
|
||||
|
||||
async def send():
|
||||
print("Gonna send")
|
||||
i = 1
|
||||
while True:
|
||||
try:
|
||||
# print("Trying to send")
|
||||
segment = next(wave_gen)
|
||||
print(f"Sending segment of length {len(segment)}")
|
||||
if len(segment) == CHUNK_SIZE * 4:
|
||||
logger.info(f"Sending segment {i} of length {len(segment)}")
|
||||
if len(segment) == args.chunk_size * 4:
|
||||
msg = segment
|
||||
|
||||
else:
|
||||
|
@ -109,37 +95,68 @@ async def send_receive():
|
|||
payload = base64.b64encode(segment).decode("utf-8")
|
||||
msg = json.dumps({"last_message": payload})
|
||||
|
||||
i += 1
|
||||
|
||||
await _ws.send(msg)
|
||||
except websockets.exceptions.ConnectionClosedError as e:
|
||||
print(e)
|
||||
logger.error(e)
|
||||
assert e.code == 4008
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Caught send exception {type(e)}")
|
||||
print(e)
|
||||
except StopIteration:
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
except Exception as e:
|
||||
logger.error(f"Caught send exception {type(e)}")
|
||||
logger.error(e)
|
||||
break
|
||||
await asyncio.sleep(1e-5)
|
||||
|
||||
return True
|
||||
|
||||
async def receive():
|
||||
i = 0
|
||||
full_text = ""
|
||||
while True:
|
||||
try:
|
||||
result_str = await _ws.recv()
|
||||
# print(f"Result: {result_str}")
|
||||
rjs = json.loads(result_str)
|
||||
print(f"Result: '{rjs['text']}'")
|
||||
if "close_connection" in rjs:
|
||||
logger.info("Closing connection")
|
||||
break
|
||||
|
||||
res = rjs["text"]
|
||||
logger.info(f"Result {i}:\t{res.strip()}")
|
||||
full_text += res
|
||||
except websockets.exceptions.ConnectionClosedError as e:
|
||||
print(e)
|
||||
logger.error(e)
|
||||
assert e.code == 4008
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Caught exception {type(e)}")
|
||||
except StopIteration:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Caught exception {type(e)}")
|
||||
break
|
||||
|
||||
i += 1
|
||||
logger.info(f"Full text:\n---\n{full_text.strip()}\n---")
|
||||
logger.info(f"Total time: {time.time() - start_time:.3f} s")
|
||||
|
||||
send_result, receive_result = await asyncio.gather(send(), receive())
|
||||
|
||||
|
||||
def convert_to_wav(tfile_name):
|
||||
os.system(f"ffmpeg -i {args.file} -acodec pcm_s16le -ac 1 -ar 16000 {tfile_name}")
|
||||
logger.info(f"Converted file to {tfile_name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(send_receive())
|
||||
args = parser.parse_args()
|
||||
|
||||
_tfile = f"tmp_{args.file}.wav"
|
||||
|
||||
try:
|
||||
convert_to_wav(_tfile)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(send_receive())
|
||||
finally:
|
||||
Path(_tfile).unlink()
|
||||
|
|
|
@ -24,6 +24,7 @@ async def serve_with_websocket_main(websocket):
|
|||
idx: int = 0
|
||||
ctx: Optional[Context] = None
|
||||
|
||||
j = 0
|
||||
while True:
|
||||
logger.debug(f"Audio #: {idx}")
|
||||
try:
|
||||
|
@ -33,6 +34,10 @@ async def serve_with_websocket_main(websocket):
|
|||
|
||||
force_padding = False
|
||||
|
||||
logger.debug(f"j={j}, padding={force_padding}")
|
||||
|
||||
j += 1
|
||||
|
||||
if isinstance(message, str) and not ctx:
|
||||
logger.debug(f"Got str: {message}")
|
||||
d = json.loads(message)
|
||||
|
@ -88,17 +93,23 @@ async def serve_with_websocket_main(websocket):
|
|||
)
|
||||
)
|
||||
return
|
||||
logger.debug(f"Have message of size {len(message)} data type {ctx.data_type}")
|
||||
# bytes to np array
|
||||
# audio = np.frombuffer(message, dtype=ctx.data_type)
|
||||
|
||||
audio = np.frombuffer(message, dtype=np.dtype(ctx.data_type)).astype(np.float32)
|
||||
logger.warning(f"Have audio shape {audio.shape}")
|
||||
|
||||
k = 0
|
||||
|
||||
for chunk in g_wsp.transcribe(
|
||||
audio=audio, ctx=ctx, force_padding=force_padding # type: ignore
|
||||
):
|
||||
await websocket.send(chunk.json())
|
||||
cjs = json.loads(chunk.json())
|
||||
|
||||
logger.debug(f"k={k}, padding={force_padding}")
|
||||
await websocket.send(json.dumps(cjs))
|
||||
k += 1
|
||||
|
||||
if force_padding:
|
||||
await websocket.send(json.dumps({"close_connection": force_padding}))
|
||||
|
||||
idx += 1
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue