mirror of
https://github.com/shirayu/whispering.git
synced 2024-11-25 10:21:00 +00:00
Check model name (openai/whisper@2d3032d)
This commit is contained in:
parent
c18f5e4993
commit
c660a01bfb
1 changed files with 10 additions and 1 deletions
|
@ -3,7 +3,7 @@
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, root_validator
|
||||||
|
|
||||||
|
|
||||||
class WhisperConfig(BaseModel):
|
class WhisperConfig(BaseModel):
|
||||||
|
@ -24,6 +24,15 @@ class WhisperConfig(BaseModel):
|
||||||
compression_ratio_threshold: Optional[float] = 2.4
|
compression_ratio_threshold: Optional[float] = 2.4
|
||||||
buffer_threshold: Optional[float] = 0.5
|
buffer_threshold: Optional[float] = 0.5
|
||||||
|
|
||||||
|
@root_validator
|
||||||
|
def validate_model_name(cls, values):
|
||||||
|
if values["model_name"].endswith(".en") and values["language"] not in {
|
||||||
|
"en",
|
||||||
|
"English",
|
||||||
|
}:
|
||||||
|
raise ValueError("English only model")
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
class Context(BaseModel, arbitrary_types_allowed=True):
|
class Context(BaseModel, arbitrary_types_allowed=True):
|
||||||
timestamp: float = 0.0
|
timestamp: float = 0.0
|
||||||
|
|
Loading…
Reference in a new issue