mirror of
https://github.com/shirayu/whispering.git
synced 2024-11-25 02:11:00 +00:00
Add --output option (Resolve #14)
This commit is contained in:
parent
01f5a21020
commit
2364aca0ea
3 changed files with 52 additions and 15 deletions
|
@ -6,7 +6,8 @@ import queue
|
|||
import sys
|
||||
from enum import Enum
|
||||
from logging import DEBUG, INFO, basicConfig, getLogger
|
||||
from typing import Optional, Union
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Optional, Union
|
||||
|
||||
import sounddevice as sd
|
||||
import torch
|
||||
|
@ -15,7 +16,7 @@ from whisper.audio import N_FRAMES, SAMPLE_RATE
|
|||
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
|
||||
|
||||
from whispering.pbar import ProgressBar
|
||||
from whispering.schema import Context, WhisperConfig
|
||||
from whispering.schema import Context, StdoutWriter, WhisperConfig
|
||||
from whispering.serve import serve_with_websocket
|
||||
from whispering.transcriber import WhisperStreamingTranscriber
|
||||
from whispering.websocket_client import run_websocket_client
|
||||
|
@ -39,7 +40,7 @@ def transcribe_from_mic(
|
|||
num_block: int,
|
||||
ctx: Context,
|
||||
no_progress: bool,
|
||||
) -> None:
|
||||
) -> Iterator[str]:
|
||||
q = queue.Queue()
|
||||
|
||||
def sd_callback(indata, frames, time, status):
|
||||
|
@ -82,7 +83,7 @@ def transcribe_from_mic(
|
|||
if not no_progress:
|
||||
sys.stderr.write("\r")
|
||||
sys.stderr.flush()
|
||||
print(f"{chunk.start:.2f}->{chunk.end:.2f}\t{chunk.text}")
|
||||
yield f"{chunk.start:.2f}->{chunk.end:.2f}\t{chunk.text}\n"
|
||||
if not no_progress:
|
||||
sys.stderr.write("Analyzing")
|
||||
sys.stderr.flush()
|
||||
|
@ -150,6 +151,13 @@ def get_opts() -> argparse.Namespace:
|
|||
)
|
||||
|
||||
group_misc = parser.add_argument_group("Other options")
|
||||
group_misc.add_argument(
|
||||
"--output",
|
||||
"-o",
|
||||
help="Output file",
|
||||
type=Path,
|
||||
default=StdoutWriter(),
|
||||
)
|
||||
group_misc.add_argument(
|
||||
"--mic",
|
||||
help="Set MIC device",
|
||||
|
@ -269,6 +277,7 @@ def main() -> None:
|
|||
port=opts.port,
|
||||
no_progress=opts.no_progress,
|
||||
ctx=ctx,
|
||||
path_out=opts.output,
|
||||
)
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
|
@ -289,13 +298,16 @@ def main() -> None:
|
|||
assert opts.model is not None
|
||||
wsp = get_wshiper(opts=opts)
|
||||
ctx: Context = get_context(opts=opts)
|
||||
transcribe_from_mic(
|
||||
wsp=wsp,
|
||||
sd_device=opts.mic,
|
||||
num_block=opts.num_block,
|
||||
no_progress=opts.no_progress,
|
||||
ctx=ctx,
|
||||
)
|
||||
with opts.output.open("w") as outf:
|
||||
for text in transcribe_from_mic(
|
||||
wsp=wsp,
|
||||
sd_device=opts.mic,
|
||||
num_block=opts.num_block,
|
||||
no_progress=opts.no_progress,
|
||||
ctx=ctx,
|
||||
):
|
||||
outf.write(text)
|
||||
outf.flush()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import sys
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
@ -59,3 +60,20 @@ class SpeechSegment(BaseModel, arbitrary_types_allowed=True):
|
|||
start_block_idx: int
|
||||
end_block_idx: int
|
||||
audio: np.ndarray
|
||||
|
||||
|
||||
class StdoutWriter:
|
||||
def open(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def __enter__(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def __exit__(self):
|
||||
pass
|
||||
|
||||
def flush(self):
|
||||
sys.stdout.flush()
|
||||
|
||||
def write(self, text):
|
||||
sys.stdout.write(text)
|
||||
|
|
|
@ -3,13 +3,14 @@ import asyncio
|
|||
import json
|
||||
import sys
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import sounddevice as sd
|
||||
import websockets
|
||||
from whisper.audio import N_FRAMES, SAMPLE_RATE
|
||||
|
||||
from whispering.schema import ParsedChunk
|
||||
from whispering.schema import ParsedChunk, StdoutWriter
|
||||
from whispering.transcriber import Context
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
@ -28,6 +29,7 @@ async def transcribe_from_mic_and_send(
|
|||
host: str,
|
||||
port: int,
|
||||
ctx: Context,
|
||||
path_out: Union[Path, StdoutWriter],
|
||||
) -> None:
|
||||
uri = f"ws://{host}:{port}"
|
||||
|
||||
|
@ -38,7 +40,7 @@ async def transcribe_from_mic_and_send(
|
|||
dtype="float32",
|
||||
channels=1,
|
||||
callback=sd_callback,
|
||||
):
|
||||
), path_out.open("w") as outf:
|
||||
async with websockets.connect(uri, max_size=999999999) as ws: # type:ignore
|
||||
logger.debug("Sent context")
|
||||
v: str = ctx.json()
|
||||
|
@ -69,10 +71,13 @@ async def transcribe_from_mic_and_send(
|
|||
c = await asyncio.wait_for(recv(), timeout=0.5)
|
||||
c_json = json.loads(c)
|
||||
if (err := c_json.get("error")) is not None:
|
||||
print(f"Error: {err}")
|
||||
sys.stderr.write(f"Error: {err}\n")
|
||||
sys.exit(1)
|
||||
chunk = ParsedChunk.parse_obj(c_json)
|
||||
print(f"{chunk.start:.2f}->{chunk.end:.2f}\t{chunk.text}")
|
||||
outf.write(
|
||||
f"{chunk.start:.2f}->{chunk.end:.2f}\t{chunk.text}\n"
|
||||
)
|
||||
outf.flush()
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
idx += 1
|
||||
|
@ -86,6 +91,7 @@ async def run_websocket_client(
|
|||
port: int,
|
||||
ctx: Context,
|
||||
no_progress: bool,
|
||||
path_out: Union[Path, StdoutWriter],
|
||||
) -> None:
|
||||
global q
|
||||
global loop
|
||||
|
@ -98,4 +104,5 @@ async def run_websocket_client(
|
|||
host=host,
|
||||
port=port,
|
||||
ctx=ctx,
|
||||
path_out=path_out,
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue