mirror of
https://github.com/shirayu/whispering.git
synced 2024-11-10 18:51:08 +00:00
Merge remote-tracking branch 'origin/master' into vad
This commit is contained in:
commit
45eb0bc34d
8 changed files with 91 additions and 73 deletions
|
@ -32,7 +32,7 @@ whispering --language en --model tiny
|
||||||
- ``--model`` set the [model name](https://github.com/openai/whisper#available-models-and-languages) to use. Larger models will be more accurate, but may not be able to transcribe in real time.
|
- ``--model`` set the [model name](https://github.com/openai/whisper#available-models-and-languages) to use. Larger models will be more accurate, but may not be able to transcribe in real time.
|
||||||
- ``--language`` sets the language to transcribe. The list of languages are shown with ``whispering -h``
|
- ``--language`` sets the language to transcribe. The list of languages are shown with ``whispering -h``
|
||||||
- ``--no-progress`` disables the progress message
|
- ``--no-progress`` disables the progress message
|
||||||
- ``-t`` sets temperatures to decode. You can set several like (``-t 0.0 -t 0.1 -t 0.5``), but too many temperatures exhaust decoding time
|
- ``-t`` sets temperatures to decode. You can set several like ``-t 0.0 -t 0.1 -t 0.5``, but too many temperatures exhaust decoding time
|
||||||
- ``--debug`` outputs logs for debug
|
- ``--debug`` outputs logs for debug
|
||||||
|
|
||||||
### Parse interval
|
### Parse interval
|
||||||
|
|
6
poetry.lock
generated
6
poetry.lock
generated
|
@ -519,13 +519,13 @@ dev = ["pytest"]
|
||||||
[package.source]
|
[package.source]
|
||||||
type = "git"
|
type = "git"
|
||||||
url = "https://github.com/openai/whisper.git"
|
url = "https://github.com/openai/whisper.git"
|
||||||
reference = '62fe7f1009a534986ac1d32a4aef8c244d029c28'
|
reference = '0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f'
|
||||||
resolved_reference = "62fe7f1009a534986ac1d32a4aef8c244d029c28"
|
resolved_reference = "0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f"
|
||||||
|
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "1.1"
|
lock-version = "1.1"
|
||||||
python-versions = ">=3.8,<3.11"
|
python-versions = ">=3.8,<3.11"
|
||||||
content-hash = "75e53434d1d46d54a886ca7a896a2f0ba0072a1848f90d5b6dc46ea2c5b47191"
|
content-hash = "ab527970383bc2245dee005627d0695812601115a36e15a5ef9e66d1185791bf"
|
||||||
|
|
||||||
[metadata.files]
|
[metadata.files]
|
||||||
black = [
|
black = [
|
||||||
|
|
|
@ -8,7 +8,7 @@ packages = [{include = "whispering"}]
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = ">=3.8,<3.11"
|
python = ">=3.8,<3.11"
|
||||||
whisper = {git = "https://github.com/openai/whisper.git", rev = '62fe7f1009a534986ac1d32a4aef8c244d029c28'}
|
whisper = {git = "https://github.com/openai/whisper.git", rev = '0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f'}
|
||||||
sounddevice = "^0.4.5"
|
sounddevice = "^0.4.5"
|
||||||
pydantic = "^1.10.2"
|
pydantic = "^1.10.2"
|
||||||
websockets = "^10.3"
|
websockets = "^10.3"
|
||||||
|
|
|
@ -214,11 +214,14 @@ def main() -> None:
|
||||||
if opts.mode == "client":
|
if opts.mode == "client":
|
||||||
assert opts.language is None
|
assert opts.language is None
|
||||||
assert opts.model is None
|
assert opts.model is None
|
||||||
asyncio.run(
|
try:
|
||||||
run_websocket_client(
|
asyncio.run(
|
||||||
opts=opts,
|
run_websocket_client(
|
||||||
|
opts=opts,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
assert opts.language is not None
|
assert opts.language is not None
|
||||||
assert opts.model is not None
|
assert opts.model is not None
|
||||||
|
|
|
@ -2,7 +2,6 @@
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import BaseModel, root_validator
|
from pydantic import BaseModel, root_validator
|
||||||
|
|
||||||
|
@ -56,4 +55,4 @@ class ParsedChunk(BaseModel):
|
||||||
class SpeechSegment(BaseModel, arbitrary_types_allowed=True):
|
class SpeechSegment(BaseModel, arbitrary_types_allowed=True):
|
||||||
start_block_idx: int
|
start_block_idx: int
|
||||||
end_block_idx: int
|
end_block_idx: int
|
||||||
segment: np.ndarray
|
segment: torch.Tensor
|
||||||
|
|
|
@ -5,6 +5,7 @@ from logging import getLogger
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import websockets
|
import websockets
|
||||||
|
from websockets.exceptions import ConnectionClosedOK
|
||||||
|
|
||||||
from whispering.transcriber import Context, WhisperStreamingTranscriber
|
from whispering.transcriber import Context, WhisperStreamingTranscriber
|
||||||
|
|
||||||
|
@ -15,10 +16,16 @@ async def serve_with_websocket_main(websocket):
|
||||||
global g_wsp
|
global g_wsp
|
||||||
global g_ctx
|
global g_ctx
|
||||||
idx: int = 0
|
idx: int = 0
|
||||||
|
ctx: Context = g_ctx.copy(
|
||||||
|
deep=True,
|
||||||
|
)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
logger.debug(f"Segment #: {idx}")
|
logger.debug(f"Segment #: {idx}")
|
||||||
message = await websocket.recv()
|
try:
|
||||||
|
message = await websocket.recv()
|
||||||
|
except ConnectionClosedOK:
|
||||||
|
break
|
||||||
|
|
||||||
if isinstance(message, str):
|
if isinstance(message, str):
|
||||||
logger.debug(f"Got str: {message}")
|
logger.debug(f"Got str: {message}")
|
||||||
|
@ -26,7 +33,10 @@ async def serve_with_websocket_main(websocket):
|
||||||
|
|
||||||
logger.debug(f"Message size: {len(message)}")
|
logger.debug(f"Message size: {len(message)}")
|
||||||
segment = np.frombuffer(message, dtype=np.float32)
|
segment = np.frombuffer(message, dtype=np.float32)
|
||||||
for chunk in g_wsp.transcribe(segment=segment, ctx=g_ctx):
|
for chunk in g_wsp.transcribe(
|
||||||
|
segment=segment, # type: ignore
|
||||||
|
ctx=ctx,
|
||||||
|
):
|
||||||
await websocket.send(chunk.json())
|
await websocket.send(chunk.json())
|
||||||
idx += 1
|
idx += 1
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,8 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import Final, Iterator, List, Optional, Union
|
from typing import Final, Iterator, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from whisper import Whisper, load_model
|
from whisper import Whisper, load_model
|
||||||
from whisper.audio import (
|
from whisper.audio import (
|
||||||
|
@ -77,59 +76,51 @@ class WhisperStreamingTranscriber:
|
||||||
suppress_blank=True,
|
suppress_blank=True,
|
||||||
suppress_tokens="-1",
|
suppress_tokens="-1",
|
||||||
without_timestamps=False,
|
without_timestamps=False,
|
||||||
max_initial_timestamp=0.0,
|
max_initial_timestamp=1.0,
|
||||||
fp16=self.fp16,
|
fp16=self.fp16,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _decode_with_fallback(
|
def _decode_with_fallback(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
segment: np.ndarray,
|
segment: torch.Tensor,
|
||||||
ctx: Context,
|
ctx: Context,
|
||||||
) -> List[DecodingResult]:
|
) -> DecodingResult:
|
||||||
assert len(ctx.temperatures) >= 1
|
assert len(ctx.temperatures) >= 1
|
||||||
t = ctx.temperatures[0]
|
decode_result: Optional[DecodingResult] = None
|
||||||
logger.debug(f"temperature: {t}")
|
|
||||||
|
|
||||||
_decode_options1: DecodingOptions = self._get_decoding_options(
|
for t in ctx.temperatures:
|
||||||
t=t,
|
_decode_options: DecodingOptions = self._get_decoding_options(
|
||||||
prompt=ctx.buffer_tokens,
|
t=t,
|
||||||
beam_size=ctx.beam_size,
|
prompt=ctx.buffer_tokens,
|
||||||
patience=None,
|
beam_size=ctx.beam_size if t <= 0 else None,
|
||||||
best_of=None,
|
patience=ctx.patience if t <= 0 else None,
|
||||||
)
|
best_of=ctx.best_of if t < 0 else None,
|
||||||
results: List[DecodingResult] = self.model.decode(segment, _decode_options1) # type: ignore
|
)
|
||||||
|
logger.debug(f"DecodeOptions: {_decode_options}")
|
||||||
|
decode_result = self.model.decode(
|
||||||
|
segment,
|
||||||
|
_decode_options,
|
||||||
|
) # type: ignore
|
||||||
|
assert decode_result is not None
|
||||||
|
|
||||||
for t in ctx.temperatures[1:]:
|
needs_fallback: bool = False
|
||||||
needs_fallback = [
|
if (
|
||||||
ctx.compression_ratio_threshold is not None
|
ctx.compression_ratio_threshold is not None
|
||||||
and result.compression_ratio > ctx.compression_ratio_threshold
|
and decode_result.compression_ratio > ctx.compression_ratio_threshold
|
||||||
or ctx.logprob_threshold is not None
|
):
|
||||||
and result.avg_logprob < ctx.logprob_threshold
|
needs_fallback = True # too repetitive
|
||||||
for result in results
|
if (
|
||||||
]
|
ctx.logprob_threshold is not None
|
||||||
if any(needs_fallback):
|
and decode_result.avg_logprob < ctx.logprob_threshold
|
||||||
logger.debug(
|
):
|
||||||
f"Fall back with temperature: {t}, needs_fallback: {needs_fallback}"
|
needs_fallback = True # average log probability is too low
|
||||||
)
|
|
||||||
_decode_options2: DecodingOptions = self._get_decoding_options(
|
if not needs_fallback:
|
||||||
t=t,
|
|
||||||
prompt=ctx.buffer_tokens,
|
|
||||||
beam_size=None,
|
|
||||||
patience=ctx.patience,
|
|
||||||
best_of=ctx.best_of,
|
|
||||||
)
|
|
||||||
retries: List[DecodingResult] = self.model.decode(
|
|
||||||
segment[needs_fallback], _decode_options2 # type: ignore
|
|
||||||
)
|
|
||||||
for retry_index, original_index in enumerate(
|
|
||||||
np.nonzero(needs_fallback)[0]
|
|
||||||
):
|
|
||||||
results[original_index] = retries[retry_index]
|
|
||||||
else:
|
|
||||||
break
|
break
|
||||||
logger.debug(f"# of results: {len(results)}")
|
|
||||||
return results
|
assert isinstance(decode_result, DecodingResult)
|
||||||
|
return decode_result
|
||||||
|
|
||||||
def _get_chunk(
|
def _get_chunk(
|
||||||
self,
|
self,
|
||||||
|
@ -205,7 +196,10 @@ class WhisperStreamingTranscriber:
|
||||||
duration = segment_duration
|
duration = segment_duration
|
||||||
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
||||||
logger.debug(f"Length of consecutive: 0, timestamps: {timestamps}")
|
logger.debug(f"Length of consecutive: 0, timestamps: {timestamps}")
|
||||||
if len(timestamps) > 0:
|
if (
|
||||||
|
len(timestamps) > 0
|
||||||
|
and timestamps[-1].item() != self.tokenizer.timestamp_begin
|
||||||
|
):
|
||||||
# no consecutive timestamps but it has a timestamp; use the last one.
|
# no consecutive timestamps but it has a timestamp; use the last one.
|
||||||
# single timestamp at the end means no speech after the last timestamp.
|
# single timestamp at the end means no speech after the last timestamp.
|
||||||
last_timestamp_position = (
|
last_timestamp_position = (
|
||||||
|
@ -232,13 +226,13 @@ class WhisperStreamingTranscriber:
|
||||||
def transcribe(
|
def transcribe(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
segment: np.ndarray,
|
segment: torch.Tensor,
|
||||||
ctx: Context,
|
ctx: Context,
|
||||||
) -> Iterator[ParsedChunk]:
|
) -> Iterator[ParsedChunk]:
|
||||||
vad_probs = self.vad(segment)
|
for speech_segment in self.vad(segment=segment):
|
||||||
logger.debug(f"{vad_probs}")
|
logger.debug(f"{speech_segment}")
|
||||||
|
|
||||||
new_mel = log_mel_spectrogram(audio=segment).unsqueeze(0)
|
new_mel = log_mel_spectrogram(audio=segment)
|
||||||
logger.debug(f"Incoming new_mel.shape: {new_mel.shape}")
|
logger.debug(f"Incoming new_mel.shape: {new_mel.shape}")
|
||||||
if ctx.buffer_mel is None:
|
if ctx.buffer_mel is None:
|
||||||
mel = new_mel
|
mel = new_mel
|
||||||
|
@ -251,7 +245,7 @@ class WhisperStreamingTranscriber:
|
||||||
seek: int = 0
|
seek: int = 0
|
||||||
while seek < mel.shape[-1]:
|
while seek < mel.shape[-1]:
|
||||||
segment = (
|
segment = (
|
||||||
pad_or_trim(mel[:, :, seek:], N_FRAMES)
|
pad_or_trim(mel[:, seek:], N_FRAMES)
|
||||||
.to(self.model.device) # type: ignore
|
.to(self.model.device) # type: ignore
|
||||||
.to(self.dtype)
|
.to(self.dtype)
|
||||||
)
|
)
|
||||||
|
@ -262,11 +256,10 @@ class WhisperStreamingTranscriber:
|
||||||
f"seek={seek}, timestamp={ctx.timestamp}, "
|
f"seek={seek}, timestamp={ctx.timestamp}, "
|
||||||
f"mel.shape: {mel.shape}, segment.shape: {segment.shape}"
|
f"mel.shape: {mel.shape}, segment.shape: {segment.shape}"
|
||||||
)
|
)
|
||||||
results = self._decode_with_fallback(
|
result = self._decode_with_fallback(
|
||||||
segment=segment,
|
segment=segment,
|
||||||
ctx=ctx,
|
ctx=ctx,
|
||||||
)
|
)
|
||||||
result = results[0]
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Result: temperature={result.temperature:.2f}, no_speech_prob={result.no_speech_prob:.2f}, "
|
f"Result: temperature={result.temperature:.2f}, no_speech_prob={result.no_speech_prob:.2f}, "
|
||||||
f"avg_logprob={result.avg_logprob:.2f}"
|
f"avg_logprob={result.avg_logprob:.2f}"
|
||||||
|
@ -306,7 +299,7 @@ class WhisperStreamingTranscriber:
|
||||||
if mel.shape[-1] - seek <= 0:
|
if mel.shape[-1] - seek <= 0:
|
||||||
logger.debug(f"ctx.buffer_mel is None ({mel.shape}, {seek})")
|
logger.debug(f"ctx.buffer_mel is None ({mel.shape}, {seek})")
|
||||||
return
|
return
|
||||||
ctx.buffer_mel = mel[:, :, seek:]
|
ctx.buffer_mel = mel[:, seek:]
|
||||||
assert ctx.buffer_mel is not None
|
assert ctx.buffer_mel is not None
|
||||||
logger.debug(f"ctx.buffer_mel.shape: {ctx.buffer_mel.shape}")
|
logger.debug(f"ctx.buffer_mel.shape: {ctx.buffer_mel.shape}")
|
||||||
del mel
|
del mel
|
||||||
|
|
|
@ -2,7 +2,6 @@
|
||||||
|
|
||||||
from typing import Iterator
|
from typing import Iterator
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from whisper.audio import N_FRAMES, SAMPLE_RATE
|
from whisper.audio import N_FRAMES, SAMPLE_RATE
|
||||||
|
|
||||||
|
@ -21,16 +20,26 @@ class VAD:
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
segment: np.ndarray,
|
segment: torch.Tensor,
|
||||||
thredhold: float = 0.5,
|
thredhold: float = 0.5,
|
||||||
) -> Iterator[SpeechBlock]:
|
) -> Iterator[SpeechSegment]:
|
||||||
# segment.shape should be multiple of (N_FRAMES,)
|
# segment.shape should be multiple of (N_FRAMES,)
|
||||||
|
|
||||||
|
def my_ret(
|
||||||
|
*,
|
||||||
|
start_block_idx: int,
|
||||||
|
idx: int,
|
||||||
|
) -> SpeechSegment:
|
||||||
|
return SpeechSegment(
|
||||||
|
start_block_idx=start_block_idx,
|
||||||
|
end_block_idx=idx,
|
||||||
|
segment=segment[N_FRAMES * start_block_idx : N_FRAMES * idx],
|
||||||
|
)
|
||||||
|
|
||||||
block_size: int = int(segment.shape[0] / N_FRAMES)
|
block_size: int = int(segment.shape[0] / N_FRAMES)
|
||||||
|
|
||||||
start_block_idx = None
|
start_block_idx = None
|
||||||
for idx in range(block_size + 1):
|
for idx in range(block_size):
|
||||||
if idx < block_size:
|
|
||||||
start: int = N_FRAMES * idx
|
start: int = N_FRAMES * idx
|
||||||
end: int = N_FRAMES * (idx + 1)
|
end: int = N_FRAMES * (idx + 1)
|
||||||
vad_prob = self.vad_model(
|
vad_prob = self.vad_model(
|
||||||
|
@ -42,9 +51,13 @@ class VAD:
|
||||||
start_block_idx = idx
|
start_block_idx = idx
|
||||||
else:
|
else:
|
||||||
if start_block_idx is not None:
|
if start_block_idx is not None:
|
||||||
yield SpeechSegment(
|
yield my_ret(
|
||||||
start_block_idx=start_block_idx,
|
start_block_idx=start_block_idx,
|
||||||
end_block_idx=idx,
|
idx=idx,
|
||||||
segment=segment[N_FRAMES * start_block_idx : N_FRAMES * idx],
|
|
||||||
)
|
)
|
||||||
start_block_idx = None
|
start_block_idx = None
|
||||||
|
if start_block_idx is not None:
|
||||||
|
yield my_ret(
|
||||||
|
start_block_idx=start_block_idx,
|
||||||
|
idx=block_size,
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in a new issue