Add --output option (Resolve #14)

This commit is contained in:
Yuta Hayashibe 2022-10-03 22:38:35 +09:00
parent 01f5a21020
commit 2364aca0ea
3 changed files with 52 additions and 15 deletions

View file

@ -6,7 +6,8 @@ import queue
import sys import sys
from enum import Enum from enum import Enum
from logging import DEBUG, INFO, basicConfig, getLogger from logging import DEBUG, INFO, basicConfig, getLogger
from typing import Optional, Union from pathlib import Path
from typing import Iterator, Optional, Union
import sounddevice as sd import sounddevice as sd
import torch import torch
@ -15,7 +16,7 @@ from whisper.audio import N_FRAMES, SAMPLE_RATE
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
from whispering.pbar import ProgressBar from whispering.pbar import ProgressBar
from whispering.schema import Context, WhisperConfig from whispering.schema import Context, StdoutWriter, WhisperConfig
from whispering.serve import serve_with_websocket from whispering.serve import serve_with_websocket
from whispering.transcriber import WhisperStreamingTranscriber from whispering.transcriber import WhisperStreamingTranscriber
from whispering.websocket_client import run_websocket_client from whispering.websocket_client import run_websocket_client
@ -39,7 +40,7 @@ def transcribe_from_mic(
num_block: int, num_block: int,
ctx: Context, ctx: Context,
no_progress: bool, no_progress: bool,
) -> None: ) -> Iterator[str]:
q = queue.Queue() q = queue.Queue()
def sd_callback(indata, frames, time, status): def sd_callback(indata, frames, time, status):
@ -82,7 +83,7 @@ def transcribe_from_mic(
if not no_progress: if not no_progress:
sys.stderr.write("\r") sys.stderr.write("\r")
sys.stderr.flush() sys.stderr.flush()
print(f"{chunk.start:.2f}->{chunk.end:.2f}\t{chunk.text}") yield f"{chunk.start:.2f}->{chunk.end:.2f}\t{chunk.text}\n"
if not no_progress: if not no_progress:
sys.stderr.write("Analyzing") sys.stderr.write("Analyzing")
sys.stderr.flush() sys.stderr.flush()
@ -150,6 +151,13 @@ def get_opts() -> argparse.Namespace:
) )
group_misc = parser.add_argument_group("Other options") group_misc = parser.add_argument_group("Other options")
group_misc.add_argument(
"--output",
"-o",
help="Output file",
type=Path,
default=StdoutWriter(),
)
group_misc.add_argument( group_misc.add_argument(
"--mic", "--mic",
help="Set MIC device", help="Set MIC device",
@ -269,6 +277,7 @@ def main() -> None:
port=opts.port, port=opts.port,
no_progress=opts.no_progress, no_progress=opts.no_progress,
ctx=ctx, ctx=ctx,
path_out=opts.output,
) )
) )
except KeyboardInterrupt: except KeyboardInterrupt:
@ -289,13 +298,16 @@ def main() -> None:
assert opts.model is not None assert opts.model is not None
wsp = get_wshiper(opts=opts) wsp = get_wshiper(opts=opts)
ctx: Context = get_context(opts=opts) ctx: Context = get_context(opts=opts)
transcribe_from_mic( with opts.output.open("w") as outf:
for text in transcribe_from_mic(
wsp=wsp, wsp=wsp,
sd_device=opts.mic, sd_device=opts.mic,
num_block=opts.num_block, num_block=opts.num_block,
no_progress=opts.no_progress, no_progress=opts.no_progress,
ctx=ctx, ctx=ctx,
) ):
outf.write(text)
outf.flush()
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -1,5 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import sys
from typing import List, Optional from typing import List, Optional
import numpy as np import numpy as np
@ -59,3 +60,20 @@ class SpeechSegment(BaseModel, arbitrary_types_allowed=True):
start_block_idx: int start_block_idx: int
end_block_idx: int end_block_idx: int
audio: np.ndarray audio: np.ndarray
class StdoutWriter:
def open(self, *args, **kwargs):
return self
def __enter__(self, *args, **kwargs):
return self
def __exit__(self):
pass
def flush(self):
sys.stdout.flush()
def write(self, text):
sys.stdout.write(text)

View file

@ -3,13 +3,14 @@ import asyncio
import json import json
import sys import sys
from logging import getLogger from logging import getLogger
from pathlib import Path
from typing import Optional, Union from typing import Optional, Union
import sounddevice as sd import sounddevice as sd
import websockets import websockets
from whisper.audio import N_FRAMES, SAMPLE_RATE from whisper.audio import N_FRAMES, SAMPLE_RATE
from whispering.schema import ParsedChunk from whispering.schema import ParsedChunk, StdoutWriter
from whispering.transcriber import Context from whispering.transcriber import Context
logger = getLogger(__name__) logger = getLogger(__name__)
@ -28,6 +29,7 @@ async def transcribe_from_mic_and_send(
host: str, host: str,
port: int, port: int,
ctx: Context, ctx: Context,
path_out: Union[Path, StdoutWriter],
) -> None: ) -> None:
uri = f"ws://{host}:{port}" uri = f"ws://{host}:{port}"
@ -38,7 +40,7 @@ async def transcribe_from_mic_and_send(
dtype="float32", dtype="float32",
channels=1, channels=1,
callback=sd_callback, callback=sd_callback,
): ), path_out.open("w") as outf:
async with websockets.connect(uri, max_size=999999999) as ws: # type:ignore async with websockets.connect(uri, max_size=999999999) as ws: # type:ignore
logger.debug("Sent context") logger.debug("Sent context")
v: str = ctx.json() v: str = ctx.json()
@ -69,10 +71,13 @@ async def transcribe_from_mic_and_send(
c = await asyncio.wait_for(recv(), timeout=0.5) c = await asyncio.wait_for(recv(), timeout=0.5)
c_json = json.loads(c) c_json = json.loads(c)
if (err := c_json.get("error")) is not None: if (err := c_json.get("error")) is not None:
print(f"Error: {err}") sys.stderr.write(f"Error: {err}\n")
sys.exit(1) sys.exit(1)
chunk = ParsedChunk.parse_obj(c_json) chunk = ParsedChunk.parse_obj(c_json)
print(f"{chunk.start:.2f}->{chunk.end:.2f}\t{chunk.text}") outf.write(
f"{chunk.start:.2f}->{chunk.end:.2f}\t{chunk.text}\n"
)
outf.flush()
except asyncio.TimeoutError: except asyncio.TimeoutError:
break break
idx += 1 idx += 1
@ -86,6 +91,7 @@ async def run_websocket_client(
port: int, port: int,
ctx: Context, ctx: Context,
no_progress: bool, no_progress: bool,
path_out: Union[Path, StdoutWriter],
) -> None: ) -> None:
global q global q
global loop global loop
@ -98,4 +104,5 @@ async def run_websocket_client(
host=host, host=host,
port=port, port=port,
ctx=ctx, ctx=ctx,
path_out=path_out,
) )