Fix dtype (Fix #2)

This commit is contained in:
Yuta Hayashibe 2022-09-23 22:39:27 +09:00
parent 56508bc282
commit 64060ee8b4
2 changed files with 16 additions and 2 deletions

View file

@ -10,6 +10,7 @@ class WhisperConfig(BaseModel):
device: str
language: str
fp16: bool = True
temperatures: Tuple[float, ...] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0)
compression_ratio_threshold: Optional[float] = 2.4
logprob_threshold: Optional[float] = -1.0

View file

@ -23,6 +23,19 @@ logger = getLogger(__name__)
class WhisperStreamingTranscriber:
def _set_dtype(self, fp16: bool):
self.fp16 = fp16
self.dtype = torch.float16 if fp16 else torch.float32
if self.model.device == torch.device("cpu"):
if torch.cuda.is_available():
logger.warning("Performing inference on CPU when CUDA is available")
if self.dtype == torch.float16:
logger.warning("FP16 is not supported on CPU; using FP32 instead")
self.dtype = torch.float32
if self.dtype == torch.float32:
self.fp16 = False
def __init__(self, *, config: WhisperConfig):
self.config: WhisperConfig = config
self.model: Whisper = load_model(config.model_name, device=config.device)
@ -31,7 +44,7 @@ class WhisperStreamingTranscriber:
language=config.language,
task="transcribe",
)
self.dtype = torch.float16
self._set_dtype(config.fp16)
self.timestamp: float = 0.0
self.input_stride = exact_div(
N_FRAMES, self.model.dims.n_audio_ctx
@ -67,7 +80,7 @@ class WhisperStreamingTranscriber:
suppress_tokens="-1",
without_timestamps=False,
max_initial_timestamp=0.0,
fp16=True,
fp16=self.fp16,
)
def _decode_with_fallback(