mirror of
https://github.com/shirayu/whispering.git
synced 2024-11-25 10:21:00 +00:00
Add protocol_version
This commit is contained in:
parent
3d293c868c
commit
dce9719fea
3 changed files with 35 additions and 3 deletions
|
@ -16,7 +16,12 @@ from whisper.audio import N_FRAMES, SAMPLE_RATE
|
|||
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
|
||||
|
||||
from whispering.pbar import ProgressBar
|
||||
from whispering.schema import Context, StdoutWriter, WhisperConfig
|
||||
from whispering.schema import (
|
||||
CURRENT_PROTOCOL_VERSION,
|
||||
Context,
|
||||
StdoutWriter,
|
||||
WhisperConfig,
|
||||
)
|
||||
from whispering.serve import serve_with_websocket
|
||||
from whispering.transcriber import WhisperStreamingTranscriber
|
||||
from whispering.websocket_client import run_websocket_client
|
||||
|
@ -214,6 +219,7 @@ def get_wshiper(*, opts) -> WhisperStreamingTranscriber:
|
|||
|
||||
def get_context(*, opts) -> Context:
|
||||
ctx = Context(
|
||||
protocol_version=CURRENT_PROTOCOL_VERSION,
|
||||
beam_size=opts.beam_size,
|
||||
temperatures=opts.temperature,
|
||||
allow_padding=opts.allow_padding,
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import sys
|
||||
from typing import List, Optional
|
||||
from typing import Final, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -24,7 +24,11 @@ class WhisperConfig(BaseModel):
|
|||
return values
|
||||
|
||||
|
||||
CURRENT_PROTOCOL_VERSION: Final[int] = int("000_006_000")
|
||||
|
||||
|
||||
class Context(BaseModel, arbitrary_types_allowed=True):
|
||||
protocol_version: int
|
||||
timestamp: float = 0.0
|
||||
buffer_tokens: List[torch.Tensor] = []
|
||||
buffer_mel: Optional[torch.Tensor] = None
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
import asyncio
|
||||
import json
|
||||
from logging import getLogger
|
||||
from typing import Optional
|
||||
from typing import Final, Optional
|
||||
|
||||
import numpy as np
|
||||
import websockets
|
||||
|
@ -13,6 +13,9 @@ from whispering.transcriber import Context, WhisperStreamingTranscriber
|
|||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
MIN_PROTOCOL_VERSION: Final[int] = int("000_006_000")
|
||||
MAX_PROTOCOL_VERSION: Final[int] = int("000_006_000")
|
||||
|
||||
|
||||
async def serve_with_websocket_main(websocket):
|
||||
global g_wsp
|
||||
|
@ -41,6 +44,25 @@ async def serve_with_websocket_main(websocket):
|
|||
)
|
||||
)
|
||||
return
|
||||
|
||||
if ctx.protocol_version < MIN_PROTOCOL_VERSION:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"error": f"protocol_version is older than {MIN_PROTOCOL_VERSION}"
|
||||
}
|
||||
)
|
||||
)
|
||||
elif ctx.protocol_version > MAX_PROTOCOL_VERSION:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"error": f"protocol_version is newer than {MAX_PROTOCOL_VERSION}"
|
||||
}
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
continue
|
||||
|
||||
logger.debug(f"Message size: {len(message)}")
|
||||
|
|
Loading…
Reference in a new issue