This commit is contained in:
Yuta Hayashibe 2022-09-23 22:26:22 +09:00
parent f61f39c577
commit b51d7c6cce

View file

@ -216,7 +216,9 @@ class WhisperStreamingTranscriber:
seek: int = 0 seek: int = 0
rest_start: Optional[int] = None rest_start: Optional[int] = None
while seek < mel.shape[-1]: while seek < mel.shape[-1]:
logger.debug(f"seek={seek}, timestamp={self.timestamp}") logger.debug(
f"seek={seek}, timestamp={self.timestamp}, rest_start={rest_start}"
)
segment = ( segment = (
pad_or_trim(mel[:, :, seek:], N_FRAMES) pad_or_trim(mel[:, :, seek:], N_FRAMES)
.to(self.model.device) # type: ignore .to(self.model.device) # type: ignore
@ -233,7 +235,9 @@ class WhisperStreamingTranscriber:
self.config.logprob_threshold is not None self.config.logprob_threshold is not None
and result.avg_logprob > self.config.logprob_threshold and result.avg_logprob > self.config.logprob_threshold
): ):
return seek += segment.shape[-1]
rest_start = None
continue
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
last_timestamp_position: Optional[int] = None last_timestamp_position: Optional[int] = None
@ -251,9 +255,9 @@ class WhisperStreamingTranscriber:
seek += last_timestamp_position seek += last_timestamp_position
rest_start = seek rest_start = seek
logger.debug(f"Last rest_start={rest_start}")
if rest_start is None: if rest_start is None:
return return
logger.debug(f"rest_start={rest_start}")
self.buffer_mel = mel[:, :, rest_start:] self.buffer_mel = mel[:, :, rest_start:]
del mel del mel