mirror of
https://github.com/shirayu/whispering.git
synced 2024-06-02 21:39:28 +00:00
Fix dtype (Fix #2)
This commit is contained in:
parent
56508bc282
commit
64060ee8b4
|
@ -10,6 +10,7 @@ class WhisperConfig(BaseModel):
|
||||||
device: str
|
device: str
|
||||||
language: str
|
language: str
|
||||||
|
|
||||||
|
fp16: bool = True
|
||||||
temperatures: Tuple[float, ...] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0)
|
temperatures: Tuple[float, ...] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0)
|
||||||
compression_ratio_threshold: Optional[float] = 2.4
|
compression_ratio_threshold: Optional[float] = 2.4
|
||||||
logprob_threshold: Optional[float] = -1.0
|
logprob_threshold: Optional[float] = -1.0
|
||||||
|
|
|
@ -23,6 +23,19 @@ logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class WhisperStreamingTranscriber:
|
class WhisperStreamingTranscriber:
|
||||||
|
def _set_dtype(self, fp16: bool):
|
||||||
|
self.fp16 = fp16
|
||||||
|
self.dtype = torch.float16 if fp16 else torch.float32
|
||||||
|
if self.model.device == torch.device("cpu"):
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
logger.warning("Performing inference on CPU when CUDA is available")
|
||||||
|
if self.dtype == torch.float16:
|
||||||
|
logger.warning("FP16 is not supported on CPU; using FP32 instead")
|
||||||
|
self.dtype = torch.float32
|
||||||
|
|
||||||
|
if self.dtype == torch.float32:
|
||||||
|
self.fp16 = False
|
||||||
|
|
||||||
def __init__(self, *, config: WhisperConfig):
|
def __init__(self, *, config: WhisperConfig):
|
||||||
self.config: WhisperConfig = config
|
self.config: WhisperConfig = config
|
||||||
self.model: Whisper = load_model(config.model_name, device=config.device)
|
self.model: Whisper = load_model(config.model_name, device=config.device)
|
||||||
|
@ -31,7 +44,7 @@ class WhisperStreamingTranscriber:
|
||||||
language=config.language,
|
language=config.language,
|
||||||
task="transcribe",
|
task="transcribe",
|
||||||
)
|
)
|
||||||
self.dtype = torch.float16
|
self._set_dtype(config.fp16)
|
||||||
self.timestamp: float = 0.0
|
self.timestamp: float = 0.0
|
||||||
self.input_stride = exact_div(
|
self.input_stride = exact_div(
|
||||||
N_FRAMES, self.model.dims.n_audio_ctx
|
N_FRAMES, self.model.dims.n_audio_ctx
|
||||||
|
@ -67,7 +80,7 @@ class WhisperStreamingTranscriber:
|
||||||
suppress_tokens="-1",
|
suppress_tokens="-1",
|
||||||
without_timestamps=False,
|
without_timestamps=False,
|
||||||
max_initial_timestamp=0.0,
|
max_initial_timestamp=0.0,
|
||||||
fp16=True,
|
fp16=self.fp16,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _decode_with_fallback(
|
def _decode_with_fallback(
|
||||||
|
|
Loading…
Reference in a new issue