Removed --allow-padding and add --max_nospeech_skip option (Resolve #13)

This commit is contained in:
Yuta Hayashibe 2022-10-15 14:48:08 +09:00
parent 20b8970aa9
commit 75147cae86
6 changed files with 41 additions and 24 deletions

View file

@ -41,7 +41,7 @@ whispering --language en --model tiny
- ``--no-progress`` disables the progress message - ``--no-progress`` disables the progress message
- ``-t`` sets temperatures to decode. You can set several like ``-t 0.0 -t 0.1 -t 0.5``, but too many temperatures exhaust decoding time - ``-t`` sets temperatures to decode. You can set several like ``-t 0.0 -t 0.1 -t 0.5``, but too many temperatures exhaust decoding time
- ``--debug`` outputs logs for debug - ``--debug`` outputs logs for debug
- ``--vad`` sets VAD (Voice Activity Detection) threshold. 0 disables VAD and forces whisper to analyze non-voice activity sound period - ``--vad`` sets VAD (Voice Activity Detection) threshold. The default is ``0.5``. 0 disables VAD and forces whisper to analyze non-voice activity sound period
- ``--output`` sets output file (Default: Standard output) - ``--output`` sets output file (Default: Standard output)
### Parse interval ### Parse interval
@ -49,19 +49,10 @@ whispering --language en --model tiny
By default, whispering performs VAD for every 3.75 second. By default, whispering performs VAD for every 3.75 second.
This interval is determined by the value of ``-n`` and its default is ``20``. This interval is determined by the value of ``-n`` and its default is ``20``.
When an interval is predicted as "silence", it will not be passed to whisper. When an interval is predicted as "silence", it will not be passed to whisper.
If you want to disable VAD, please use ``--no-vad`` option. If you want to disable VAD, please make VAD threshold 0 by adding ``--vad 0``.
By default, Whisper does not perform analysis until the total length of the segments determined by VAD to have speech exceeds 30 seconds. By default, Whisper does not perform analysis until the total length of the segments determined by VAD to have speech exceeds 30 seconds.
This is because Whisper is trained to make predictions for 30-second intervals. However, if silence segments appear 16 times (the default value of ``--max_nospeech_skip``) after speech is detected, the analysis is performed.
Nevertheless, if you want to force Whisper to perform analysis even if a segment is less than 30 seconds, please use ``--allow-padding`` option like this.
```bash
whispering --language en --model tiny -n 20 --allow-padding
```
This forces Whisper to analyze every 3.75 seconds speech segment.
Using ``--allow-padding`` may sacrifice the accuracy, while you can get quick response.
The smaller value of ``-n`` with ``--allow-padding`` is, the worse the accuracy becomes.
## Example of web socket ## Example of web socket
@ -81,7 +72,7 @@ whispering --language en --model tiny --host 0.0.0.0 --port 8000
whispering --host ADDRESS_OF_HOST --port 8000 --mode client whispering --host ADDRESS_OF_HOST --port 8000 --mode client
``` ```
You can set ``-n``, ``--allow-padding`` and other options. You can set ``-n`` and other options.
## License ## License

View file

@ -14,7 +14,6 @@ def test_options():
"--mode server --mic 1", "--mode server --mic 1",
"--mode server --beam_size 3", "--mode server --beam_size 3",
"--mode server --temperature 0", "--mode server --temperature 0",
"--mode server --allow-padding",
"--mode server --num_block 3", "--mode server --num_block 3",
"--mode mic --host 0.0.0.0", "--mode mic --host 0.0.0.0",
"--mode mic --port 8000", "--mode mic --port 8000",

View file

@ -144,16 +144,18 @@ def get_opts() -> argparse.Namespace:
action="append", action="append",
default=[], default=[],
) )
group_ctx.add_argument(
"--allow-padding",
action="store_true",
)
group_ctx.add_argument( group_ctx.add_argument(
"--vad", "--vad",
type=float, type=float,
help="Threshold of VAD", help="Threshold of VAD",
default=0.5, default=0.5,
) )
group_ctx.add_argument(
"--max_nospeech_skip",
type=int,
help="Maximum number of skip to analyze because of nospeech",
default=16,
)
group_misc = parser.add_argument_group("Other options") group_misc = parser.add_argument_group("Other options")
group_misc.add_argument( group_misc.add_argument(
@ -224,7 +226,7 @@ def get_context(*, opts) -> Context:
protocol_version=CURRENT_PROTOCOL_VERSION, protocol_version=CURRENT_PROTOCOL_VERSION,
beam_size=opts.beam_size, beam_size=opts.beam_size,
temperatures=opts.temperature, temperatures=opts.temperature,
allow_padding=opts.allow_padding, max_nospeech_skip=opts.max_nospeech_skip,
vad_threshold=opts.vad, vad_threshold=opts.vad,
) )
logger.debug(f"Context: {ctx}") logger.debug(f"Context: {ctx}")
@ -245,7 +247,6 @@ def is_valid_arg(opts) -> bool:
"mic", "mic",
"beam_size", "beam_size",
"temperature", "temperature",
"allow_padding",
] ]
elif opts.mode == Mode.mic.value: elif opts.mode == Mode.mic.value:
keys = [ keys = [

View file

@ -32,9 +32,9 @@ class Context(BaseModel, arbitrary_types_allowed=True):
timestamp: float = 0.0 timestamp: float = 0.0
buffer_tokens: List[torch.Tensor] = [] buffer_tokens: List[torch.Tensor] = []
buffer_mel: Optional[torch.Tensor] = None buffer_mel: Optional[torch.Tensor] = None
nosoeech_skip_count: Optional[int] = None
temperatures: List[float] temperatures: List[float]
allow_padding: bool = False
patience: Optional[float] = None patience: Optional[float] = None
compression_ratio_threshold: Optional[float] = 2.4 compression_ratio_threshold: Optional[float] = 2.4
logprob_threshold: Optional[float] = -1.0 logprob_threshold: Optional[float] = -1.0
@ -46,6 +46,7 @@ class Context(BaseModel, arbitrary_types_allowed=True):
compression_ratio_threshold: Optional[float] = 2.4 compression_ratio_threshold: Optional[float] = 2.4
buffer_threshold: Optional[float] = 0.5 buffer_threshold: Optional[float] = 0.5
vad_threshold: float vad_threshold: float
max_nospeech_skip: int
class ParsedChunk(BaseModel): class ParsedChunk(BaseModel):

View file

@ -233,6 +233,7 @@ class WhisperStreamingTranscriber:
ctx: Context, ctx: Context,
) -> Iterator[ParsedChunk]: ) -> Iterator[ParsedChunk]:
logger.debug(f"{len(audio)}") logger.debug(f"{len(audio)}")
force_padding: bool = False
if ctx.vad_threshold > 0.0: if ctx.vad_threshold > 0.0:
x = [ x = [
@ -246,7 +247,20 @@ class WhisperStreamingTranscriber:
if len(x) == 0: # No speech if len(x) == 0: # No speech
logger.debug("No speech") logger.debug("No speech")
ctx.timestamp += len(audio) / N_FRAMES * self.duration_pre_one_mel ctx.timestamp += len(audio) / N_FRAMES * self.duration_pre_one_mel
if ctx.nosoeech_skip_count is not None:
ctx.nosoeech_skip_count += 1
if (
ctx.nosoeech_skip_count is None
or ctx.nosoeech_skip_count <= ctx.max_nospeech_skip
):
logger.debug(
f"nosoeech_skip_count: {ctx.nosoeech_skip_count} (<= {ctx.max_nospeech_skip})"
)
return return
ctx.nosoeech_skip_count = None
force_padding = True
new_mel = log_mel_spectrogram(audio=audio) new_mel = log_mel_spectrogram(audio=audio)
logger.debug(f"Incoming new_mel.shape: {new_mel.shape}") logger.debug(f"Incoming new_mel.shape: {new_mel.shape}")
@ -261,12 +275,15 @@ class WhisperStreamingTranscriber:
seek: int = 0 seek: int = 0
while seek < mel.shape[-1]: while seek < mel.shape[-1]:
logger.debug(f"seek: {seek}") logger.debug(f"seek: {seek}")
if mel.shape[-1] - seek <= 0:
logger.debug(f"No more seek: mel.shape={mel.shape}, seek={seek}")
break
if mel.shape[-1] - seek < N_FRAMES: if mel.shape[-1] - seek < N_FRAMES:
logger.debug( logger.debug(
f"mel.shape ({mel.shape[-1]}) - seek ({seek}) < N_FRAMES ({N_FRAMES})" f"mel.shape ({mel.shape[-1]}) - seek ({seek}) < N_FRAMES ({N_FRAMES})"
) )
if ctx.allow_padding: if force_padding:
logger.warning("Padding is not expected while speaking") logger.debug("Padding")
else: else:
logger.debug("No padding") logger.debug("No padding")
break break
@ -319,9 +336,13 @@ class WhisperStreamingTranscriber:
logger.debug(f"new seek={seek}, mel.shape: {mel.shape}") logger.debug(f"new seek={seek}, mel.shape: {mel.shape}")
if mel.shape[-1] - seek <= 0: if mel.shape[-1] - seek <= 0:
ctx.buffer_mel = None
ctx.nosoeech_skip_count = None
logger.debug(f"ctx.buffer_mel is None ({mel.shape}, {seek})") logger.debug(f"ctx.buffer_mel is None ({mel.shape}, {seek})")
return return
ctx.buffer_mel = mel[:, seek:] ctx.buffer_mel = mel[:, seek:]
assert ctx.buffer_mel is not None assert ctx.buffer_mel is not None
logger.debug(f"ctx.buffer_mel.shape: {ctx.buffer_mel.shape}") logger.debug(f"ctx.buffer_mel.shape: {ctx.buffer_mel.shape}")
del mel del mel
if ctx.nosoeech_skip_count is None:
ctx.nosoeech_skip_count = 0 # start count

View file

@ -1,5 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from logging import getLogger
from typing import Iterator, Optional from typing import Iterator, Optional
import numpy as np import numpy as np
@ -8,6 +9,8 @@ from whisper.audio import N_FRAMES, SAMPLE_RATE
from whispering.schema import SpeechSegment from whispering.schema import SpeechSegment
logger = getLogger(__name__)
class VAD: class VAD:
def __init__( def __init__(
@ -50,6 +53,7 @@ class VAD:
torch.from_numpy(audio[start:end]), torch.from_numpy(audio[start:end]),
SAMPLE_RATE, SAMPLE_RATE,
).item() ).item()
logger.debug(f"VAD: {vad_prob} (threshold={threshold})")
if vad_prob > threshold: if vad_prob > threshold:
if start_block_idx is None: if start_block_idx is None:
start_block_idx = idx start_block_idx = idx