process last file

This commit is contained in:
ricwo 2022-11-09 17:39:30 +01:00
parent df99efca98
commit 59b1230af7
6 changed files with 193 additions and 39 deletions

View file

@ -11,14 +11,12 @@ RUN pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1
RUN pip install ipython
# Copy entire repo
COPY . /app
WORKDIR /app
RUN pip install .
COPY download_whisper_models.py .
RUN python3 download_whisper_models.py --models tiny small base medium large
# open bash

145
stream_file.py Normal file
View file

@ -0,0 +1,145 @@
import sys
import wave
import numpy
import websockets
import asyncio
import base64
import json
import logging
import soundfile as sf
logging.basicConfig(level=logging.INFO)
# URL = "ws://ec2-3-71-72-35.eu-central-1.compute.amazonaws.com"
# URL = "ws://localhost:1111/ws"
URL = "ws://localhost:8000"
FNAME = sys.argv[1]
initial_context = {
"context": {
"protocol_version": 6002,
"timestamp": 0.0,
"buffer_tokens": [],
"buffer_mel": None,
"nosoeech_skip_count": None,
"temperatures": [0.2],
"patience": None,
"compression_ratio_threshold": 2.4,
"logprob_threshold": -1.0,
"no_captions_threshold": 0.6,
"best_of": 5,
"beam_size": 5,
"no_speech_threshold": 0.6,
"buffer_threshold": 0.5,
"vad_threshold": 0.5,
"max_nospeech_skip": 16,
"data_type": "float32",
}
}
# # read binary file in chunks
# def wave_generator(chunk_size=50000):
# with open(FNAME, "rb") as f:
# while True:
# data = f.read(chunk_size)
# if not data:
# break
# yield data
# def wave_generator():
# with wave.open(FNAME, "rb") as f:
# while True:
# data = f.readframes(100000)
# if not data:
# break
# yield data
CHUNK_SIZE = 100000
def wave_generator(chunk_size: int = CHUNK_SIZE):
for data in sf.blocks(FNAME, blocksize=chunk_size, dtype="float32"):
yield data.tobytes()
# def wave_generator(chunk_size=60000):
# from scipy.io.wavfile import read
#
# a = read(FNAME)
# data = numpy.array(a[1], dtype=numpy.float32)
# print(f"Data: {data.shape}")
# for i in range(0, data.shape[0], chunk_size):
# print(f"From {i} to {i+chunk_size}")
# yield data[i : i + chunk_size]
async def send_receive():
print(f"Connecting websocket to url ${URL}")
wave_gen = wave_generator()
async with websockets.connect(
URL,
ping_interval=5,
ping_timeout=20,
) as _ws:
await asyncio.sleep(0.1)
# print("Receiving SessionBegins ...")
await _ws.send(json.dumps(initial_context))
# session_begins = await _ws.recv()
# print(session_begins)
print("Sending messages ...")
async def send():
print("Gonna send")
while True:
try:
# print("Trying to send")
segment = next(wave_gen)
print(f"Sending segment of length {len(segment)}")
if len(segment) == CHUNK_SIZE * 4:
msg = segment
else:
# b64-encode bytes
payload = base64.b64encode(segment).decode("utf-8")
msg = json.dumps({"last_message": payload})
await _ws.send(msg)
except websockets.exceptions.ConnectionClosedError as e:
print(e)
assert e.code == 4008
break
except Exception as e:
print(f"Caught send exception {type(e)}")
print(e)
break
await asyncio.sleep(1)
return True
async def receive():
while True:
try:
result_str = await _ws.recv()
# print(f"Result: {result_str}")
rjs = json.loads(result_str)
print(f"Result: '{rjs['text']}'")
except websockets.exceptions.ConnectionClosedError as e:
print(e)
assert e.code == 4008
break
except Exception as e:
print(f"Caught exception {type(e)}")
break
send_result, receive_result = await asyncio.gather(send(), receive())
if __name__ == "__main__":
loop = asyncio.get_event_loop()
loop.run_until_complete(send_receive())

View file

@ -1,31 +1,35 @@
#!/usr/bin/env python3
import argparse
import asyncio
import queue
import sys
from enum import Enum
from logging import DEBUG, INFO, basicConfig, getLogger
from logging import basicConfig
from logging import DEBUG
from logging import getLogger
from logging import INFO
from pathlib import Path
from typing import Iterator, List, Optional, Union
from typing import Iterator
from typing import List
from typing import Optional
from typing import Union
from whispering.pbar import ProgressBar
from whispering.schema import Context
from whispering.schema import CURRENT_PROTOCOL_VERSION
from whispering.schema import StdoutWriter
from whispering.schema import WhisperConfig
from whispering.serve import serve_with_websocket
from whispering.transcriber import WhisperStreamingTranscriber
from whispering.websocket_client import run_websocket_client
import sounddevice as sd
import torch
import whisper
from whisper import available_models
from whisper.audio import N_FRAMES, SAMPLE_RATE
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
from whispering.pbar import ProgressBar
from whispering.schema import (
CURRENT_PROTOCOL_VERSION,
Context,
StdoutWriter,
WhisperConfig,
)
from whispering.serve import serve_with_websocket
from whispering.transcriber import WhisperStreamingTranscriber
from whispering.websocket_client import run_websocket_client
from whisper.audio import N_FRAMES
from whisper.audio import SAMPLE_RATE
from whisper.tokenizer import LANGUAGES
from whisper.tokenizer import TO_LANGUAGE_CODE
logger = getLogger(__name__)
@ -275,17 +279,6 @@ def is_valid_arg(
return True
import whisper
from pathlib import Path
[
whisper._download(
whisper._MODELS[m], str(Path("~/.cache/whisper").expanduser()), False
)
for m in ["tiny", "base", "small", "medium", "large"]
]
def main() -> None:
opts = get_opts()

View file

@ -1,6 +1,7 @@
#!/usr/bin/env python3
import asyncio
import base64
import json
from logging import getLogger
from typing import Final, Optional
@ -30,7 +31,9 @@ async def serve_with_websocket_main(websocket):
except ConnectionClosedOK:
break
if isinstance(message, str):
force_padding = False
if isinstance(message, str) and not ctx:
logger.debug(f"Got str: {message}")
d = json.loads(message)
v = d.get("context")
@ -66,6 +69,15 @@ async def serve_with_websocket_main(websocket):
continue
elif isinstance(message, str) and ctx:
d = json.loads(message)
logger.warning(f"Received last message")
if "last_message" in d:
# b64-decode message
message = base64.b64decode(d["last_message"])
force_padding = True
logger.debug(f"Message size: {len(message)}")
if ctx is None:
await websocket.send(
@ -76,10 +88,15 @@ async def serve_with_websocket_main(websocket):
)
)
return
logger.debug(f"Have message of size {len(message)} data type {ctx.data_type}")
# bytes to np array
# audio = np.frombuffer(message, dtype=ctx.data_type)
audio = np.frombuffer(message, dtype=np.dtype(ctx.data_type)).astype(np.float32)
logger.warning(f"Have audio shape {audio.shape}")
for chunk in g_wsp.transcribe(
audio=audio, # type: ignore
ctx=ctx,
audio=audio, ctx=ctx, force_padding=force_padding # type: ignore
):
await websocket.send(chunk.json())
idx += 1

View file

@ -227,13 +227,9 @@ class WhisperStreamingTranscriber:
logger.debug(f"Length of buffer: {len(ctx.buffer_tokens)}")
def transcribe(
self,
*,
audio: np.ndarray,
ctx: Context,
self, *, audio: np.ndarray, ctx: Context, force_padding: bool = False
) -> Iterator[ParsedChunk]:
logger.debug(f"{len(audio)}")
force_padding: bool = False
if ctx.vad_threshold > 0.0:
x = [

View file

@ -19,6 +19,9 @@ logger = getLogger(__name__)
def sd_callback(indata, frames, time, status):
if status:
logger.warning(status)
logger.debug(f"Have indata: {len(indata)} {type(indata)}")
ravelled = indata.ravel().tobytes()
logger.debug(f"Have ravelled: {len(ravelled)} {type(ravelled)}")
loop.call_soon_threadsafe(q.put_nowait, indata.ravel().tobytes())
@ -44,7 +47,9 @@ async def transcribe_from_mic_and_send(
async with websockets.connect(uri, max_size=999999999) as ws: # type:ignore
logger.debug("Sent context")
v: str = ctx.json()
await ws.send("""{"context": """ + v + """}""")
msg = """{"context": """ + v + """}"""
logger.debug(f"Sending context\n{msg}")
await ws.send(msg)
idx: int = 0
while True:
@ -59,7 +64,7 @@ async def transcribe_from_mic_and_send(
except asyncio.TimeoutError:
pass
if segment is not None:
logger.debug(f"Segment size: {len(segment)}")
logger.debug(f"Segment size: {len(segment)} {type(segment)}")
await ws.send(segment)
logger.debug("Sent")