mirror of
https://github.com/shirayu/whispering.git
synced 2025-02-16 10:35:16 +00:00
Added prompt
This commit is contained in:
parent
dcddd07176
commit
6f49e29838
1 changed files with 14 additions and 3 deletions
|
@ -43,6 +43,7 @@ class WhisperStreamingTranscriber:
|
|||
self,
|
||||
*,
|
||||
t,
|
||||
prompt,
|
||||
beam_size: Optional[int],
|
||||
patience: float,
|
||||
best_of: Optional[int],
|
||||
|
@ -56,7 +57,7 @@ class WhisperStreamingTranscriber:
|
|||
beam_size=beam_size,
|
||||
patience=patience,
|
||||
length_penalty=None,
|
||||
prompt=None,
|
||||
prompt=prompt,
|
||||
prefix=None,
|
||||
suppress_blank=True,
|
||||
suppress_tokens="-1",
|
||||
|
@ -65,12 +66,18 @@ class WhisperStreamingTranscriber:
|
|||
fp16=True,
|
||||
)
|
||||
|
||||
def _decode_with_fallback(self, *, segment: np.ndarray) -> List[DecodingResult]:
|
||||
def _decode_with_fallback(
|
||||
self,
|
||||
*,
|
||||
segment: np.ndarray,
|
||||
prompt,
|
||||
) -> List[DecodingResult]:
|
||||
assert len(self.config.temperatures) >= 1
|
||||
t = self.config.temperatures[0]
|
||||
|
||||
_decode_options1: DecodingOptions = self._get_decoding_options(
|
||||
t=t,
|
||||
prompt=self.buffer_tokens,
|
||||
beam_size=self.config.beam_size,
|
||||
patience=0.0,
|
||||
best_of=None,
|
||||
|
@ -88,6 +95,7 @@ class WhisperStreamingTranscriber:
|
|||
if any(needs_fallback):
|
||||
_decode_options2: DecodingOptions = self._get_decoding_options(
|
||||
t=t,
|
||||
prompt=self.buffer_tokens,
|
||||
beam_size=None,
|
||||
patience=0.0,
|
||||
best_of=self.config.best_of,
|
||||
|
@ -201,7 +209,10 @@ class WhisperStreamingTranscriber:
|
|||
.to(self.model.device) # type:ignore
|
||||
.to(self.dtype) # type:ignore
|
||||
)
|
||||
results = self._decode_with_fallback(segment=segment)
|
||||
results = self._decode_with_fallback(
|
||||
segment=segment,
|
||||
prompt=self.buffer_tokens,
|
||||
)
|
||||
result = results[0]
|
||||
|
||||
if self.config.no_speech_threshold is not None:
|
||||
|
|
Loading…
Reference in a new issue