mirror of
https://github.com/shirayu/whispering.git
synced 2024-11-10 18:51:08 +00:00
Merge branch 'feature/ctx_msg'
This commit is contained in:
commit
8d32e607df
4 changed files with 74 additions and 20 deletions
|
@ -57,15 +57,13 @@ Run with ``--host`` and ``--port``.
|
||||||
whispering --language en --model tiny --host 0.0.0.0 --port 8000
|
whispering --language en --model tiny --host 0.0.0.0 --port 8000
|
||||||
```
|
```
|
||||||
|
|
||||||
You can set ``--allow-padding`` and other options.
|
|
||||||
|
|
||||||
### Client
|
### Client
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
whispering --host ADDRESS_OF_HOST --port 8000 --mode client
|
whispering --host ADDRESS_OF_HOST --port 8000 --mode client
|
||||||
```
|
```
|
||||||
|
|
||||||
You can set ``-n`` and other options.
|
You can set ``-n``, ``--allow-padding`` and other options.
|
||||||
|
|
||||||
## Tips
|
## Tips
|
||||||
|
|
||||||
|
|
|
@ -224,6 +224,19 @@ def show_devices():
|
||||||
print(f"{i}: {device['name']}")
|
print(f"{i}: {device['name']}")
|
||||||
|
|
||||||
|
|
||||||
|
def check_invalid_arg(opts):
|
||||||
|
ngs = []
|
||||||
|
if opts.mode == Mode.server.value:
|
||||||
|
ngs = [
|
||||||
|
"mic",
|
||||||
|
"allow_padding",
|
||||||
|
]
|
||||||
|
for ng in ngs:
|
||||||
|
if vars(opts).get(ng) not in {None, False}:
|
||||||
|
sys.stderr.write(f"{ng} is not accepted option for {opts.mode} mode\n")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
opts = get_opts()
|
opts = get_opts()
|
||||||
|
|
||||||
|
@ -242,13 +255,20 @@ def main() -> None:
|
||||||
):
|
):
|
||||||
opts.mode = Mode.server.value
|
opts.mode = Mode.server.value
|
||||||
|
|
||||||
|
check_invalid_arg(opts)
|
||||||
if opts.mode == Mode.client.value:
|
if opts.mode == Mode.client.value:
|
||||||
assert opts.language is None
|
assert opts.language is None
|
||||||
assert opts.model is None
|
assert opts.model is None
|
||||||
|
ctx: Context = get_context(opts=opts)
|
||||||
try:
|
try:
|
||||||
asyncio.run(
|
asyncio.run(
|
||||||
run_websocket_client(
|
run_websocket_client(
|
||||||
opts=opts,
|
sd_device=opts.mic,
|
||||||
|
num_block=opts.num_block,
|
||||||
|
host=opts.host,
|
||||||
|
port=opts.port,
|
||||||
|
no_progress=opts.no_progress,
|
||||||
|
ctx=ctx,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
|
@ -257,13 +277,11 @@ def main() -> None:
|
||||||
assert opts.language is not None
|
assert opts.language is not None
|
||||||
assert opts.model is not None
|
assert opts.model is not None
|
||||||
wsp = get_wshiper(opts=opts)
|
wsp = get_wshiper(opts=opts)
|
||||||
ctx: Context = get_context(opts=opts)
|
|
||||||
asyncio.run(
|
asyncio.run(
|
||||||
serve_with_websocket(
|
serve_with_websocket(
|
||||||
wsp=wsp,
|
wsp=wsp,
|
||||||
host=opts.host,
|
host=opts.host,
|
||||||
port=opts.port,
|
port=opts.port,
|
||||||
ctx=ctx,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import json
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import websockets
|
import websockets
|
||||||
|
@ -14,11 +16,8 @@ logger = getLogger(__name__)
|
||||||
|
|
||||||
async def serve_with_websocket_main(websocket):
|
async def serve_with_websocket_main(websocket):
|
||||||
global g_wsp
|
global g_wsp
|
||||||
global g_ctx
|
|
||||||
idx: int = 0
|
idx: int = 0
|
||||||
ctx: Context = g_ctx.copy(
|
ctx: Optional[Context] = None
|
||||||
deep=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
logger.debug(f"Audio #: {idx}")
|
logger.debug(f"Audio #: {idx}")
|
||||||
|
@ -29,10 +28,32 @@ async def serve_with_websocket_main(websocket):
|
||||||
|
|
||||||
if isinstance(message, str):
|
if isinstance(message, str):
|
||||||
logger.debug(f"Got str: {message}")
|
logger.debug(f"Got str: {message}")
|
||||||
|
d = json.loads(message)
|
||||||
|
v = d.get("context")
|
||||||
|
if v is not None:
|
||||||
|
ctx = Context.parse_obj(v)
|
||||||
|
else:
|
||||||
|
await websocket.send(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"error": "unsupported message",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return
|
||||||
continue
|
continue
|
||||||
|
|
||||||
logger.debug(f"Message size: {len(message)}")
|
logger.debug(f"Message size: {len(message)}")
|
||||||
audio = np.frombuffer(message, dtype=np.float32)
|
audio = np.frombuffer(message, dtype=np.float32)
|
||||||
|
if ctx is None:
|
||||||
|
await websocket.send(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"error": "no context",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return
|
||||||
for chunk in g_wsp.transcribe(
|
for chunk in g_wsp.transcribe(
|
||||||
audio=audio, # type: ignore
|
audio=audio, # type: ignore
|
||||||
ctx=ctx,
|
ctx=ctx,
|
||||||
|
@ -46,14 +67,11 @@ async def serve_with_websocket(
|
||||||
wsp: WhisperStreamingTranscriber,
|
wsp: WhisperStreamingTranscriber,
|
||||||
host: str,
|
host: str,
|
||||||
port: int,
|
port: int,
|
||||||
ctx: Context,
|
|
||||||
):
|
):
|
||||||
logger.info(f"Serve at {host}:{port}")
|
logger.info(f"Serve at {host}:{port}")
|
||||||
logger.info("Make secure with your responsibility!")
|
logger.info("Make secure with your responsibility!")
|
||||||
global g_wsp
|
global g_wsp
|
||||||
global g_ctx
|
|
||||||
g_wsp = wsp
|
g_wsp = wsp
|
||||||
g_ctx = ctx
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with websockets.serve( # type: ignore
|
async with websockets.serve( # type: ignore
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
@ -8,6 +10,7 @@ import websockets
|
||||||
from whisper.audio import N_FRAMES, SAMPLE_RATE
|
from whisper.audio import N_FRAMES, SAMPLE_RATE
|
||||||
|
|
||||||
from whispering.schema import ParsedChunk
|
from whispering.schema import ParsedChunk
|
||||||
|
from whispering.transcriber import Context
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -24,6 +27,7 @@ async def transcribe_from_mic_and_send(
|
||||||
num_block: int,
|
num_block: int,
|
||||||
host: str,
|
host: str,
|
||||||
port: int,
|
port: int,
|
||||||
|
ctx: Context,
|
||||||
) -> None:
|
) -> None:
|
||||||
uri = f"ws://{host}:{port}"
|
uri = f"ws://{host}:{port}"
|
||||||
|
|
||||||
|
@ -36,6 +40,10 @@ async def transcribe_from_mic_and_send(
|
||||||
callback=sd_callback,
|
callback=sd_callback,
|
||||||
):
|
):
|
||||||
async with websockets.connect(uri, max_size=999999999) as ws: # type:ignore
|
async with websockets.connect(uri, max_size=999999999) as ws: # type:ignore
|
||||||
|
logger.debug("Sent context")
|
||||||
|
v: str = ctx.json()
|
||||||
|
await ws.send("""{"context": """ + v + """}""")
|
||||||
|
|
||||||
idx: int = 0
|
idx: int = 0
|
||||||
while True:
|
while True:
|
||||||
|
|
||||||
|
@ -59,23 +67,35 @@ async def transcribe_from_mic_and_send(
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
c = await asyncio.wait_for(recv(), timeout=0.5)
|
c = await asyncio.wait_for(recv(), timeout=0.5)
|
||||||
chunk = ParsedChunk.parse_raw(c)
|
c_json = json.loads(c)
|
||||||
|
if (err := c_json.get("error")) is not None:
|
||||||
|
print(f"Error: {err}")
|
||||||
|
sys.exit(1)
|
||||||
|
chunk = ParsedChunk.parse_obj(c_json)
|
||||||
print(f"{chunk.start:.2f}->{chunk.end:.2f}\t{chunk.text}")
|
print(f"{chunk.start:.2f}->{chunk.end:.2f}\t{chunk.text}")
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
break
|
break
|
||||||
|
|
||||||
idx += 1
|
idx += 1
|
||||||
|
|
||||||
|
|
||||||
async def run_websocket_client(*, opts) -> None:
|
async def run_websocket_client(
|
||||||
|
*,
|
||||||
|
sd_device: Optional[Union[int, str]],
|
||||||
|
num_block: int,
|
||||||
|
host: str,
|
||||||
|
port: int,
|
||||||
|
ctx: Context,
|
||||||
|
no_progress: bool,
|
||||||
|
) -> None:
|
||||||
global q
|
global q
|
||||||
global loop
|
global loop
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
q = asyncio.Queue()
|
q = asyncio.Queue()
|
||||||
|
|
||||||
await transcribe_from_mic_and_send(
|
await transcribe_from_mic_and_send(
|
||||||
sd_device=opts.mic,
|
sd_device=sd_device,
|
||||||
num_block=opts.num_block,
|
num_block=num_block,
|
||||||
host=opts.host,
|
host=host,
|
||||||
port=opts.port,
|
port=port,
|
||||||
|
ctx=ctx,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in a new issue