whispering/whispering/cli.py

349 lines
8.7 KiB
Python
Raw Normal View History

2022-09-23 10:20:11 +00:00
#!/usr/bin/env python3
import argparse
2022-09-24 11:45:20 +00:00
import asyncio
2022-09-23 10:20:11 +00:00
import queue
import sys
2022-10-02 12:31:05 +00:00
from enum import Enum
2022-09-23 13:13:25 +00:00
from logging import DEBUG, INFO, basicConfig, getLogger
2022-10-03 13:38:35 +00:00
from pathlib import Path
2022-10-17 12:29:12 +00:00
from typing import Iterator, List, Optional, Union
2022-09-23 10:20:11 +00:00
import sounddevice as sd
import torch
from whisper import available_models
from whisper.audio import N_FRAMES, SAMPLE_RATE
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
from whispering.pbar import ProgressBar
2022-10-15 04:23:00 +00:00
from whispering.schema import (
CURRENT_PROTOCOL_VERSION,
Context,
StdoutWriter,
WhisperConfig,
)
from whispering.serve import serve_with_websocket
from whispering.transcriber import WhisperStreamingTranscriber
from whispering.websocket_client import run_websocket_client
2022-09-23 10:20:11 +00:00
logger = getLogger(__name__)
2022-10-02 12:31:05 +00:00
class Mode(Enum):
client = "client"
server = "server"
mic = "mic"
def __str__(self):
return self.value
2022-09-23 10:28:11 +00:00
def transcribe_from_mic(
*,
2022-09-24 11:45:20 +00:00
wsp: WhisperStreamingTranscriber,
2022-09-23 10:28:11 +00:00
sd_device: Optional[Union[int, str]],
2022-09-23 11:45:18 +00:00
num_block: int,
2022-09-29 11:43:49 +00:00
ctx: Context,
no_progress: bool,
2022-10-03 13:38:35 +00:00
) -> Iterator[str]:
2022-09-23 10:20:11 +00:00
q = queue.Queue()
def sd_callback(indata, frames, time, status):
if status:
logger.warning(status)
q.put(indata.ravel())
logger.info("Ready to transcribe")
with sd.InputStream(
samplerate=SAMPLE_RATE,
2022-09-23 11:45:18 +00:00
blocksize=N_FRAMES * num_block,
2022-09-23 10:20:11 +00:00
device=sd_device,
dtype="float32",
channels=1,
callback=sd_callback,
):
2022-09-23 13:44:16 +00:00
idx: int = 0
2022-09-23 10:20:11 +00:00
while True:
2022-10-02 10:47:17 +00:00
logger.debug(f"Audio #: {idx}, The rest of queue: {q.qsize()}")
if no_progress:
2022-10-02 10:47:17 +00:00
audio = q.get()
else:
2022-09-29 13:07:44 +00:00
pbar_thread = ProgressBar(
num_block=num_block, # TODO: set more accurate value
)
try:
2022-10-02 10:47:17 +00:00
audio = q.get()
except KeyboardInterrupt:
pbar_thread.kill()
return
pbar_thread.kill()
logger.debug(f"Got. The rest of queue: {q.qsize()}")
if not no_progress:
sys.stderr.write("Analyzing")
sys.stderr.flush()
2022-10-02 10:47:17 +00:00
for chunk in wsp.transcribe(audio=audio, ctx=ctx):
if not no_progress:
sys.stderr.write("\r")
sys.stderr.flush()
2022-10-03 13:38:35 +00:00
yield f"{chunk.start:.2f}->{chunk.end:.2f}\t{chunk.text}\n"
if not no_progress:
sys.stderr.write("Analyzing")
sys.stderr.flush()
2022-09-23 13:44:16 +00:00
idx += 1
if not no_progress:
sys.stderr.write("\r")
sys.stderr.flush()
2022-09-23 10:20:11 +00:00
def get_opts() -> argparse.Namespace:
parser = argparse.ArgumentParser()
2022-10-02 12:13:36 +00:00
group_model = parser.add_argument_group("Whisper model options")
group_model.add_argument(
"--model",
type=str,
choices=available_models(),
)
group_model.add_argument(
2022-09-23 10:20:11 +00:00
"--language",
type=str,
choices=sorted(LANGUAGES.keys())
2022-09-23 10:20:11 +00:00
+ sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]),
)
2022-10-02 12:13:36 +00:00
group_model.add_argument(
2022-09-23 10:20:11 +00:00
"--device",
default="cuda" if torch.cuda.is_available() else "cpu",
help="device to use for PyTorch inference",
)
2022-10-02 12:13:36 +00:00
group_ws = parser.add_argument_group("WebSocket options")
group_ws.add_argument(
"--host",
help="host of websocker server",
)
group_ws.add_argument(
"--port",
type=int,
help="Port number of websocker server",
)
group_ctx = parser.add_argument_group("Parsing options")
group_ctx.add_argument(
2022-09-23 10:20:11 +00:00
"--beam_size",
"-b",
type=int,
default=5,
)
2022-10-02 12:13:36 +00:00
group_ctx.add_argument(
2022-09-24 04:12:24 +00:00
"--temperature",
"-t",
type=float,
action="append",
default=[],
)
2022-10-02 12:37:27 +00:00
group_ctx.add_argument(
"--vad",
type=float,
help="Threshold of VAD",
default=0.5,
2022-10-02 12:37:27 +00:00
)
group_ctx.add_argument(
"--max_nospeech_skip",
type=int,
help="Maximum number of skip to analyze because of nospeech",
default=16,
)
2022-10-02 12:36:26 +00:00
group_misc = parser.add_argument_group("Other options")
2022-10-03 13:38:35 +00:00
group_misc.add_argument(
"--output",
"-o",
help="Output file",
type=Path,
default=StdoutWriter(),
)
2022-10-02 12:36:26 +00:00
group_misc.add_argument(
"--mic",
help="Set MIC device",
2022-09-24 11:45:20 +00:00
)
2022-10-02 12:36:26 +00:00
group_misc.add_argument(
"--num_block",
"-n",
type=int,
2022-10-02 13:55:13 +00:00
default=20,
2022-10-02 12:36:26 +00:00
help="Number of operation unit",
)
group_misc.add_argument(
2022-09-24 19:13:28 +00:00
"--mode",
2022-10-02 12:31:05 +00:00
choices=[v.value for v in Mode],
2022-09-24 19:13:28 +00:00
)
2022-10-02 12:13:36 +00:00
group_misc.add_argument(
2022-10-02 12:36:26 +00:00
"--no-progress",
action="store_true",
)
2022-10-02 12:13:36 +00:00
group_misc.add_argument(
"--show-devices",
action="store_true",
2022-10-02 12:14:59 +00:00
help="Show MIC devices",
)
2022-10-02 12:13:36 +00:00
group_misc.add_argument(
"--debug",
2022-10-02 11:38:21 +00:00
action="store_true",
)
2022-10-02 12:13:36 +00:00
2022-09-27 03:27:22 +00:00
opts = parser.parse_args()
2022-09-23 10:20:11 +00:00
2022-09-27 03:27:22 +00:00
if opts.beam_size <= 0:
opts.beam_size = None
if len(opts.temperature) == 0:
opts.temperature = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
opts.temperature = sorted(set(opts.temperature))
try:
opts.mic = int(opts.mic)
except Exception:
pass
return opts
2022-09-23 10:20:11 +00:00
2022-09-29 11:43:49 +00:00
def get_wshiper(*, opts) -> WhisperStreamingTranscriber:
2022-09-24 19:13:28 +00:00
config = WhisperConfig(
model_name=opts.model,
language=opts.language,
2022-09-24 19:13:28 +00:00
device=opts.device,
)
logger.debug(f"WhisperConfig: {config}")
wsp = WhisperStreamingTranscriber(config=config)
return wsp
2022-09-29 11:43:49 +00:00
def get_context(*, opts) -> Context:
ctx = Context(
2022-10-15 04:23:00 +00:00
protocol_version=CURRENT_PROTOCOL_VERSION,
2022-09-29 11:43:49 +00:00
beam_size=opts.beam_size,
temperatures=opts.temperature,
max_nospeech_skip=opts.max_nospeech_skip,
vad_threshold=opts.vad,
2022-09-29 11:43:49 +00:00
)
logger.debug(f"Context: {ctx}")
return ctx
2022-09-27 03:27:22 +00:00
def show_devices():
devices = sd.query_devices()
for i, device in enumerate(devices):
if device["max_input_channels"] > 0:
print(f"{i}: {device['name']}")
2022-10-17 12:29:12 +00:00
def is_valid_arg(
*,
args: List[str],
mode: str,
) -> bool:
2022-10-09 14:56:39 +00:00
keys = []
2022-10-17 12:29:12 +00:00
if mode == Mode.server.value:
keys = {
"--mic",
"--beam_size",
"-b",
"--temperature",
"-t",
"--num_block",
"-n",
"--vad",
"--max_nospeech_skip",
"--output",
"--show-devices",
"--no-progress",
}
elif mode == Mode.mic.value:
keys = {
"--host",
"--port",
}
for arg in args:
if arg in keys:
sys.stderr.write(f"{arg} is not accepted option for {mode} mode\n")
2022-10-09 14:56:39 +00:00
return False
return True
2022-10-02 12:59:02 +00:00
2022-09-27 03:27:22 +00:00
def main() -> None:
opts = get_opts()
2022-09-23 13:11:36 +00:00
basicConfig(
level=DEBUG if opts.debug else INFO,
2022-09-24 00:52:33 +00:00
format="[%(asctime)s] %(module)s.%(funcName)s:%(lineno)d %(levelname)s -> %(message)s",
2022-09-23 13:11:36 +00:00
)
2022-09-23 13:01:40 +00:00
2022-09-27 03:27:22 +00:00
if opts.show_devices:
return show_devices()
2022-09-23 10:28:11 +00:00
2022-10-02 12:51:34 +00:00
if (
opts.host is not None
and opts.port is not None
and opts.mode != Mode.client.value
):
opts.mode = Mode.server.value
2022-10-02 12:31:05 +00:00
2022-10-17 12:29:12 +00:00
if not is_valid_arg(
args=sys.argv[1:],
mode=opts.mode,
):
2022-10-09 14:56:39 +00:00
sys.exit(1)
2022-10-02 12:51:34 +00:00
if opts.mode == Mode.client.value:
2022-10-02 12:31:05 +00:00
assert opts.language is None
assert opts.model is None
2022-10-02 12:59:02 +00:00
ctx: Context = get_context(opts=opts)
2022-10-02 12:31:05 +00:00
try:
2022-09-24 19:13:28 +00:00
asyncio.run(
2022-10-02 12:31:05 +00:00
run_websocket_client(
2022-10-02 12:59:02 +00:00
sd_device=opts.mic,
num_block=opts.num_block,
host=opts.host,
port=opts.port,
no_progress=opts.no_progress,
ctx=ctx,
2022-10-03 13:38:35 +00:00
path_out=opts.output,
2022-09-24 19:13:28 +00:00
)
2022-09-24 11:45:20 +00:00
)
2022-10-02 12:31:05 +00:00
except KeyboardInterrupt:
pass
2022-10-02 12:51:34 +00:00
elif opts.mode == Mode.server.value:
2022-10-02 12:31:05 +00:00
assert opts.language is not None
assert opts.model is not None
wsp = get_wshiper(opts=opts)
asyncio.run(
serve_with_websocket(
wsp=wsp,
host=opts.host,
port=opts.port,
)
)
2022-09-24 11:45:20 +00:00
else:
2022-09-24 19:13:28 +00:00
assert opts.language is not None
assert opts.model is not None
wsp = get_wshiper(opts=opts)
2022-09-29 11:43:49 +00:00
ctx: Context = get_context(opts=opts)
2022-10-03 13:38:35 +00:00
with opts.output.open("w") as outf:
for text in transcribe_from_mic(
wsp=wsp,
sd_device=opts.mic,
num_block=opts.num_block,
no_progress=opts.no_progress,
ctx=ctx,
):
outf.write(text)
outf.flush()
2022-09-23 10:20:11 +00:00
if __name__ == "__main__":
main()