"""A simple web interactive chat demo based on gradio.""" import os import time import gradio as gr import base64 import numpy as np import requests API_URL = os.getenv("API_URL", None) client = None if API_URL is None: from inference import OmniInference omni_client = OmniInference('./checkpoint', 'cuda:0') omni_client.warm_up() OUT_CHUNK = 4096 OUT_RATE = 24000 OUT_CHANNELS = 1 def process_audio(audio): filepath = audio print(f"filepath: {filepath}") if filepath is None: return cnt = 0 if API_URL is not None: with open(filepath, "rb") as f: data = f.read() base64_encoded = str(base64.b64encode(data), encoding="utf-8") files = {"audio": base64_encoded} tik = time.time() with requests.post(API_URL, json=files, stream=True) as response: try: for chunk in response.iter_content(chunk_size=OUT_CHUNK): if chunk: # Convert chunk to numpy array if cnt == 0: print(f"first chunk time cost: {time.time() - tik:.3f}") cnt += 1 audio_data = np.frombuffer(chunk, dtype=np.int16) audio_data = audio_data.reshape(-1, OUT_CHANNELS) yield OUT_RATE, audio_data.astype(np.int16) except Exception as e: print(f"error: {e}") else: tik = time.time() for chunk in omni_client.run_AT_batch_stream(filepath): # Convert chunk to numpy array if cnt == 0: print(f"first chunk time cost: {time.time() - tik:.3f}") cnt += 1 audio_data = np.frombuffer(chunk, dtype=np.int16) audio_data = audio_data.reshape(-1, OUT_CHANNELS) yield OUT_RATE, audio_data.astype(np.int16) def main(port=None): demo = gr.Interface( process_audio, inputs=gr.Audio(type="filepath", label="Microphone"), outputs=[gr.Audio(label="Response", streaming=True, autoplay=True)], title="Chat Mini-Omni Demo", live=True, ) if port is not None: demo.queue().launch(share=False, server_name="0.0.0.0", server_port=port) else: demo.queue().launch() if __name__ == "__main__": import fire fire.Fire(main)