From 64060ee8b4ad457585296f86b609f5c681162688 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Fri, 23 Sep 2022 22:39:27 +0900 Subject: [PATCH] Fix dtype (Fix #2) --- whisper_streaming/schema.py | 1 + whisper_streaming/transcriber.py | 17 +++++++++++++++-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/whisper_streaming/schema.py b/whisper_streaming/schema.py index 59272d3..3dc1b1c 100644 --- a/whisper_streaming/schema.py +++ b/whisper_streaming/schema.py @@ -10,6 +10,7 @@ class WhisperConfig(BaseModel): device: str language: str + fp16: bool = True temperatures: Tuple[float, ...] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0) compression_ratio_threshold: Optional[float] = 2.4 logprob_threshold: Optional[float] = -1.0 diff --git a/whisper_streaming/transcriber.py b/whisper_streaming/transcriber.py index a952dfd..59afed4 100644 --- a/whisper_streaming/transcriber.py +++ b/whisper_streaming/transcriber.py @@ -23,6 +23,19 @@ logger = getLogger(__name__) class WhisperStreamingTranscriber: + def _set_dtype(self, fp16: bool): + self.fp16 = fp16 + self.dtype = torch.float16 if fp16 else torch.float32 + if self.model.device == torch.device("cpu"): + if torch.cuda.is_available(): + logger.warning("Performing inference on CPU when CUDA is available") + if self.dtype == torch.float16: + logger.warning("FP16 is not supported on CPU; using FP32 instead") + self.dtype = torch.float32 + + if self.dtype == torch.float32: + self.fp16 = False + def __init__(self, *, config: WhisperConfig): self.config: WhisperConfig = config self.model: Whisper = load_model(config.model_name, device=config.device) @@ -31,7 +44,7 @@ class WhisperStreamingTranscriber: language=config.language, task="transcribe", ) - self.dtype = torch.float16 + self._set_dtype(config.fp16) self.timestamp: float = 0.0 self.input_stride = exact_div( N_FRAMES, self.model.dims.n_audio_ctx @@ -67,7 +80,7 @@ class WhisperStreamingTranscriber: suppress_tokens="-1", without_timestamps=False, max_initial_timestamp=0.0, - fp16=True, + fp16=self.fp16, ) def _decode_with_fallback(