mirror of
https://github.com/shirayu/whispering.git
synced 2024-11-25 10:21:00 +00:00
Check model name (openai/whisper@2d3032d)
This commit is contained in:
parent
c18f5e4993
commit
c660a01bfb
1 changed files with 10 additions and 1 deletions
|
@ -3,7 +3,7 @@
|
|||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
|
||||
class WhisperConfig(BaseModel):
|
||||
|
@ -24,6 +24,15 @@ class WhisperConfig(BaseModel):
|
|||
compression_ratio_threshold: Optional[float] = 2.4
|
||||
buffer_threshold: Optional[float] = 0.5
|
||||
|
||||
@root_validator
|
||||
def validate_model_name(cls, values):
|
||||
if values["model_name"].endswith(".en") and values["language"] not in {
|
||||
"en",
|
||||
"English",
|
||||
}:
|
||||
raise ValueError("English only model")
|
||||
return values
|
||||
|
||||
|
||||
class Context(BaseModel, arbitrary_types_allowed=True):
|
||||
timestamp: float = 0.0
|
||||
|
|
Loading…
Reference in a new issue