mirror of
https://github.com/shirayu/whispering.git
synced 2024-11-28 20:11:08 +00:00
Add --output option (Resolve #14)
This commit is contained in:
parent
01f5a21020
commit
2364aca0ea
3 changed files with 52 additions and 15 deletions
|
@ -6,7 +6,8 @@ import queue
|
||||||
import sys
|
import sys
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from logging import DEBUG, INFO, basicConfig, getLogger
|
from logging import DEBUG, INFO, basicConfig, getLogger
|
||||||
from typing import Optional, Union
|
from pathlib import Path
|
||||||
|
from typing import Iterator, Optional, Union
|
||||||
|
|
||||||
import sounddevice as sd
|
import sounddevice as sd
|
||||||
import torch
|
import torch
|
||||||
|
@ -15,7 +16,7 @@ from whisper.audio import N_FRAMES, SAMPLE_RATE
|
||||||
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
|
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
|
||||||
|
|
||||||
from whispering.pbar import ProgressBar
|
from whispering.pbar import ProgressBar
|
||||||
from whispering.schema import Context, WhisperConfig
|
from whispering.schema import Context, StdoutWriter, WhisperConfig
|
||||||
from whispering.serve import serve_with_websocket
|
from whispering.serve import serve_with_websocket
|
||||||
from whispering.transcriber import WhisperStreamingTranscriber
|
from whispering.transcriber import WhisperStreamingTranscriber
|
||||||
from whispering.websocket_client import run_websocket_client
|
from whispering.websocket_client import run_websocket_client
|
||||||
|
@ -39,7 +40,7 @@ def transcribe_from_mic(
|
||||||
num_block: int,
|
num_block: int,
|
||||||
ctx: Context,
|
ctx: Context,
|
||||||
no_progress: bool,
|
no_progress: bool,
|
||||||
) -> None:
|
) -> Iterator[str]:
|
||||||
q = queue.Queue()
|
q = queue.Queue()
|
||||||
|
|
||||||
def sd_callback(indata, frames, time, status):
|
def sd_callback(indata, frames, time, status):
|
||||||
|
@ -82,7 +83,7 @@ def transcribe_from_mic(
|
||||||
if not no_progress:
|
if not no_progress:
|
||||||
sys.stderr.write("\r")
|
sys.stderr.write("\r")
|
||||||
sys.stderr.flush()
|
sys.stderr.flush()
|
||||||
print(f"{chunk.start:.2f}->{chunk.end:.2f}\t{chunk.text}")
|
yield f"{chunk.start:.2f}->{chunk.end:.2f}\t{chunk.text}\n"
|
||||||
if not no_progress:
|
if not no_progress:
|
||||||
sys.stderr.write("Analyzing")
|
sys.stderr.write("Analyzing")
|
||||||
sys.stderr.flush()
|
sys.stderr.flush()
|
||||||
|
@ -150,6 +151,13 @@ def get_opts() -> argparse.Namespace:
|
||||||
)
|
)
|
||||||
|
|
||||||
group_misc = parser.add_argument_group("Other options")
|
group_misc = parser.add_argument_group("Other options")
|
||||||
|
group_misc.add_argument(
|
||||||
|
"--output",
|
||||||
|
"-o",
|
||||||
|
help="Output file",
|
||||||
|
type=Path,
|
||||||
|
default=StdoutWriter(),
|
||||||
|
)
|
||||||
group_misc.add_argument(
|
group_misc.add_argument(
|
||||||
"--mic",
|
"--mic",
|
||||||
help="Set MIC device",
|
help="Set MIC device",
|
||||||
|
@ -269,6 +277,7 @@ def main() -> None:
|
||||||
port=opts.port,
|
port=opts.port,
|
||||||
no_progress=opts.no_progress,
|
no_progress=opts.no_progress,
|
||||||
ctx=ctx,
|
ctx=ctx,
|
||||||
|
path_out=opts.output,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
|
@ -289,13 +298,16 @@ def main() -> None:
|
||||||
assert opts.model is not None
|
assert opts.model is not None
|
||||||
wsp = get_wshiper(opts=opts)
|
wsp = get_wshiper(opts=opts)
|
||||||
ctx: Context = get_context(opts=opts)
|
ctx: Context = get_context(opts=opts)
|
||||||
transcribe_from_mic(
|
with opts.output.open("w") as outf:
|
||||||
|
for text in transcribe_from_mic(
|
||||||
wsp=wsp,
|
wsp=wsp,
|
||||||
sd_device=opts.mic,
|
sd_device=opts.mic,
|
||||||
num_block=opts.num_block,
|
num_block=opts.num_block,
|
||||||
no_progress=opts.no_progress,
|
no_progress=opts.no_progress,
|
||||||
ctx=ctx,
|
ctx=ctx,
|
||||||
)
|
):
|
||||||
|
outf.write(text)
|
||||||
|
outf.flush()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import sys
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -59,3 +60,20 @@ class SpeechSegment(BaseModel, arbitrary_types_allowed=True):
|
||||||
start_block_idx: int
|
start_block_idx: int
|
||||||
end_block_idx: int
|
end_block_idx: int
|
||||||
audio: np.ndarray
|
audio: np.ndarray
|
||||||
|
|
||||||
|
|
||||||
|
class StdoutWriter:
|
||||||
|
def open(self, *args, **kwargs):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __enter__(self, *args, **kwargs):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def flush(self):
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
def write(self, text):
|
||||||
|
sys.stdout.write(text)
|
||||||
|
|
|
@ -3,13 +3,14 @@ import asyncio
|
||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
|
from pathlib import Path
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import sounddevice as sd
|
import sounddevice as sd
|
||||||
import websockets
|
import websockets
|
||||||
from whisper.audio import N_FRAMES, SAMPLE_RATE
|
from whisper.audio import N_FRAMES, SAMPLE_RATE
|
||||||
|
|
||||||
from whispering.schema import ParsedChunk
|
from whispering.schema import ParsedChunk, StdoutWriter
|
||||||
from whispering.transcriber import Context
|
from whispering.transcriber import Context
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
@ -28,6 +29,7 @@ async def transcribe_from_mic_and_send(
|
||||||
host: str,
|
host: str,
|
||||||
port: int,
|
port: int,
|
||||||
ctx: Context,
|
ctx: Context,
|
||||||
|
path_out: Union[Path, StdoutWriter],
|
||||||
) -> None:
|
) -> None:
|
||||||
uri = f"ws://{host}:{port}"
|
uri = f"ws://{host}:{port}"
|
||||||
|
|
||||||
|
@ -38,7 +40,7 @@ async def transcribe_from_mic_and_send(
|
||||||
dtype="float32",
|
dtype="float32",
|
||||||
channels=1,
|
channels=1,
|
||||||
callback=sd_callback,
|
callback=sd_callback,
|
||||||
):
|
), path_out.open("w") as outf:
|
||||||
async with websockets.connect(uri, max_size=999999999) as ws: # type:ignore
|
async with websockets.connect(uri, max_size=999999999) as ws: # type:ignore
|
||||||
logger.debug("Sent context")
|
logger.debug("Sent context")
|
||||||
v: str = ctx.json()
|
v: str = ctx.json()
|
||||||
|
@ -69,10 +71,13 @@ async def transcribe_from_mic_and_send(
|
||||||
c = await asyncio.wait_for(recv(), timeout=0.5)
|
c = await asyncio.wait_for(recv(), timeout=0.5)
|
||||||
c_json = json.loads(c)
|
c_json = json.loads(c)
|
||||||
if (err := c_json.get("error")) is not None:
|
if (err := c_json.get("error")) is not None:
|
||||||
print(f"Error: {err}")
|
sys.stderr.write(f"Error: {err}\n")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
chunk = ParsedChunk.parse_obj(c_json)
|
chunk = ParsedChunk.parse_obj(c_json)
|
||||||
print(f"{chunk.start:.2f}->{chunk.end:.2f}\t{chunk.text}")
|
outf.write(
|
||||||
|
f"{chunk.start:.2f}->{chunk.end:.2f}\t{chunk.text}\n"
|
||||||
|
)
|
||||||
|
outf.flush()
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
break
|
break
|
||||||
idx += 1
|
idx += 1
|
||||||
|
@ -86,6 +91,7 @@ async def run_websocket_client(
|
||||||
port: int,
|
port: int,
|
||||||
ctx: Context,
|
ctx: Context,
|
||||||
no_progress: bool,
|
no_progress: bool,
|
||||||
|
path_out: Union[Path, StdoutWriter],
|
||||||
) -> None:
|
) -> None:
|
||||||
global q
|
global q
|
||||||
global loop
|
global loop
|
||||||
|
@ -98,4 +104,5 @@ async def run_websocket_client(
|
||||||
host=host,
|
host=host,
|
||||||
port=port,
|
port=port,
|
||||||
ctx=ctx,
|
ctx=ctx,
|
||||||
|
path_out=path_out,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in a new issue