Fix -n (Resolve #3)

This commit is contained in:
Yuta Hayashibe 2022-09-24 15:39:41 +09:00
parent f75b13e0ba
commit 501b37f4a7
3 changed files with 11 additions and 13 deletions

View file

@ -21,13 +21,12 @@ poetry install --only main
poetry run pip install -U torch torchaudio --extra-index-url https://download.pytorch.org/whl/cu116
# Run in English
poetry run whisper_streaming --language en --model base -n 20
poetry run whisper_streaming --language en --model base
```
- ``--help`` shows full options
- ``--language`` sets the language to transcribe. The list of languages are shown with ``poetry run whisper_streaming -h``
- ``-t`` sets temperatures to decode. You can set several like (``-t 0.0 -t 0.1 -t 0.5``), but too many temperatures exhaust decoding time
- ``-n`` sets interval of parsing. Larger values can improve accuracy but consume more memory.
- ``--debug`` outputs logs for debug
## Tips

View file

@ -81,8 +81,8 @@ def get_opts() -> argparse.Namespace:
"--num_block",
"-n",
type=int,
default=20,
help="Number of operation unit. Larger values can improve accuracy but consume more memory.",
default=160,
help="Number of operation unit",
)
parser.add_argument(
"--temperature",

View file

@ -239,16 +239,17 @@ class WhisperStreamingTranscriber:
self.buffer_mel = None
seek: int = 0
rest_start: Optional[int] = None
while seek < mel.shape[-1]:
segment = (
pad_or_trim(mel[:, :, seek:], N_FRAMES)
.to(self.model.device) # type: ignore
.to(self.dtype)
)
if segment.shape[-1] > mel.shape[-1]:
logger.warning("Padding is not expected while speaking")
logger.debug(
f"seek={seek}, timestamp={self.timestamp}, rest_start={rest_start},"
f"seek={seek}, timestamp={self.timestamp}"
f"mel.shape: {mel.shape}, segment.shape: {segment.shape}"
)
results = self._decode_with_fallback(
@ -266,7 +267,6 @@ class WhisperStreamingTranscriber:
and result.avg_logprob > self.config.logprob_threshold
):
seek += segment.shape[-1]
rest_start = None
logger.debug(
f"Skip: {segment.shape[-1]}, new seek={seek}, mel.shape: {mel.shape}"
)
@ -283,15 +283,14 @@ class WhisperStreamingTranscriber:
yield v
if last_timestamp_position is None:
seek += segment.shape[-1]
rest_start = None
else:
seek += last_timestamp_position * self.input_stride
rest_start = seek
logger.debug(f"new seek={seek}, mel.shape: {mel.shape}")
logger.debug(f"Last rest_start={rest_start}, mel.shape: {mel.shape}")
if rest_start is None:
return
if mel.shape[-1] < N_FRAMES:
break
self.buffer_mel = mel[:, :, rest_start:]
if mel.shape[-1] - seek < 0:
return
self.buffer_mel = mel[:, :, seek:]
del mel