mirror of
https://github.com/shirayu/whispering.git
synced 2024-06-13 02:39:23 +00:00
process last file
This commit is contained in:
parent
df99efca98
commit
59b1230af7
|
@ -11,14 +11,12 @@ RUN pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1
|
|||
|
||||
RUN pip install ipython
|
||||
|
||||
# Copy entire repo
|
||||
COPY . /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN pip install .
|
||||
|
||||
COPY download_whisper_models.py .
|
||||
RUN python3 download_whisper_models.py --models tiny small base medium large
|
||||
|
||||
# open bash
|
||||
|
|
145
stream_file.py
Normal file
145
stream_file.py
Normal file
|
@ -0,0 +1,145 @@
|
|||
import sys
|
||||
import wave
|
||||
|
||||
import numpy
|
||||
import websockets
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import soundfile as sf
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
# URL = "ws://ec2-3-71-72-35.eu-central-1.compute.amazonaws.com"
|
||||
|
||||
# URL = "ws://localhost:1111/ws"
|
||||
|
||||
URL = "ws://localhost:8000"
|
||||
|
||||
FNAME = sys.argv[1]
|
||||
|
||||
|
||||
initial_context = {
|
||||
"context": {
|
||||
"protocol_version": 6002,
|
||||
"timestamp": 0.0,
|
||||
"buffer_tokens": [],
|
||||
"buffer_mel": None,
|
||||
"nosoeech_skip_count": None,
|
||||
"temperatures": [0.2],
|
||||
"patience": None,
|
||||
"compression_ratio_threshold": 2.4,
|
||||
"logprob_threshold": -1.0,
|
||||
"no_captions_threshold": 0.6,
|
||||
"best_of": 5,
|
||||
"beam_size": 5,
|
||||
"no_speech_threshold": 0.6,
|
||||
"buffer_threshold": 0.5,
|
||||
"vad_threshold": 0.5,
|
||||
"max_nospeech_skip": 16,
|
||||
"data_type": "float32",
|
||||
}
|
||||
}
|
||||
|
||||
# # read binary file in chunks
|
||||
# def wave_generator(chunk_size=50000):
|
||||
# with open(FNAME, "rb") as f:
|
||||
# while True:
|
||||
# data = f.read(chunk_size)
|
||||
# if not data:
|
||||
# break
|
||||
# yield data
|
||||
|
||||
|
||||
# def wave_generator():
|
||||
# with wave.open(FNAME, "rb") as f:
|
||||
# while True:
|
||||
# data = f.readframes(100000)
|
||||
# if not data:
|
||||
# break
|
||||
# yield data
|
||||
|
||||
CHUNK_SIZE = 100000
|
||||
|
||||
|
||||
def wave_generator(chunk_size: int = CHUNK_SIZE):
|
||||
for data in sf.blocks(FNAME, blocksize=chunk_size, dtype="float32"):
|
||||
yield data.tobytes()
|
||||
|
||||
|
||||
# def wave_generator(chunk_size=60000):
|
||||
# from scipy.io.wavfile import read
|
||||
#
|
||||
# a = read(FNAME)
|
||||
# data = numpy.array(a[1], dtype=numpy.float32)
|
||||
# print(f"Data: {data.shape}")
|
||||
# for i in range(0, data.shape[0], chunk_size):
|
||||
# print(f"From {i} to {i+chunk_size}")
|
||||
# yield data[i : i + chunk_size]
|
||||
|
||||
|
||||
async def send_receive():
|
||||
print(f"Connecting websocket to url ${URL}")
|
||||
wave_gen = wave_generator()
|
||||
async with websockets.connect(
|
||||
URL,
|
||||
ping_interval=5,
|
||||
ping_timeout=20,
|
||||
) as _ws:
|
||||
await asyncio.sleep(0.1)
|
||||
# print("Receiving SessionBegins ...")
|
||||
await _ws.send(json.dumps(initial_context))
|
||||
# session_begins = await _ws.recv()
|
||||
# print(session_begins)
|
||||
print("Sending messages ...")
|
||||
|
||||
async def send():
|
||||
print("Gonna send")
|
||||
while True:
|
||||
try:
|
||||
# print("Trying to send")
|
||||
segment = next(wave_gen)
|
||||
print(f"Sending segment of length {len(segment)}")
|
||||
if len(segment) == CHUNK_SIZE * 4:
|
||||
msg = segment
|
||||
|
||||
else:
|
||||
# b64-encode bytes
|
||||
payload = base64.b64encode(segment).decode("utf-8")
|
||||
msg = json.dumps({"last_message": payload})
|
||||
|
||||
await _ws.send(msg)
|
||||
except websockets.exceptions.ConnectionClosedError as e:
|
||||
print(e)
|
||||
assert e.code == 4008
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Caught send exception {type(e)}")
|
||||
print(e)
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
|
||||
return True
|
||||
|
||||
async def receive():
|
||||
while True:
|
||||
try:
|
||||
result_str = await _ws.recv()
|
||||
# print(f"Result: {result_str}")
|
||||
rjs = json.loads(result_str)
|
||||
print(f"Result: '{rjs['text']}'")
|
||||
except websockets.exceptions.ConnectionClosedError as e:
|
||||
print(e)
|
||||
assert e.code == 4008
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Caught exception {type(e)}")
|
||||
break
|
||||
|
||||
send_result, receive_result = await asyncio.gather(send(), receive())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(send_receive())
|
|
@ -1,31 +1,35 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import queue
|
||||
import sys
|
||||
from enum import Enum
|
||||
from logging import DEBUG, INFO, basicConfig, getLogger
|
||||
from logging import basicConfig
|
||||
from logging import DEBUG
|
||||
from logging import getLogger
|
||||
from logging import INFO
|
||||
from pathlib import Path
|
||||
from typing import Iterator, List, Optional, Union
|
||||
from typing import Iterator
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
from whispering.pbar import ProgressBar
|
||||
from whispering.schema import Context
|
||||
from whispering.schema import CURRENT_PROTOCOL_VERSION
|
||||
from whispering.schema import StdoutWriter
|
||||
from whispering.schema import WhisperConfig
|
||||
from whispering.serve import serve_with_websocket
|
||||
from whispering.transcriber import WhisperStreamingTranscriber
|
||||
from whispering.websocket_client import run_websocket_client
|
||||
|
||||
import sounddevice as sd
|
||||
import torch
|
||||
import whisper
|
||||
from whisper import available_models
|
||||
from whisper.audio import N_FRAMES, SAMPLE_RATE
|
||||
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
|
||||
|
||||
from whispering.pbar import ProgressBar
|
||||
from whispering.schema import (
|
||||
CURRENT_PROTOCOL_VERSION,
|
||||
Context,
|
||||
StdoutWriter,
|
||||
WhisperConfig,
|
||||
)
|
||||
from whispering.serve import serve_with_websocket
|
||||
from whispering.transcriber import WhisperStreamingTranscriber
|
||||
from whispering.websocket_client import run_websocket_client
|
||||
from whisper.audio import N_FRAMES
|
||||
from whisper.audio import SAMPLE_RATE
|
||||
from whisper.tokenizer import LANGUAGES
|
||||
from whisper.tokenizer import TO_LANGUAGE_CODE
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -275,17 +279,6 @@ def is_valid_arg(
|
|||
return True
|
||||
|
||||
|
||||
import whisper
|
||||
from pathlib import Path
|
||||
|
||||
[
|
||||
whisper._download(
|
||||
whisper._MODELS[m], str(Path("~/.cache/whisper").expanduser()), False
|
||||
)
|
||||
for m in ["tiny", "base", "small", "medium", "large"]
|
||||
]
|
||||
|
||||
|
||||
def main() -> None:
|
||||
opts = get_opts()
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
from logging import getLogger
|
||||
from typing import Final, Optional
|
||||
|
@ -30,7 +31,9 @@ async def serve_with_websocket_main(websocket):
|
|||
except ConnectionClosedOK:
|
||||
break
|
||||
|
||||
if isinstance(message, str):
|
||||
force_padding = False
|
||||
|
||||
if isinstance(message, str) and not ctx:
|
||||
logger.debug(f"Got str: {message}")
|
||||
d = json.loads(message)
|
||||
v = d.get("context")
|
||||
|
@ -66,6 +69,15 @@ async def serve_with_websocket_main(websocket):
|
|||
|
||||
continue
|
||||
|
||||
elif isinstance(message, str) and ctx:
|
||||
d = json.loads(message)
|
||||
logger.warning(f"Received last message")
|
||||
if "last_message" in d:
|
||||
# b64-decode message
|
||||
message = base64.b64decode(d["last_message"])
|
||||
|
||||
force_padding = True
|
||||
|
||||
logger.debug(f"Message size: {len(message)}")
|
||||
if ctx is None:
|
||||
await websocket.send(
|
||||
|
@ -76,10 +88,15 @@ async def serve_with_websocket_main(websocket):
|
|||
)
|
||||
)
|
||||
return
|
||||
logger.debug(f"Have message of size {len(message)} data type {ctx.data_type}")
|
||||
# bytes to np array
|
||||
# audio = np.frombuffer(message, dtype=ctx.data_type)
|
||||
|
||||
audio = np.frombuffer(message, dtype=np.dtype(ctx.data_type)).astype(np.float32)
|
||||
logger.warning(f"Have audio shape {audio.shape}")
|
||||
|
||||
for chunk in g_wsp.transcribe(
|
||||
audio=audio, # type: ignore
|
||||
ctx=ctx,
|
||||
audio=audio, ctx=ctx, force_padding=force_padding # type: ignore
|
||||
):
|
||||
await websocket.send(chunk.json())
|
||||
idx += 1
|
||||
|
|
|
@ -227,13 +227,9 @@ class WhisperStreamingTranscriber:
|
|||
logger.debug(f"Length of buffer: {len(ctx.buffer_tokens)}")
|
||||
|
||||
def transcribe(
|
||||
self,
|
||||
*,
|
||||
audio: np.ndarray,
|
||||
ctx: Context,
|
||||
self, *, audio: np.ndarray, ctx: Context, force_padding: bool = False
|
||||
) -> Iterator[ParsedChunk]:
|
||||
logger.debug(f"{len(audio)}")
|
||||
force_padding: bool = False
|
||||
|
||||
if ctx.vad_threshold > 0.0:
|
||||
x = [
|
||||
|
|
|
@ -19,6 +19,9 @@ logger = getLogger(__name__)
|
|||
def sd_callback(indata, frames, time, status):
|
||||
if status:
|
||||
logger.warning(status)
|
||||
logger.debug(f"Have indata: {len(indata)} {type(indata)}")
|
||||
ravelled = indata.ravel().tobytes()
|
||||
logger.debug(f"Have ravelled: {len(ravelled)} {type(ravelled)}")
|
||||
loop.call_soon_threadsafe(q.put_nowait, indata.ravel().tobytes())
|
||||
|
||||
|
||||
|
@ -44,7 +47,9 @@ async def transcribe_from_mic_and_send(
|
|||
async with websockets.connect(uri, max_size=999999999) as ws: # type:ignore
|
||||
logger.debug("Sent context")
|
||||
v: str = ctx.json()
|
||||
await ws.send("""{"context": """ + v + """}""")
|
||||
msg = """{"context": """ + v + """}"""
|
||||
logger.debug(f"Sending context\n{msg}")
|
||||
await ws.send(msg)
|
||||
|
||||
idx: int = 0
|
||||
while True:
|
||||
|
@ -59,7 +64,7 @@ async def transcribe_from_mic_and_send(
|
|||
except asyncio.TimeoutError:
|
||||
pass
|
||||
if segment is not None:
|
||||
logger.debug(f"Segment size: {len(segment)}")
|
||||
logger.debug(f"Segment size: {len(segment)} {type(segment)}")
|
||||
await ws.send(segment)
|
||||
logger.debug("Sent")
|
||||
|
||||
|
|
Loading…
Reference in a new issue