whispering/whispering/cli.py
Yuta Hayashibe 40dd471f6b Fix bugs
2022-10-17 21:29:12 +09:00

349 lines
8.7 KiB
Python

#!/usr/bin/env python3
import argparse
import asyncio
import queue
import sys
from enum import Enum
from logging import DEBUG, INFO, basicConfig, getLogger
from pathlib import Path
from typing import Iterator, List, Optional, Union
import sounddevice as sd
import torch
from whisper import available_models
from whisper.audio import N_FRAMES, SAMPLE_RATE
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
from whispering.pbar import ProgressBar
from whispering.schema import (
CURRENT_PROTOCOL_VERSION,
Context,
StdoutWriter,
WhisperConfig,
)
from whispering.serve import serve_with_websocket
from whispering.transcriber import WhisperStreamingTranscriber
from whispering.websocket_client import run_websocket_client
logger = getLogger(__name__)
class Mode(Enum):
client = "client"
server = "server"
mic = "mic"
def __str__(self):
return self.value
def transcribe_from_mic(
*,
wsp: WhisperStreamingTranscriber,
sd_device: Optional[Union[int, str]],
num_block: int,
ctx: Context,
no_progress: bool,
) -> Iterator[str]:
q = queue.Queue()
def sd_callback(indata, frames, time, status):
if status:
logger.warning(status)
q.put(indata.ravel())
logger.info("Ready to transcribe")
with sd.InputStream(
samplerate=SAMPLE_RATE,
blocksize=N_FRAMES * num_block,
device=sd_device,
dtype="float32",
channels=1,
callback=sd_callback,
):
idx: int = 0
while True:
logger.debug(f"Audio #: {idx}, The rest of queue: {q.qsize()}")
if no_progress:
audio = q.get()
else:
pbar_thread = ProgressBar(
num_block=num_block, # TODO: set more accurate value
)
try:
audio = q.get()
except KeyboardInterrupt:
pbar_thread.kill()
return
pbar_thread.kill()
logger.debug(f"Got. The rest of queue: {q.qsize()}")
if not no_progress:
sys.stderr.write("Analyzing")
sys.stderr.flush()
for chunk in wsp.transcribe(audio=audio, ctx=ctx):
if not no_progress:
sys.stderr.write("\r")
sys.stderr.flush()
yield f"{chunk.start:.2f}->{chunk.end:.2f}\t{chunk.text}\n"
if not no_progress:
sys.stderr.write("Analyzing")
sys.stderr.flush()
idx += 1
if not no_progress:
sys.stderr.write("\r")
sys.stderr.flush()
def get_opts() -> argparse.Namespace:
parser = argparse.ArgumentParser()
group_model = parser.add_argument_group("Whisper model options")
group_model.add_argument(
"--model",
type=str,
choices=available_models(),
)
group_model.add_argument(
"--language",
type=str,
choices=sorted(LANGUAGES.keys())
+ sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]),
)
group_model.add_argument(
"--device",
default="cuda" if torch.cuda.is_available() else "cpu",
help="device to use for PyTorch inference",
)
group_ws = parser.add_argument_group("WebSocket options")
group_ws.add_argument(
"--host",
help="host of websocker server",
)
group_ws.add_argument(
"--port",
type=int,
help="Port number of websocker server",
)
group_ctx = parser.add_argument_group("Parsing options")
group_ctx.add_argument(
"--beam_size",
"-b",
type=int,
default=5,
)
group_ctx.add_argument(
"--temperature",
"-t",
type=float,
action="append",
default=[],
)
group_ctx.add_argument(
"--vad",
type=float,
help="Threshold of VAD",
default=0.5,
)
group_ctx.add_argument(
"--max_nospeech_skip",
type=int,
help="Maximum number of skip to analyze because of nospeech",
default=16,
)
group_misc = parser.add_argument_group("Other options")
group_misc.add_argument(
"--output",
"-o",
help="Output file",
type=Path,
default=StdoutWriter(),
)
group_misc.add_argument(
"--mic",
help="Set MIC device",
)
group_misc.add_argument(
"--num_block",
"-n",
type=int,
default=20,
help="Number of operation unit",
)
group_misc.add_argument(
"--mode",
choices=[v.value for v in Mode],
)
group_misc.add_argument(
"--no-progress",
action="store_true",
)
group_misc.add_argument(
"--show-devices",
action="store_true",
help="Show MIC devices",
)
group_misc.add_argument(
"--debug",
action="store_true",
)
opts = parser.parse_args()
if opts.beam_size <= 0:
opts.beam_size = None
if len(opts.temperature) == 0:
opts.temperature = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
opts.temperature = sorted(set(opts.temperature))
try:
opts.mic = int(opts.mic)
except Exception:
pass
return opts
def get_wshiper(*, opts) -> WhisperStreamingTranscriber:
config = WhisperConfig(
model_name=opts.model,
language=opts.language,
device=opts.device,
)
logger.debug(f"WhisperConfig: {config}")
wsp = WhisperStreamingTranscriber(config=config)
return wsp
def get_context(*, opts) -> Context:
ctx = Context(
protocol_version=CURRENT_PROTOCOL_VERSION,
beam_size=opts.beam_size,
temperatures=opts.temperature,
max_nospeech_skip=opts.max_nospeech_skip,
vad_threshold=opts.vad,
)
logger.debug(f"Context: {ctx}")
return ctx
def show_devices():
devices = sd.query_devices()
for i, device in enumerate(devices):
if device["max_input_channels"] > 0:
print(f"{i}: {device['name']}")
def is_valid_arg(
*,
args: List[str],
mode: str,
) -> bool:
keys = []
if mode == Mode.server.value:
keys = {
"--mic",
"--beam_size",
"-b",
"--temperature",
"-t",
"--num_block",
"-n",
"--vad",
"--max_nospeech_skip",
"--output",
"--show-devices",
"--no-progress",
}
elif mode == Mode.mic.value:
keys = {
"--host",
"--port",
}
for arg in args:
if arg in keys:
sys.stderr.write(f"{arg} is not accepted option for {mode} mode\n")
return False
return True
def main() -> None:
opts = get_opts()
basicConfig(
level=DEBUG if opts.debug else INFO,
format="[%(asctime)s] %(module)s.%(funcName)s:%(lineno)d %(levelname)s -> %(message)s",
)
if opts.show_devices:
return show_devices()
if (
opts.host is not None
and opts.port is not None
and opts.mode != Mode.client.value
):
opts.mode = Mode.server.value
if not is_valid_arg(
args=sys.argv[1:],
mode=opts.mode,
):
sys.exit(1)
if opts.mode == Mode.client.value:
assert opts.language is None
assert opts.model is None
ctx: Context = get_context(opts=opts)
try:
asyncio.run(
run_websocket_client(
sd_device=opts.mic,
num_block=opts.num_block,
host=opts.host,
port=opts.port,
no_progress=opts.no_progress,
ctx=ctx,
path_out=opts.output,
)
)
except KeyboardInterrupt:
pass
elif opts.mode == Mode.server.value:
assert opts.language is not None
assert opts.model is not None
wsp = get_wshiper(opts=opts)
asyncio.run(
serve_with_websocket(
wsp=wsp,
host=opts.host,
port=opts.port,
)
)
else:
assert opts.language is not None
assert opts.model is not None
wsp = get_wshiper(opts=opts)
ctx: Context = get_context(opts=opts)
with opts.output.open("w") as outf:
for text in transcribe_from_mic(
wsp=wsp,
sd_device=opts.mic,
num_block=opts.num_block,
no_progress=opts.no_progress,
ctx=ctx,
):
outf.write(text)
outf.flush()
if __name__ == "__main__":
main()