whispering/whispering/schema.py

88 lines
2.1 KiB
Python
Raw Normal View History

2022-09-23 10:20:11 +00:00
#!/usr/bin/env python3
2022-10-03 13:38:35 +00:00
import sys
2022-10-15 04:23:00 +00:00
from typing import Final, List, Optional
2022-09-23 10:20:11 +00:00
2022-10-02 10:47:17 +00:00
import numpy as np
2022-09-29 11:14:56 +00:00
import torch
2022-11-08 14:42:11 +00:00
from pydantic import BaseModel, Field, root_validator
from whisper.audio import N_FRAMES
2022-09-23 10:20:11 +00:00
class WhisperConfig(BaseModel):
model_name: str
device: str
language: str
2022-09-23 13:39:27 +00:00
fp16: bool = True
2022-09-23 11:03:00 +00:00
@root_validator
def validate_model_name(cls, values):
if values["model_name"].endswith(".en") and values["language"] not in {
"en",
"English",
}:
raise ValueError("English only model")
return values
2022-09-23 11:03:00 +00:00
2022-11-08 14:42:11 +00:00
CURRENT_PROTOCOL_VERSION: Final[int] = int("000_006_003")
2022-10-15 04:23:00 +00:00
2022-09-29 11:14:56 +00:00
class Context(BaseModel, arbitrary_types_allowed=True):
2022-10-15 04:23:00 +00:00
protocol_version: int
2022-09-29 11:14:56 +00:00
timestamp: float = 0.0
buffer_tokens: List[torch.Tensor] = []
buffer_mel: Optional[torch.Tensor] = None
nosoeech_skip_count: Optional[int] = None
2022-09-29 11:14:56 +00:00
2022-09-29 11:43:49 +00:00
temperatures: List[float]
patience: Optional[float] = None
compression_ratio_threshold: Optional[float] = 2.4
logprob_threshold: Optional[float] = -1.0
no_captions_threshold: Optional[float] = 0.6
best_of: int = 5
beam_size: Optional[int] = None
no_speech_threshold: Optional[float] = 0.6
logprob_threshold: Optional[float] = -1.0
compression_ratio_threshold: Optional[float] = 2.4
buffer_threshold: Optional[float] = 0.5
vad_threshold: float
max_nospeech_skip: int
2022-11-08 14:42:11 +00:00
mel_frame_min_num: int = Field(N_FRAMES, ge=1, le=N_FRAMES)
2022-09-29 11:43:49 +00:00
2022-10-17 13:01:33 +00:00
data_type: str = "float32"
2022-10-17 08:53:59 +00:00
2022-09-29 11:14:56 +00:00
2022-09-23 11:03:00 +00:00
class ParsedChunk(BaseModel):
start: float
end: float
text: str
tokens: List[int]
temperature: float
avg_logprob: float
compression_ratio: float
no_speech_prob: float
2022-10-01 14:21:58 +00:00
class SpeechSegment(BaseModel, arbitrary_types_allowed=True):
start_block_idx: int
end_block_idx: int
2022-10-02 10:47:17 +00:00
audio: np.ndarray
2022-10-03 13:38:35 +00:00
class StdoutWriter:
def open(self, *args, **kwargs):
return self
def __enter__(self, *args, **kwargs):
return self
def __exit__(self, *args, **kwargs):
2022-10-03 13:38:35 +00:00
pass
def flush(self, *args, **kwargs):
2022-10-03 13:38:35 +00:00
sys.stdout.flush()
def write(self, text, *args, **kwargs):
2022-10-03 13:38:35 +00:00
sys.stdout.write(text)