From dce9719fea20b7d7ab3c598548e8ad9e5f57e052 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Sat, 15 Oct 2022 13:23:00 +0900 Subject: [PATCH] Add protocol_version --- whispering/cli.py | 8 +++++++- whispering/schema.py | 6 +++++- whispering/serve.py | 24 +++++++++++++++++++++++- 3 files changed, 35 insertions(+), 3 deletions(-) diff --git a/whispering/cli.py b/whispering/cli.py index 2d2f529..1148751 100644 --- a/whispering/cli.py +++ b/whispering/cli.py @@ -16,7 +16,12 @@ from whisper.audio import N_FRAMES, SAMPLE_RATE from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE from whispering.pbar import ProgressBar -from whispering.schema import Context, StdoutWriter, WhisperConfig +from whispering.schema import ( + CURRENT_PROTOCOL_VERSION, + Context, + StdoutWriter, + WhisperConfig, +) from whispering.serve import serve_with_websocket from whispering.transcriber import WhisperStreamingTranscriber from whispering.websocket_client import run_websocket_client @@ -214,6 +219,7 @@ def get_wshiper(*, opts) -> WhisperStreamingTranscriber: def get_context(*, opts) -> Context: ctx = Context( + protocol_version=CURRENT_PROTOCOL_VERSION, beam_size=opts.beam_size, temperatures=opts.temperature, allow_padding=opts.allow_padding, diff --git a/whispering/schema.py b/whispering/schema.py index e805f73..ee0272f 100644 --- a/whispering/schema.py +++ b/whispering/schema.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 import sys -from typing import List, Optional +from typing import Final, List, Optional import numpy as np import torch @@ -24,7 +24,11 @@ class WhisperConfig(BaseModel): return values +CURRENT_PROTOCOL_VERSION: Final[int] = int("000_006_000") + + class Context(BaseModel, arbitrary_types_allowed=True): + protocol_version: int timestamp: float = 0.0 buffer_tokens: List[torch.Tensor] = [] buffer_mel: Optional[torch.Tensor] = None diff --git a/whispering/serve.py b/whispering/serve.py index 1c4bdad..8ca32ec 100644 --- a/whispering/serve.py +++ b/whispering/serve.py @@ -3,7 +3,7 @@ import asyncio import json from logging import getLogger -from typing import Optional +from typing import Final, Optional import numpy as np import websockets @@ -13,6 +13,9 @@ from whispering.transcriber import Context, WhisperStreamingTranscriber logger = getLogger(__name__) +MIN_PROTOCOL_VERSION: Final[int] = int("000_006_000") +MAX_PROTOCOL_VERSION: Final[int] = int("000_006_000") + async def serve_with_websocket_main(websocket): global g_wsp @@ -41,6 +44,25 @@ async def serve_with_websocket_main(websocket): ) ) return + + if ctx.protocol_version < MIN_PROTOCOL_VERSION: + await websocket.send( + json.dumps( + { + "error": f"protocol_version is older than {MIN_PROTOCOL_VERSION}" + } + ) + ) + elif ctx.protocol_version > MAX_PROTOCOL_VERSION: + await websocket.send( + json.dumps( + { + "error": f"protocol_version is newer than {MAX_PROTOCOL_VERSION}" + } + ) + ) + return + continue logger.debug(f"Message size: {len(message)}")