mini-omni / app.py
StevenChen16's picture
Update app.py
136d7a0 verified
raw
history blame
2.46 kB
"""A simple web interactive chat demo based on gradio."""
import os
import time
import gradio as gr
import base64
import numpy as np
import requests
API_URL = os.getenv("API_URL", None)
client = None
if API_URL is None:
from inference import OmniInference
omni_client = OmniInference('./checkpoint', 'cuda:0')
omni_client.warm_up()
OUT_CHUNK = 4096
OUT_RATE = 24000
OUT_CHANNELS = 1
def process_audio(audio):
filepath = audio
print(f"filepath: {filepath}")
if filepath is None:
return
cnt = 0
if API_URL is not None:
with open(filepath, "rb") as f:
data = f.read()
base64_encoded = str(base64.b64encode(data), encoding="utf-8")
files = {"audio": base64_encoded}
tik = time.time()
with requests.post(API_URL, json=files, stream=True) as response:
try:
for chunk in response.iter_content(chunk_size=OUT_CHUNK):
if chunk:
# Convert chunk to numpy array
if cnt == 0:
print(f"first chunk time cost: {time.time() - tik:.3f}")
cnt += 1
audio_data = np.frombuffer(chunk, dtype=np.int16)
audio_data = audio_data.reshape(-1, OUT_CHANNELS)
yield OUT_RATE, audio_data.astype(np.int16)
except Exception as e:
print(f"error: {e}")
else:
tik = time.time()
for chunk in omni_client.run_AT_batch_stream(filepath):
# Convert chunk to numpy array
if cnt == 0:
print(f"first chunk time cost: {time.time() - tik:.3f}")
cnt += 1
audio_data = np.frombuffer(chunk, dtype=np.int16)
audio_data = audio_data.reshape(-1, OUT_CHANNELS)
yield OUT_RATE, audio_data.astype(np.int16)
def main(port=None):
demo = gr.Interface(
process_audio,
inputs=gr.Audio(type="filepath", label="Microphone"),
outputs=[gr.Audio(label="Response", streaming=True, autoplay=True)],
title="Chat Mini-Omni Demo",
live=True,
)
if port is not None:
demo.queue().launch(share=False, server_name="0.0.0.0", server_port=port)
else:
demo.queue().launch()
if __name__ == "__main__":
import fire
fire.Fire(main)