Add protocol_version

This commit is contained in:
Yuta Hayashibe 2022-10-15 13:23:00 +09:00
parent 3d293c868c
commit dce9719fea
3 changed files with 35 additions and 3 deletions

View file

@ -16,7 +16,12 @@ from whisper.audio import N_FRAMES, SAMPLE_RATE
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
from whispering.pbar import ProgressBar from whispering.pbar import ProgressBar
from whispering.schema import Context, StdoutWriter, WhisperConfig from whispering.schema import (
CURRENT_PROTOCOL_VERSION,
Context,
StdoutWriter,
WhisperConfig,
)
from whispering.serve import serve_with_websocket from whispering.serve import serve_with_websocket
from whispering.transcriber import WhisperStreamingTranscriber from whispering.transcriber import WhisperStreamingTranscriber
from whispering.websocket_client import run_websocket_client from whispering.websocket_client import run_websocket_client
@ -214,6 +219,7 @@ def get_wshiper(*, opts) -> WhisperStreamingTranscriber:
def get_context(*, opts) -> Context: def get_context(*, opts) -> Context:
ctx = Context( ctx = Context(
protocol_version=CURRENT_PROTOCOL_VERSION,
beam_size=opts.beam_size, beam_size=opts.beam_size,
temperatures=opts.temperature, temperatures=opts.temperature,
allow_padding=opts.allow_padding, allow_padding=opts.allow_padding,

View file

@ -1,7 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import sys import sys
from typing import List, Optional from typing import Final, List, Optional
import numpy as np import numpy as np
import torch import torch
@ -24,7 +24,11 @@ class WhisperConfig(BaseModel):
return values return values
CURRENT_PROTOCOL_VERSION: Final[int] = int("000_006_000")
class Context(BaseModel, arbitrary_types_allowed=True): class Context(BaseModel, arbitrary_types_allowed=True):
protocol_version: int
timestamp: float = 0.0 timestamp: float = 0.0
buffer_tokens: List[torch.Tensor] = [] buffer_tokens: List[torch.Tensor] = []
buffer_mel: Optional[torch.Tensor] = None buffer_mel: Optional[torch.Tensor] = None

View file

@ -3,7 +3,7 @@
import asyncio import asyncio
import json import json
from logging import getLogger from logging import getLogger
from typing import Optional from typing import Final, Optional
import numpy as np import numpy as np
import websockets import websockets
@ -13,6 +13,9 @@ from whispering.transcriber import Context, WhisperStreamingTranscriber
logger = getLogger(__name__) logger = getLogger(__name__)
MIN_PROTOCOL_VERSION: Final[int] = int("000_006_000")
MAX_PROTOCOL_VERSION: Final[int] = int("000_006_000")
async def serve_with_websocket_main(websocket): async def serve_with_websocket_main(websocket):
global g_wsp global g_wsp
@ -41,6 +44,25 @@ async def serve_with_websocket_main(websocket):
) )
) )
return return
if ctx.protocol_version < MIN_PROTOCOL_VERSION:
await websocket.send(
json.dumps(
{
"error": f"protocol_version is older than {MIN_PROTOCOL_VERSION}"
}
)
)
elif ctx.protocol_version > MAX_PROTOCOL_VERSION:
await websocket.send(
json.dumps(
{
"error": f"protocol_version is newer than {MAX_PROTOCOL_VERSION}"
}
)
)
return
continue continue
logger.debug(f"Message size: {len(message)}") logger.debug(f"Message size: {len(message)}")