mirror of
https://github.com/shirayu/whispering.git
synced 2025-02-22 05:16:18 +00:00
Fix dtype (Fix #2)
This commit is contained in:
parent
56508bc282
commit
64060ee8b4
2 changed files with 16 additions and 2 deletions
|
@ -10,6 +10,7 @@ class WhisperConfig(BaseModel):
|
|||
device: str
|
||||
language: str
|
||||
|
||||
fp16: bool = True
|
||||
temperatures: Tuple[float, ...] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0)
|
||||
compression_ratio_threshold: Optional[float] = 2.4
|
||||
logprob_threshold: Optional[float] = -1.0
|
||||
|
|
|
@ -23,6 +23,19 @@ logger = getLogger(__name__)
|
|||
|
||||
|
||||
class WhisperStreamingTranscriber:
|
||||
def _set_dtype(self, fp16: bool):
|
||||
self.fp16 = fp16
|
||||
self.dtype = torch.float16 if fp16 else torch.float32
|
||||
if self.model.device == torch.device("cpu"):
|
||||
if torch.cuda.is_available():
|
||||
logger.warning("Performing inference on CPU when CUDA is available")
|
||||
if self.dtype == torch.float16:
|
||||
logger.warning("FP16 is not supported on CPU; using FP32 instead")
|
||||
self.dtype = torch.float32
|
||||
|
||||
if self.dtype == torch.float32:
|
||||
self.fp16 = False
|
||||
|
||||
def __init__(self, *, config: WhisperConfig):
|
||||
self.config: WhisperConfig = config
|
||||
self.model: Whisper = load_model(config.model_name, device=config.device)
|
||||
|
@ -31,7 +44,7 @@ class WhisperStreamingTranscriber:
|
|||
language=config.language,
|
||||
task="transcribe",
|
||||
)
|
||||
self.dtype = torch.float16
|
||||
self._set_dtype(config.fp16)
|
||||
self.timestamp: float = 0.0
|
||||
self.input_stride = exact_div(
|
||||
N_FRAMES, self.model.dims.n_audio_ctx
|
||||
|
@ -67,7 +80,7 @@ class WhisperStreamingTranscriber:
|
|||
suppress_tokens="-1",
|
||||
without_timestamps=False,
|
||||
max_initial_timestamp=0.0,
|
||||
fp16=True,
|
||||
fp16=self.fp16,
|
||||
)
|
||||
|
||||
def _decode_with_fallback(
|
||||
|
|
Loading…
Reference in a new issue