This commit is contained in:
Yuta Hayashibe 2022-09-29 20:26:03 +09:00
parent c18f5e4993
commit c660a01bfb

View file

@ -3,7 +3,7 @@
from typing import List, Optional from typing import List, Optional
import torch import torch
from pydantic import BaseModel from pydantic import BaseModel, root_validator
class WhisperConfig(BaseModel): class WhisperConfig(BaseModel):
@ -24,6 +24,15 @@ class WhisperConfig(BaseModel):
compression_ratio_threshold: Optional[float] = 2.4 compression_ratio_threshold: Optional[float] = 2.4
buffer_threshold: Optional[float] = 0.5 buffer_threshold: Optional[float] = 0.5
@root_validator
def validate_model_name(cls, values):
if values["model_name"].endswith(".en") and values["language"] not in {
"en",
"English",
}:
raise ValueError("English only model")
return values
class Context(BaseModel, arbitrary_types_allowed=True): class Context(BaseModel, arbitrary_types_allowed=True):
timestamp: float = 0.0 timestamp: float = 0.0