From 2364aca0eaa224fcee5952d775ccee5ee837cb59 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Mon, 3 Oct 2022 22:38:35 +0900 Subject: [PATCH] Add --output option (Resolve #14) --- whispering/cli.py | 34 +++++++++++++++++++++++----------- whispering/schema.py | 18 ++++++++++++++++++ whispering/websocket_client.py | 15 +++++++++++---- 3 files changed, 52 insertions(+), 15 deletions(-) diff --git a/whispering/cli.py b/whispering/cli.py index 1af32e7..f804f1e 100644 --- a/whispering/cli.py +++ b/whispering/cli.py @@ -6,7 +6,8 @@ import queue import sys from enum import Enum from logging import DEBUG, INFO, basicConfig, getLogger -from typing import Optional, Union +from pathlib import Path +from typing import Iterator, Optional, Union import sounddevice as sd import torch @@ -15,7 +16,7 @@ from whisper.audio import N_FRAMES, SAMPLE_RATE from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE from whispering.pbar import ProgressBar -from whispering.schema import Context, WhisperConfig +from whispering.schema import Context, StdoutWriter, WhisperConfig from whispering.serve import serve_with_websocket from whispering.transcriber import WhisperStreamingTranscriber from whispering.websocket_client import run_websocket_client @@ -39,7 +40,7 @@ def transcribe_from_mic( num_block: int, ctx: Context, no_progress: bool, -) -> None: +) -> Iterator[str]: q = queue.Queue() def sd_callback(indata, frames, time, status): @@ -82,7 +83,7 @@ def transcribe_from_mic( if not no_progress: sys.stderr.write("\r") sys.stderr.flush() - print(f"{chunk.start:.2f}->{chunk.end:.2f}\t{chunk.text}") + yield f"{chunk.start:.2f}->{chunk.end:.2f}\t{chunk.text}\n" if not no_progress: sys.stderr.write("Analyzing") sys.stderr.flush() @@ -150,6 +151,13 @@ def get_opts() -> argparse.Namespace: ) group_misc = parser.add_argument_group("Other options") + group_misc.add_argument( + "--output", + "-o", + help="Output file", + type=Path, + default=StdoutWriter(), + ) group_misc.add_argument( "--mic", help="Set MIC device", @@ -269,6 +277,7 @@ def main() -> None: port=opts.port, no_progress=opts.no_progress, ctx=ctx, + path_out=opts.output, ) ) except KeyboardInterrupt: @@ -289,13 +298,16 @@ def main() -> None: assert opts.model is not None wsp = get_wshiper(opts=opts) ctx: Context = get_context(opts=opts) - transcribe_from_mic( - wsp=wsp, - sd_device=opts.mic, - num_block=opts.num_block, - no_progress=opts.no_progress, - ctx=ctx, - ) + with opts.output.open("w") as outf: + for text in transcribe_from_mic( + wsp=wsp, + sd_device=opts.mic, + num_block=opts.num_block, + no_progress=opts.no_progress, + ctx=ctx, + ): + outf.write(text) + outf.flush() if __name__ == "__main__": diff --git a/whispering/schema.py b/whispering/schema.py index 2e65c72..46acbfb 100644 --- a/whispering/schema.py +++ b/whispering/schema.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 +import sys from typing import List, Optional import numpy as np @@ -59,3 +60,20 @@ class SpeechSegment(BaseModel, arbitrary_types_allowed=True): start_block_idx: int end_block_idx: int audio: np.ndarray + + +class StdoutWriter: + def open(self, *args, **kwargs): + return self + + def __enter__(self, *args, **kwargs): + return self + + def __exit__(self): + pass + + def flush(self): + sys.stdout.flush() + + def write(self, text): + sys.stdout.write(text) diff --git a/whispering/websocket_client.py b/whispering/websocket_client.py index 99798c2..7953e5d 100644 --- a/whispering/websocket_client.py +++ b/whispering/websocket_client.py @@ -3,13 +3,14 @@ import asyncio import json import sys from logging import getLogger +from pathlib import Path from typing import Optional, Union import sounddevice as sd import websockets from whisper.audio import N_FRAMES, SAMPLE_RATE -from whispering.schema import ParsedChunk +from whispering.schema import ParsedChunk, StdoutWriter from whispering.transcriber import Context logger = getLogger(__name__) @@ -28,6 +29,7 @@ async def transcribe_from_mic_and_send( host: str, port: int, ctx: Context, + path_out: Union[Path, StdoutWriter], ) -> None: uri = f"ws://{host}:{port}" @@ -38,7 +40,7 @@ async def transcribe_from_mic_and_send( dtype="float32", channels=1, callback=sd_callback, - ): + ), path_out.open("w") as outf: async with websockets.connect(uri, max_size=999999999) as ws: # type:ignore logger.debug("Sent context") v: str = ctx.json() @@ -69,10 +71,13 @@ async def transcribe_from_mic_and_send( c = await asyncio.wait_for(recv(), timeout=0.5) c_json = json.loads(c) if (err := c_json.get("error")) is not None: - print(f"Error: {err}") + sys.stderr.write(f"Error: {err}\n") sys.exit(1) chunk = ParsedChunk.parse_obj(c_json) - print(f"{chunk.start:.2f}->{chunk.end:.2f}\t{chunk.text}") + outf.write( + f"{chunk.start:.2f}->{chunk.end:.2f}\t{chunk.text}\n" + ) + outf.flush() except asyncio.TimeoutError: break idx += 1 @@ -86,6 +91,7 @@ async def run_websocket_client( port: int, ctx: Context, no_progress: bool, + path_out: Union[Path, StdoutWriter], ) -> None: global q global loop @@ -98,4 +104,5 @@ async def run_websocket_client( host=host, port=port, ctx=ctx, + path_out=path_out, )