Added prompt

This commit is contained in:
Yuta Hayashibe 2022-09-23 21:25:05 +09:00
parent dcddd07176
commit 6f49e29838

View file

@ -43,6 +43,7 @@ class WhisperStreamingTranscriber:
self,
*,
t,
prompt,
beam_size: Optional[int],
patience: float,
best_of: Optional[int],
@ -56,7 +57,7 @@ class WhisperStreamingTranscriber:
beam_size=beam_size,
patience=patience,
length_penalty=None,
prompt=None,
prompt=prompt,
prefix=None,
suppress_blank=True,
suppress_tokens="-1",
@ -65,12 +66,18 @@ class WhisperStreamingTranscriber:
fp16=True,
)
def _decode_with_fallback(self, *, segment: np.ndarray) -> List[DecodingResult]:
def _decode_with_fallback(
self,
*,
segment: np.ndarray,
prompt,
) -> List[DecodingResult]:
assert len(self.config.temperatures) >= 1
t = self.config.temperatures[0]
_decode_options1: DecodingOptions = self._get_decoding_options(
t=t,
prompt=self.buffer_tokens,
beam_size=self.config.beam_size,
patience=0.0,
best_of=None,
@ -88,6 +95,7 @@ class WhisperStreamingTranscriber:
if any(needs_fallback):
_decode_options2: DecodingOptions = self._get_decoding_options(
t=t,
prompt=self.buffer_tokens,
beam_size=None,
patience=0.0,
best_of=self.config.best_of,
@ -201,7 +209,10 @@ class WhisperStreamingTranscriber:
.to(self.model.device) # type:ignore
.to(self.dtype) # type:ignore
)
results = self._decode_with_fallback(segment=segment)
results = self._decode_with_fallback(
segment=segment,
prompt=self.buffer_tokens,
)
result = results[0]
if self.config.no_speech_threshold is not None: