From 6f49e2983823350ef4557da58057634de5318022 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Fri, 23 Sep 2022 21:25:05 +0900 Subject: [PATCH] Added prompt --- whisper_streaming/transcriber.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/whisper_streaming/transcriber.py b/whisper_streaming/transcriber.py index aad86f4..5fcf6f6 100644 --- a/whisper_streaming/transcriber.py +++ b/whisper_streaming/transcriber.py @@ -43,6 +43,7 @@ class WhisperStreamingTranscriber: self, *, t, + prompt, beam_size: Optional[int], patience: float, best_of: Optional[int], @@ -56,7 +57,7 @@ class WhisperStreamingTranscriber: beam_size=beam_size, patience=patience, length_penalty=None, - prompt=None, + prompt=prompt, prefix=None, suppress_blank=True, suppress_tokens="-1", @@ -65,12 +66,18 @@ class WhisperStreamingTranscriber: fp16=True, ) - def _decode_with_fallback(self, *, segment: np.ndarray) -> List[DecodingResult]: + def _decode_with_fallback( + self, + *, + segment: np.ndarray, + prompt, + ) -> List[DecodingResult]: assert len(self.config.temperatures) >= 1 t = self.config.temperatures[0] _decode_options1: DecodingOptions = self._get_decoding_options( t=t, + prompt=self.buffer_tokens, beam_size=self.config.beam_size, patience=0.0, best_of=None, @@ -88,6 +95,7 @@ class WhisperStreamingTranscriber: if any(needs_fallback): _decode_options2: DecodingOptions = self._get_decoding_options( t=t, + prompt=self.buffer_tokens, beam_size=None, patience=0.0, best_of=self.config.best_of, @@ -201,7 +209,10 @@ class WhisperStreamingTranscriber: .to(self.model.device) # type:ignore .to(self.dtype) # type:ignore ) - results = self._decode_with_fallback(segment=segment) + results = self._decode_with_fallback( + segment=segment, + prompt=self.buffer_tokens, + ) result = results[0] if self.config.no_speech_threshold is not None: