freddyaboulton's picture
first
54011d4
raw
history blame
4.45 kB
import logging
import base64
import io
import os
from threading import Thread
import gradio as gr
import numpy as np
import requests
from gradio_webrtc import ReplyOnPause, WebRTC, AdditionalOutputs
from pydub import AudioSegment
from twilio.rest import Client
from server import serve
logging.basicConfig(level=logging.WARNING)
file_handler = logging.FileHandler("gradio_webrtc.log")
file_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
file_handler.setFormatter(formatter)
logger = logging.getLogger("gradio_webrtc")
logger.setLevel(logging.DEBUG)
logger.addHandler(file_handler)
IP = "0.0.0.0"
PORT = 60808
thread = Thread(target=serve, daemon=True)
thread.start()
API_URL = "http://0.0.0.0:60808/chat"
account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
if account_sid and auth_token:
client = Client(account_sid, auth_token)
token = client.tokens.create()
rtc_configuration = {
"iceServers": token.ice_servers,
"iceTransportPolicy": "relay",
}
else:
rtc_configuration = None
OUT_CHANNELS = 1
OUT_RATE = 24000
OUT_SAMPLE_WIDTH = 2
OUT_CHUNK = 20 * 4096
def response(audio: tuple[int, np.ndarray], conversation: list[dict], img: str | None):
conversation.append({"role": "user", "content": gr.Audio(audio)})
yield AdditionalOutputs(conversation)
sampling_rate, audio_np = audio
audio_np = audio_np.squeeze()
audio_buffer = io.BytesIO()
segment = AudioSegment(
audio_np.tobytes(),
frame_rate=sampling_rate,
sample_width=audio_np.dtype.itemsize,
channels=1,
)
segment.export(audio_buffer, format="wav")
conversation.append({"role": "assistant", "content": ""})
base64_encoded = str(base64.b64encode(audio_buffer.getvalue()), encoding="utf-8")
if API_URL is not None:
output_audio_bytes = b""
files = {"audio": base64_encoded}
if img is not None:
files["image"] = str(base64.b64encode(open(img, "rb").read()), encoding="utf-8")
print("sending request to server")
resp_text = ""
with requests.post(API_URL, json=files, stream=True) as response:
try:
buffer = b''
for chunk in response.iter_content(chunk_size=2048):
buffer += chunk
while b'\r\n--frame\r\n' in buffer:
frame, buffer = buffer.split(b'\r\n--frame\r\n', 1)
if b'Content-Type: audio/wav' in frame:
audio_data = frame.split(b'\r\n\r\n', 1)[1]
# audio_data = base64.b64decode(audio_data)
output_audio_bytes += audio_data
audio_array = np.frombuffer(audio_data, dtype=np.int8).reshape(1, -1)
yield (OUT_RATE, audio_array, "mono")
elif b'Content-Type: text/plain' in frame:
text_data = frame.split(b'\r\n\r\n', 1)[1].decode()
resp_text += text_data
if len(text_data) > 0:
conversation[-1]["content"] = resp_text
yield AdditionalOutputs(conversation)
except Exception as e:
raise Exception(f"Error during audio streaming: {e}") from e
with gr.Blocks() as demo:
gr.HTML(
"""
<h1 style='text-align: center'>
Mini-Omni-2 Chat (Powered by WebRTC ⚡️)
</h1>
"""
)
with gr.Row():
with gr.Column():
with gr.Group():
audio = WebRTC(
label="Stream",
rtc_configuration=rtc_configuration,
mode="send-receive",
modality="audio",
)
img = gr.Image(label="Image", type="filepath")
with gr.Column():
conversation = gr.Chatbot(label="Conversation", type="messages")
audio.stream(
fn=ReplyOnPause(
response, output_sample_rate=OUT_RATE, output_frame_size=480
),
inputs=[audio, conversation, img],
outputs=[audio],
time_limit=90,
)
audio.on_additional_outputs(lambda c: c, outputs=[conversation])
demo.launch()