Add protocol_version

This commit is contained in:
Yuta Hayashibe 2022-10-15 13:23:00 +09:00
parent 3d293c868c
commit dce9719fea
3 changed files with 35 additions and 3 deletions

View file

@ -16,7 +16,12 @@ from whisper.audio import N_FRAMES, SAMPLE_RATE
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
from whispering.pbar import ProgressBar
from whispering.schema import Context, StdoutWriter, WhisperConfig
from whispering.schema import (
CURRENT_PROTOCOL_VERSION,
Context,
StdoutWriter,
WhisperConfig,
)
from whispering.serve import serve_with_websocket
from whispering.transcriber import WhisperStreamingTranscriber
from whispering.websocket_client import run_websocket_client
@ -214,6 +219,7 @@ def get_wshiper(*, opts) -> WhisperStreamingTranscriber:
def get_context(*, opts) -> Context:
ctx = Context(
protocol_version=CURRENT_PROTOCOL_VERSION,
beam_size=opts.beam_size,
temperatures=opts.temperature,
allow_padding=opts.allow_padding,

View file

@ -1,7 +1,7 @@
#!/usr/bin/env python3
import sys
from typing import List, Optional
from typing import Final, List, Optional
import numpy as np
import torch
@ -24,7 +24,11 @@ class WhisperConfig(BaseModel):
return values
CURRENT_PROTOCOL_VERSION: Final[int] = int("000_006_000")
class Context(BaseModel, arbitrary_types_allowed=True):
protocol_version: int
timestamp: float = 0.0
buffer_tokens: List[torch.Tensor] = []
buffer_mel: Optional[torch.Tensor] = None

View file

@ -3,7 +3,7 @@
import asyncio
import json
from logging import getLogger
from typing import Optional
from typing import Final, Optional
import numpy as np
import websockets
@ -13,6 +13,9 @@ from whispering.transcriber import Context, WhisperStreamingTranscriber
logger = getLogger(__name__)
MIN_PROTOCOL_VERSION: Final[int] = int("000_006_000")
MAX_PROTOCOL_VERSION: Final[int] = int("000_006_000")
async def serve_with_websocket_main(websocket):
global g_wsp
@ -41,6 +44,25 @@ async def serve_with_websocket_main(websocket):
)
)
return
if ctx.protocol_version < MIN_PROTOCOL_VERSION:
await websocket.send(
json.dumps(
{
"error": f"protocol_version is older than {MIN_PROTOCOL_VERSION}"
}
)
)
elif ctx.protocol_version > MAX_PROTOCOL_VERSION:
await websocket.send(
json.dumps(
{
"error": f"protocol_version is newer than {MAX_PROTOCOL_VERSION}"
}
)
)
return
continue
logger.debug(f"Message size: {len(message)}")