whispering/whisper_streaming/cli.py

139 lines
3.3 KiB
Python
Raw Normal View History

2022-09-23 10:20:11 +00:00
#!/usr/bin/env python3
import argparse
import queue
2022-09-23 13:13:25 +00:00
from logging import DEBUG, INFO, basicConfig, getLogger
2022-09-23 10:28:11 +00:00
from typing import Optional, Union
2022-09-23 10:20:11 +00:00
import sounddevice as sd
import torch
from whisper import available_models
from whisper.audio import N_FRAMES, SAMPLE_RATE
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
from whisper_streaming.schema import WhisperConfig
from whisper_streaming.transcriber import WhisperStreamingTranscriber
logger = getLogger(__name__)
2022-09-23 10:28:11 +00:00
def transcribe_from_mic(
*,
config: WhisperConfig,
sd_device: Optional[Union[int, str]],
2022-09-23 11:45:18 +00:00
num_block: int,
2022-09-23 10:28:11 +00:00
) -> None:
2022-09-24 04:12:24 +00:00
logger.debug(f"WhisperConfig: {config}")
2022-09-23 10:20:11 +00:00
wsp = WhisperStreamingTranscriber(config=config)
q = queue.Queue()
def sd_callback(indata, frames, time, status):
if status:
logger.warning(status)
q.put(indata.ravel())
logger.info("Ready to transcribe")
with sd.InputStream(
samplerate=SAMPLE_RATE,
2022-09-23 11:45:18 +00:00
blocksize=N_FRAMES * num_block,
2022-09-23 10:20:11 +00:00
device=sd_device,
dtype="float32",
channels=1,
callback=sd_callback,
):
2022-09-23 13:44:16 +00:00
idx: int = 0
2022-09-23 10:20:11 +00:00
while True:
2022-09-24 01:13:10 +00:00
logger.debug(f"Segment #: {idx}, The rest of queue: {q.qsize()}")
2022-09-23 10:20:11 +00:00
segment = q.get()
2022-09-23 11:03:00 +00:00
for chunk in wsp.transcribe(segment=segment):
2022-09-23 11:26:32 +00:00
print(f"{chunk.start:.2f}->{chunk.end:.2f}\t{chunk.text}")
2022-09-23 13:44:16 +00:00
idx += 1
2022-09-23 10:20:11 +00:00
def get_opts() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"--language",
type=str,
default=None,
choices=sorted(LANGUAGES.keys())
+ sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]),
required=True,
)
parser.add_argument(
"--model",
type=str,
choices=available_models(),
required=True,
)
parser.add_argument(
"--device",
default="cuda" if torch.cuda.is_available() else "cpu",
help="device to use for PyTorch inference",
)
parser.add_argument(
"--beam_size",
"-b",
type=int,
default=5,
)
2022-09-23 11:45:18 +00:00
parser.add_argument(
"--num_block",
2022-09-23 11:45:56 +00:00
"-n",
2022-09-23 11:45:18 +00:00
type=int,
default=20,
help="Number of operation unit. Larger values can improve accuracy but consume more memory.",
)
2022-09-24 04:12:24 +00:00
parser.add_argument(
"--temperature",
"-t",
type=float,
action="append",
default=[],
)
2022-09-23 10:28:11 +00:00
parser.add_argument(
"--mic",
)
2022-09-23 13:01:40 +00:00
parser.add_argument(
"--debug",
action="store_true",
)
2022-09-23 10:20:11 +00:00
return parser.parse_args()
def main() -> None:
opts = get_opts()
2022-09-23 13:11:36 +00:00
basicConfig(
level=DEBUG if opts.debug else INFO,
2022-09-24 00:52:33 +00:00
format="[%(asctime)s] %(module)s.%(funcName)s:%(lineno)d %(levelname)s -> %(message)s",
2022-09-23 13:11:36 +00:00
)
2022-09-23 13:01:40 +00:00
2022-09-23 10:20:11 +00:00
if opts.beam_size <= 0:
opts.beam_size = None
2022-09-24 04:12:24 +00:00
if len(opts.temperature) == 0:
opts.temperature = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
2022-09-23 10:28:11 +00:00
try:
opts.mic = int(opts.mic)
except Exception:
pass
2022-09-23 10:20:11 +00:00
config = WhisperConfig(
model_name=opts.model,
language=opts.language,
device=opts.device,
beam_size=opts.beam_size,
2022-09-24 04:12:24 +00:00
temperatures=opts.temperature,
2022-09-23 10:20:11 +00:00
)
2022-09-23 10:28:11 +00:00
transcribe_from_mic(
config=config,
sd_device=opts.mic,
2022-09-23 11:45:18 +00:00
num_block=opts.num_block,
2022-09-23 10:28:11 +00:00
)
2022-09-23 10:20:11 +00:00
if __name__ == "__main__":
main()