mirror of
https://github.com/shirayu/whispering.git
synced 2025-02-16 10:35:16 +00:00
Fix -n (Resolve #3)
This commit is contained in:
parent
f75b13e0ba
commit
501b37f4a7
3 changed files with 11 additions and 13 deletions
|
@ -21,13 +21,12 @@ poetry install --only main
|
|||
poetry run pip install -U torch torchaudio --extra-index-url https://download.pytorch.org/whl/cu116
|
||||
|
||||
# Run in English
|
||||
poetry run whisper_streaming --language en --model base -n 20
|
||||
poetry run whisper_streaming --language en --model base
|
||||
```
|
||||
|
||||
- ``--help`` shows full options
|
||||
- ``--language`` sets the language to transcribe. The list of languages are shown with ``poetry run whisper_streaming -h``
|
||||
- ``-t`` sets temperatures to decode. You can set several like (``-t 0.0 -t 0.1 -t 0.5``), but too many temperatures exhaust decoding time
|
||||
- ``-n`` sets interval of parsing. Larger values can improve accuracy but consume more memory.
|
||||
- ``--debug`` outputs logs for debug
|
||||
|
||||
## Tips
|
||||
|
|
|
@ -81,8 +81,8 @@ def get_opts() -> argparse.Namespace:
|
|||
"--num_block",
|
||||
"-n",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Number of operation unit. Larger values can improve accuracy but consume more memory.",
|
||||
default=160,
|
||||
help="Number of operation unit",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
|
|
|
@ -239,16 +239,17 @@ class WhisperStreamingTranscriber:
|
|||
self.buffer_mel = None
|
||||
|
||||
seek: int = 0
|
||||
rest_start: Optional[int] = None
|
||||
while seek < mel.shape[-1]:
|
||||
segment = (
|
||||
pad_or_trim(mel[:, :, seek:], N_FRAMES)
|
||||
.to(self.model.device) # type: ignore
|
||||
.to(self.dtype)
|
||||
)
|
||||
if segment.shape[-1] > mel.shape[-1]:
|
||||
logger.warning("Padding is not expected while speaking")
|
||||
|
||||
logger.debug(
|
||||
f"seek={seek}, timestamp={self.timestamp}, rest_start={rest_start},"
|
||||
f"seek={seek}, timestamp={self.timestamp}"
|
||||
f"mel.shape: {mel.shape}, segment.shape: {segment.shape}"
|
||||
)
|
||||
results = self._decode_with_fallback(
|
||||
|
@ -266,7 +267,6 @@ class WhisperStreamingTranscriber:
|
|||
and result.avg_logprob > self.config.logprob_threshold
|
||||
):
|
||||
seek += segment.shape[-1]
|
||||
rest_start = None
|
||||
logger.debug(
|
||||
f"Skip: {segment.shape[-1]}, new seek={seek}, mel.shape: {mel.shape}"
|
||||
)
|
||||
|
@ -283,15 +283,14 @@ class WhisperStreamingTranscriber:
|
|||
yield v
|
||||
if last_timestamp_position is None:
|
||||
seek += segment.shape[-1]
|
||||
rest_start = None
|
||||
else:
|
||||
seek += last_timestamp_position * self.input_stride
|
||||
rest_start = seek
|
||||
logger.debug(f"new seek={seek}, mel.shape: {mel.shape}")
|
||||
|
||||
logger.debug(f"Last rest_start={rest_start}, mel.shape: {mel.shape}")
|
||||
if rest_start is None:
|
||||
return
|
||||
if mel.shape[-1] < N_FRAMES:
|
||||
break
|
||||
|
||||
self.buffer_mel = mel[:, :, rest_start:]
|
||||
if mel.shape[-1] - seek < 0:
|
||||
return
|
||||
self.buffer_mel = mel[:, :, seek:]
|
||||
del mel
|
||||
|
|
Loading…
Reference in a new issue