mirror of
https://github.com/shirayu/whispering.git
synced 2024-11-25 10:21:00 +00:00
Add protocol_version
This commit is contained in:
parent
3d293c868c
commit
dce9719fea
3 changed files with 35 additions and 3 deletions
|
@ -16,7 +16,12 @@ from whisper.audio import N_FRAMES, SAMPLE_RATE
|
||||||
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
|
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
|
||||||
|
|
||||||
from whispering.pbar import ProgressBar
|
from whispering.pbar import ProgressBar
|
||||||
from whispering.schema import Context, StdoutWriter, WhisperConfig
|
from whispering.schema import (
|
||||||
|
CURRENT_PROTOCOL_VERSION,
|
||||||
|
Context,
|
||||||
|
StdoutWriter,
|
||||||
|
WhisperConfig,
|
||||||
|
)
|
||||||
from whispering.serve import serve_with_websocket
|
from whispering.serve import serve_with_websocket
|
||||||
from whispering.transcriber import WhisperStreamingTranscriber
|
from whispering.transcriber import WhisperStreamingTranscriber
|
||||||
from whispering.websocket_client import run_websocket_client
|
from whispering.websocket_client import run_websocket_client
|
||||||
|
@ -214,6 +219,7 @@ def get_wshiper(*, opts) -> WhisperStreamingTranscriber:
|
||||||
|
|
||||||
def get_context(*, opts) -> Context:
|
def get_context(*, opts) -> Context:
|
||||||
ctx = Context(
|
ctx = Context(
|
||||||
|
protocol_version=CURRENT_PROTOCOL_VERSION,
|
||||||
beam_size=opts.beam_size,
|
beam_size=opts.beam_size,
|
||||||
temperatures=opts.temperature,
|
temperatures=opts.temperature,
|
||||||
allow_padding=opts.allow_padding,
|
allow_padding=opts.allow_padding,
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
from typing import List, Optional
|
from typing import Final, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -24,7 +24,11 @@ class WhisperConfig(BaseModel):
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
CURRENT_PROTOCOL_VERSION: Final[int] = int("000_006_000")
|
||||||
|
|
||||||
|
|
||||||
class Context(BaseModel, arbitrary_types_allowed=True):
|
class Context(BaseModel, arbitrary_types_allowed=True):
|
||||||
|
protocol_version: int
|
||||||
timestamp: float = 0.0
|
timestamp: float = 0.0
|
||||||
buffer_tokens: List[torch.Tensor] = []
|
buffer_tokens: List[torch.Tensor] = []
|
||||||
buffer_mel: Optional[torch.Tensor] = None
|
buffer_mel: Optional[torch.Tensor] = None
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import Optional
|
from typing import Final, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import websockets
|
import websockets
|
||||||
|
@ -13,6 +13,9 @@ from whispering.transcriber import Context, WhisperStreamingTranscriber
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
MIN_PROTOCOL_VERSION: Final[int] = int("000_006_000")
|
||||||
|
MAX_PROTOCOL_VERSION: Final[int] = int("000_006_000")
|
||||||
|
|
||||||
|
|
||||||
async def serve_with_websocket_main(websocket):
|
async def serve_with_websocket_main(websocket):
|
||||||
global g_wsp
|
global g_wsp
|
||||||
|
@ -41,6 +44,25 @@ async def serve_with_websocket_main(websocket):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if ctx.protocol_version < MIN_PROTOCOL_VERSION:
|
||||||
|
await websocket.send(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"error": f"protocol_version is older than {MIN_PROTOCOL_VERSION}"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif ctx.protocol_version > MAX_PROTOCOL_VERSION:
|
||||||
|
await websocket.send(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"error": f"protocol_version is newer than {MAX_PROTOCOL_VERSION}"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
logger.debug(f"Message size: {len(message)}")
|
logger.debug(f"Message size: {len(message)}")
|
||||||
|
|
Loading…
Reference in a new issue