mirror of
https://github.com/shirayu/whispering.git
synced 2024-06-12 10:19:31 +00:00
Fix -n (Resolve #3)
This commit is contained in:
parent
f75b13e0ba
commit
501b37f4a7
|
@ -21,13 +21,12 @@ poetry install --only main
|
||||||
poetry run pip install -U torch torchaudio --extra-index-url https://download.pytorch.org/whl/cu116
|
poetry run pip install -U torch torchaudio --extra-index-url https://download.pytorch.org/whl/cu116
|
||||||
|
|
||||||
# Run in English
|
# Run in English
|
||||||
poetry run whisper_streaming --language en --model base -n 20
|
poetry run whisper_streaming --language en --model base
|
||||||
```
|
```
|
||||||
|
|
||||||
- ``--help`` shows full options
|
- ``--help`` shows full options
|
||||||
- ``--language`` sets the language to transcribe. The list of languages are shown with ``poetry run whisper_streaming -h``
|
- ``--language`` sets the language to transcribe. The list of languages are shown with ``poetry run whisper_streaming -h``
|
||||||
- ``-t`` sets temperatures to decode. You can set several like (``-t 0.0 -t 0.1 -t 0.5``), but too many temperatures exhaust decoding time
|
- ``-t`` sets temperatures to decode. You can set several like (``-t 0.0 -t 0.1 -t 0.5``), but too many temperatures exhaust decoding time
|
||||||
- ``-n`` sets interval of parsing. Larger values can improve accuracy but consume more memory.
|
|
||||||
- ``--debug`` outputs logs for debug
|
- ``--debug`` outputs logs for debug
|
||||||
|
|
||||||
## Tips
|
## Tips
|
||||||
|
|
|
@ -81,8 +81,8 @@ def get_opts() -> argparse.Namespace:
|
||||||
"--num_block",
|
"--num_block",
|
||||||
"-n",
|
"-n",
|
||||||
type=int,
|
type=int,
|
||||||
default=20,
|
default=160,
|
||||||
help="Number of operation unit. Larger values can improve accuracy but consume more memory.",
|
help="Number of operation unit",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--temperature",
|
"--temperature",
|
||||||
|
|
|
@ -239,16 +239,17 @@ class WhisperStreamingTranscriber:
|
||||||
self.buffer_mel = None
|
self.buffer_mel = None
|
||||||
|
|
||||||
seek: int = 0
|
seek: int = 0
|
||||||
rest_start: Optional[int] = None
|
|
||||||
while seek < mel.shape[-1]:
|
while seek < mel.shape[-1]:
|
||||||
segment = (
|
segment = (
|
||||||
pad_or_trim(mel[:, :, seek:], N_FRAMES)
|
pad_or_trim(mel[:, :, seek:], N_FRAMES)
|
||||||
.to(self.model.device) # type: ignore
|
.to(self.model.device) # type: ignore
|
||||||
.to(self.dtype)
|
.to(self.dtype)
|
||||||
)
|
)
|
||||||
|
if segment.shape[-1] > mel.shape[-1]:
|
||||||
|
logger.warning("Padding is not expected while speaking")
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"seek={seek}, timestamp={self.timestamp}, rest_start={rest_start},"
|
f"seek={seek}, timestamp={self.timestamp}"
|
||||||
f"mel.shape: {mel.shape}, segment.shape: {segment.shape}"
|
f"mel.shape: {mel.shape}, segment.shape: {segment.shape}"
|
||||||
)
|
)
|
||||||
results = self._decode_with_fallback(
|
results = self._decode_with_fallback(
|
||||||
|
@ -266,7 +267,6 @@ class WhisperStreamingTranscriber:
|
||||||
and result.avg_logprob > self.config.logprob_threshold
|
and result.avg_logprob > self.config.logprob_threshold
|
||||||
):
|
):
|
||||||
seek += segment.shape[-1]
|
seek += segment.shape[-1]
|
||||||
rest_start = None
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Skip: {segment.shape[-1]}, new seek={seek}, mel.shape: {mel.shape}"
|
f"Skip: {segment.shape[-1]}, new seek={seek}, mel.shape: {mel.shape}"
|
||||||
)
|
)
|
||||||
|
@ -283,15 +283,14 @@ class WhisperStreamingTranscriber:
|
||||||
yield v
|
yield v
|
||||||
if last_timestamp_position is None:
|
if last_timestamp_position is None:
|
||||||
seek += segment.shape[-1]
|
seek += segment.shape[-1]
|
||||||
rest_start = None
|
|
||||||
else:
|
else:
|
||||||
seek += last_timestamp_position * self.input_stride
|
seek += last_timestamp_position * self.input_stride
|
||||||
rest_start = seek
|
|
||||||
logger.debug(f"new seek={seek}, mel.shape: {mel.shape}")
|
logger.debug(f"new seek={seek}, mel.shape: {mel.shape}")
|
||||||
|
|
||||||
logger.debug(f"Last rest_start={rest_start}, mel.shape: {mel.shape}")
|
if mel.shape[-1] < N_FRAMES:
|
||||||
if rest_start is None:
|
break
|
||||||
return
|
|
||||||
|
|
||||||
self.buffer_mel = mel[:, :, rest_start:]
|
if mel.shape[-1] - seek < 0:
|
||||||
|
return
|
||||||
|
self.buffer_mel = mel[:, :, seek:]
|
||||||
del mel
|
del mel
|
||||||
|
|
Loading…
Reference in a new issue