Add --output option (Resolve #14)

This commit is contained in:
Yuta Hayashibe 2022-10-03 22:38:35 +09:00
parent 01f5a21020
commit 2364aca0ea
3 changed files with 52 additions and 15 deletions

View file

@ -6,7 +6,8 @@ import queue
import sys
from enum import Enum
from logging import DEBUG, INFO, basicConfig, getLogger
from typing import Optional, Union
from pathlib import Path
from typing import Iterator, Optional, Union
import sounddevice as sd
import torch
@ -15,7 +16,7 @@ from whisper.audio import N_FRAMES, SAMPLE_RATE
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
from whispering.pbar import ProgressBar
from whispering.schema import Context, WhisperConfig
from whispering.schema import Context, StdoutWriter, WhisperConfig
from whispering.serve import serve_with_websocket
from whispering.transcriber import WhisperStreamingTranscriber
from whispering.websocket_client import run_websocket_client
@ -39,7 +40,7 @@ def transcribe_from_mic(
num_block: int,
ctx: Context,
no_progress: bool,
) -> None:
) -> Iterator[str]:
q = queue.Queue()
def sd_callback(indata, frames, time, status):
@ -82,7 +83,7 @@ def transcribe_from_mic(
if not no_progress:
sys.stderr.write("\r")
sys.stderr.flush()
print(f"{chunk.start:.2f}->{chunk.end:.2f}\t{chunk.text}")
yield f"{chunk.start:.2f}->{chunk.end:.2f}\t{chunk.text}\n"
if not no_progress:
sys.stderr.write("Analyzing")
sys.stderr.flush()
@ -150,6 +151,13 @@ def get_opts() -> argparse.Namespace:
)
group_misc = parser.add_argument_group("Other options")
group_misc.add_argument(
"--output",
"-o",
help="Output file",
type=Path,
default=StdoutWriter(),
)
group_misc.add_argument(
"--mic",
help="Set MIC device",
@ -269,6 +277,7 @@ def main() -> None:
port=opts.port,
no_progress=opts.no_progress,
ctx=ctx,
path_out=opts.output,
)
)
except KeyboardInterrupt:
@ -289,13 +298,16 @@ def main() -> None:
assert opts.model is not None
wsp = get_wshiper(opts=opts)
ctx: Context = get_context(opts=opts)
transcribe_from_mic(
wsp=wsp,
sd_device=opts.mic,
num_block=opts.num_block,
no_progress=opts.no_progress,
ctx=ctx,
)
with opts.output.open("w") as outf:
for text in transcribe_from_mic(
wsp=wsp,
sd_device=opts.mic,
num_block=opts.num_block,
no_progress=opts.no_progress,
ctx=ctx,
):
outf.write(text)
outf.flush()
if __name__ == "__main__":

View file

@ -1,5 +1,6 @@
#!/usr/bin/env python3
import sys
from typing import List, Optional
import numpy as np
@ -59,3 +60,20 @@ class SpeechSegment(BaseModel, arbitrary_types_allowed=True):
start_block_idx: int
end_block_idx: int
audio: np.ndarray
class StdoutWriter:
def open(self, *args, **kwargs):
return self
def __enter__(self, *args, **kwargs):
return self
def __exit__(self):
pass
def flush(self):
sys.stdout.flush()
def write(self, text):
sys.stdout.write(text)

View file

@ -3,13 +3,14 @@ import asyncio
import json
import sys
from logging import getLogger
from pathlib import Path
from typing import Optional, Union
import sounddevice as sd
import websockets
from whisper.audio import N_FRAMES, SAMPLE_RATE
from whispering.schema import ParsedChunk
from whispering.schema import ParsedChunk, StdoutWriter
from whispering.transcriber import Context
logger = getLogger(__name__)
@ -28,6 +29,7 @@ async def transcribe_from_mic_and_send(
host: str,
port: int,
ctx: Context,
path_out: Union[Path, StdoutWriter],
) -> None:
uri = f"ws://{host}:{port}"
@ -38,7 +40,7 @@ async def transcribe_from_mic_and_send(
dtype="float32",
channels=1,
callback=sd_callback,
):
), path_out.open("w") as outf:
async with websockets.connect(uri, max_size=999999999) as ws: # type:ignore
logger.debug("Sent context")
v: str = ctx.json()
@ -69,10 +71,13 @@ async def transcribe_from_mic_and_send(
c = await asyncio.wait_for(recv(), timeout=0.5)
c_json = json.loads(c)
if (err := c_json.get("error")) is not None:
print(f"Error: {err}")
sys.stderr.write(f"Error: {err}\n")
sys.exit(1)
chunk = ParsedChunk.parse_obj(c_json)
print(f"{chunk.start:.2f}->{chunk.end:.2f}\t{chunk.text}")
outf.write(
f"{chunk.start:.2f}->{chunk.end:.2f}\t{chunk.text}\n"
)
outf.flush()
except asyncio.TimeoutError:
break
idx += 1
@ -86,6 +91,7 @@ async def run_websocket_client(
port: int,
ctx: Context,
no_progress: bool,
path_out: Union[Path, StdoutWriter],
) -> None:
global q
global loop
@ -98,4 +104,5 @@ async def run_websocket_client(
host=host,
port=port,
ctx=ctx,
path_out=path_out,
)