From 82c3910d177e3c41ce5dff4a5094fdf34ca4137f Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Sat, 24 Sep 2022 20:45:20 +0900 Subject: [PATCH] Add websocket server --- poetry.lock | 60 +++++++++++++++++++++++++++++++++++++- pyproject.toml | 1 + whisper_streaming/cli.py | 38 +++++++++++++++++++----- whisper_streaming/serve.py | 54 ++++++++++++++++++++++++++++++++++ 4 files changed, 144 insertions(+), 9 deletions(-) create mode 100644 whisper_streaming/serve.py diff --git a/poetry.lock b/poetry.lock index 9d72f21..f54efa7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -476,6 +476,14 @@ brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"] secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"] socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] +[[package]] +name = "websockets" +version = "10.3" +description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)" +category = "main" +optional = false +python-versions = ">=3.7" + [[package]] name = "whisper" version = "1.0" @@ -505,7 +513,7 @@ resolved_reference = "8cf36f3508c9acd341a45eb2364239a3d81458b9" [metadata] lock-version = "1.1" python-versions = ">=3.8,<4.0" -content-hash = "0f0a8569bdbf956e2a0a648725e65a33e51c2a38ba8d54df8fdb5c317a64c7d6" +content-hash = "a743450981e24e9b1a5cb5f3ffec7e346f54b87176a63d8c0dcc5cd0befdff62" [metadata.files] black = [ @@ -971,4 +979,54 @@ urllib3 = [ {file = "urllib3-1.26.12-py2.py3-none-any.whl", hash = "sha256:b930dd878d5a8afb066a637fbb35144fe7901e3b209d1cd4f524bd0e9deee997"}, {file = "urllib3-1.26.12.tar.gz", hash = "sha256:3fa96cf423e6987997fc326ae8df396db2a8b7c667747d47ddd8ecba91f4a74e"}, ] +websockets = [ + {file = "websockets-10.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:661f641b44ed315556a2fa630239adfd77bd1b11cb0b9d96ed8ad90b0b1e4978"}, + {file = "websockets-10.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b529fdfa881b69fe563dbd98acce84f3e5a67df13de415e143ef053ff006d500"}, + {file = "websockets-10.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f351c7d7d92f67c0609329ab2735eee0426a03022771b00102816a72715bb00b"}, + {file = "websockets-10.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:379e03422178436af4f3abe0aa8f401aa77ae2487843738542a75faf44a31f0c"}, + {file = "websockets-10.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:e904c0381c014b914136c492c8fa711ca4cced4e9b3d110e5e7d436d0fc289e8"}, + {file = "websockets-10.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:e7e6f2d6fd48422071cc8a6f8542016f350b79cc782752de531577d35e9bd677"}, + {file = "websockets-10.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:b9c77f0d1436ea4b4dc089ed8335fa141e6a251a92f75f675056dac4ab47a71e"}, + {file = "websockets-10.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e6fa05a680e35d0fcc1470cb070b10e6fe247af54768f488ed93542e71339d6f"}, + {file = "websockets-10.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:2f94fa3ae454a63ea3a19f73b95deeebc9f02ba2d5617ca16f0bbdae375cda47"}, + {file = "websockets-10.3-cp310-cp310-win32.whl", hash = "sha256:6ed1d6f791eabfd9808afea1e068f5e59418e55721db8b7f3bfc39dc831c42ae"}, + {file = "websockets-10.3-cp310-cp310-win_amd64.whl", hash = "sha256:347974105bbd4ea068106ec65e8e8ebd86f28c19e529d115d89bd8cc5cda3079"}, + {file = "websockets-10.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:fab7c640815812ed5f10fbee7abbf58788d602046b7bb3af9b1ac753a6d5e916"}, + {file = "websockets-10.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:994cdb1942a7a4c2e10098d9162948c9e7b235df755de91ca33f6e0481366fdb"}, + {file = "websockets-10.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:aad5e300ab32036eb3fdc350ad30877210e2f51bceaca83fb7fef4d2b6c72b79"}, + {file = "websockets-10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:e49ea4c1a9543d2bd8a747ff24411509c29e4bdcde05b5b0895e2120cb1a761d"}, + {file = "websockets-10.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:6ea6b300a6bdd782e49922d690e11c3669828fe36fc2471408c58b93b5535a98"}, + {file = "websockets-10.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:ef5ce841e102278c1c2e98f043db99d6755b1c58bde475516aef3a008ed7f28e"}, + {file = "websockets-10.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:d1655a6fc7aecd333b079d00fb3c8132d18988e47f19740c69303bf02e9883c6"}, + {file = "websockets-10.3-cp37-cp37m-win32.whl", hash = "sha256:83e5ca0d5b743cde3d29fda74ccab37bdd0911f25bd4cdf09ff8b51b7b4f2fa1"}, + {file = "websockets-10.3-cp37-cp37m-win_amd64.whl", hash = "sha256:da4377904a3379f0c1b75a965fff23b28315bcd516d27f99a803720dfebd94d4"}, + {file = "websockets-10.3-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:a1e15b230c3613e8ea82c9fc6941b2093e8eb939dd794c02754d33980ba81e36"}, + {file = "websockets-10.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:31564a67c3e4005f27815634343df688b25705cccb22bc1db621c781ddc64c69"}, + {file = "websockets-10.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c8d1d14aa0f600b5be363077b621b1b4d1eb3fbf90af83f9281cda668e6ff7fd"}, + {file = "websockets-10.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8fbd7d77f8aba46d43245e86dd91a8970eac4fb74c473f8e30e9c07581f852b2"}, + {file = "websockets-10.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:210aad7fdd381c52e58777560860c7e6110b6174488ef1d4b681c08b68bf7f8c"}, + {file = "websockets-10.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:6075fd24df23133c1b078e08a9b04a3bc40b31a8def4ee0b9f2c8865acce913e"}, + {file = "websockets-10.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:7f6d96fdb0975044fdd7953b35d003b03f9e2bcf85f2d2cf86285ece53e9f991"}, + {file = "websockets-10.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:c7250848ce69559756ad0086a37b82c986cd33c2d344ab87fea596c5ac6d9442"}, + {file = "websockets-10.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:28dd20b938a57c3124028680dc1600c197294da5db4292c76a0b48efb3ed7f76"}, + {file = "websockets-10.3-cp38-cp38-win32.whl", hash = "sha256:54c000abeaff6d8771a4e2cef40900919908ea7b6b6a30eae72752607c6db559"}, + {file = "websockets-10.3-cp38-cp38-win_amd64.whl", hash = "sha256:7ab36e17af592eec5747c68ef2722a74c1a4a70f3772bc661079baf4ae30e40d"}, + {file = "websockets-10.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:a141de3d5a92188234afa61653ed0bbd2dde46ad47b15c3042ffb89548e77094"}, + {file = "websockets-10.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:97bc9d41e69a7521a358f9b8e44871f6cdeb42af31815c17aed36372d4eec667"}, + {file = "websockets-10.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d6353ba89cfc657a3f5beabb3b69be226adbb5c6c7a66398e17809b0ce3c4731"}, + {file = "websockets-10.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ec2b0ab7edc8cd4b0eb428b38ed89079bdc20c6bdb5f889d353011038caac2f9"}, + {file = "websockets-10.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:85506b3328a9e083cc0a0fb3ba27e33c8db78341b3eb12eb72e8afd166c36680"}, + {file = "websockets-10.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:8af75085b4bc0b5c40c4a3c0e113fa95e84c60f4ed6786cbb675aeb1ee128247"}, + {file = "websockets-10.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:07cdc0a5b2549bcfbadb585ad8471ebdc7bdf91e32e34ae3889001c1c106a6af"}, + {file = "websockets-10.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:5b936bf552e4f6357f5727579072ff1e1324717902127ffe60c92d29b67b7be3"}, + {file = "websockets-10.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e4e08305bfd76ba8edab08dcc6496f40674f44eb9d5e23153efa0a35750337e8"}, + {file = "websockets-10.3-cp39-cp39-win32.whl", hash = "sha256:bb621ec2dbbbe8df78a27dbd9dd7919f9b7d32a73fafcb4d9252fc4637343582"}, + {file = "websockets-10.3-cp39-cp39-win_amd64.whl", hash = "sha256:51695d3b199cd03098ae5b42833006a0f43dc5418d3102972addc593a783bc02"}, + {file = "websockets-10.3-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:907e8247480f287aa9bbc9391bd6de23c906d48af54c8c421df84655eef66af7"}, + {file = "websockets-10.3-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8b1359aba0ff810d5830d5ab8e2c4a02bebf98a60aa0124fb29aa78cfdb8031f"}, + {file = "websockets-10.3-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:93d5ea0b5da8d66d868b32c614d2b52d14304444e39e13a59566d4acb8d6e2e4"}, + {file = "websockets-10.3-pp37-pypy37_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:7934e055fd5cd9dee60f11d16c8d79c4567315824bacb1246d0208a47eca9755"}, + {file = "websockets-10.3-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:3eda1cb7e9da1b22588cefff09f0951771d6ee9fa8dbe66f5ae04cc5f26b2b55"}, + {file = "websockets-10.3.tar.gz", hash = "sha256:fc06cc8073c8e87072138ba1e431300e2d408f054b27047d047b549455066ff4"}, +] whisper = [] diff --git a/pyproject.toml b/pyproject.toml index 80abe9e..7d1c035 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ python = ">=3.8,<4.0" whisper = {git = "https://github.com/openai/whisper.git", rev = '8cf36f3508c9acd341a45eb2364239a3d81458b9'} sounddevice = "^0.4.5" pydantic = "^1.10.2" +websockets = "^10.3" [tool.poetry.group.dev.dependencies] diff --git a/whisper_streaming/cli.py b/whisper_streaming/cli.py index d40e44f..6464102 100644 --- a/whisper_streaming/cli.py +++ b/whisper_streaming/cli.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import argparse +import asyncio import queue from logging import DEBUG, INFO, basicConfig, getLogger from typing import Optional, Union @@ -12,6 +13,7 @@ from whisper.audio import N_FRAMES, SAMPLE_RATE from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE from whisper_streaming.schema import WhisperConfig +from whisper_streaming.serve import serve_with_websocket from whisper_streaming.transcriber import WhisperStreamingTranscriber logger = getLogger(__name__) @@ -19,12 +21,10 @@ logger = getLogger(__name__) def transcribe_from_mic( *, - config: WhisperConfig, + wsp: WhisperStreamingTranscriber, sd_device: Optional[Union[int, str]], num_block: int, ) -> None: - logger.debug(f"WhisperConfig: {config}") - wsp = WhisperStreamingTranscriber(config=config) q = queue.Queue() def sd_callback(indata, frames, time, status): @@ -99,6 +99,16 @@ def get_opts() -> argparse.Namespace: "--debug", action="store_true", ) + parser.add_argument( + "--host", + default="0.0.0.0", + help="host of websocker server", + ) + parser.add_argument( + "--port", + type=int, + help="Port number of websocker server", + ) return parser.parse_args() @@ -128,11 +138,23 @@ def main() -> None: beam_size=opts.beam_size, temperatures=opts.temperature, ) - transcribe_from_mic( - config=config, - sd_device=opts.mic, - num_block=opts.num_block, - ) + + logger.debug(f"WhisperConfig: {config}") + wsp = WhisperStreamingTranscriber(config=config) + if opts.host is not None and opts.port is not None: + asyncio.run( + serve_with_websocket( + wsp=wsp, + host=opts.host, + port=opts.port, + ) + ) + else: + transcribe_from_mic( + wsp=wsp, + sd_device=opts.mic, + num_block=opts.num_block, + ) if __name__ == "__main__": diff --git a/whisper_streaming/serve.py b/whisper_streaming/serve.py new file mode 100644 index 0000000..e2c9217 --- /dev/null +++ b/whisper_streaming/serve.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 + +import asyncio +import json +from logging import getLogger + +import numpy as np +import websockets + +from whisper_streaming.transcriber import WhisperStreamingTranscriber + +logger = getLogger(__name__) + + +async def serve_with_websocket_main(websocket): + global g_wsp + idx: int = 0 + + while True: + logger.debug(f"Segment #: {idx}") + message = await websocket.recv() + + if isinstance(message, str): + logger.debug(f"Got str: {message}") + continue + + logger.debug(f"Message size: {len(message)}") + segment = np.frombuffer(message, dtype=np.float32) + chunks = [chunk.dict() for chunk in g_wsp.transcribe(segment=segment)] + await websocket.send(json.dumps(chunks)) + idx += 1 + + +async def serve_with_websocket( + *, + wsp: WhisperStreamingTranscriber, + host: str, + port: int, +): + logger.info(f"Serve at {host}:{port}") + logger.info("Make secure with your responsibility!") + global g_wsp + g_wsp = wsp + + try: + async with websockets.serve( # type: ignore + serve_with_websocket_main, + host=host, + port=port, + max_size=999999999, + ): + await asyncio.Future() + except KeyboardInterrupt: + pass