Spaces:
Sleeping
feat: Update Dockerfile and requirements.txt to resolve PyAudio build issues
Browse filesThe main changes are:
1. Added `build-essential` and `libasound2-dev` to the system dependencies in the Dockerfile to ensure the necessary build tools are available.
2. Removed PyAudio from the `requirements.txt` file to avoid the pip installation issues.
3. Added a separate `RUN pip install PyAudio==0.2.14` command in the Dockerfile to install PyAudio manually.
These changes should resolve the build issues with PyAudio on the CUDA server.
Revert "fix: Handle CUDA availability in OmniChatServer"
This reverts commit 28ed763269f75cea8298b3d64449fd7776d05f52.
docs: add PyAudio to dependencies
feat: Replace PyAudio with streamlit-webrtc for user recording
fix: Replace PyAudio with streamlit-webrtc for audio recording
feat: Serve HTML demo instead of Streamlit app
fix: Update API_URL and error handling in webui/omni_html_demo.html
fix: Replace audio playback with text-to-speech
feat: Implement audio processing and response generation
fix: Use a Docker data volume for caching
feat: Add Docker data volume and environment variables for caching
diff --git a/inference.py b/inference.py
index 4d4d4d1..d4d4d1a 100644
--- a/inference.py
+++ b/inference.py
@@ -1,6 +1,7 @@
def download_model(ckpt_dir):
repo_id = "gpt-omni/mini-omni"
- snapshot_download(repo_id, local_dir=ckpt_dir, revision="main")
+ cache_dir = os.environ.get('XDG_CACHE_HOME', '/tmp')
+ snapshot_download(repo_id, local_dir=ckpt_dir, revision="main", cache_dir=cache_dir)
rm
fix: Remove cache-related code and update Dockerfile
fix: Add Docker volume and set permissions for model download
fix: Set correct permissions for checkpoint directory
feat: Use DATA volume to store model checkpoint
fix: Set permissions and create necessary subdirectories in the DATA volume
fix: Implement error handling and CUDA Tensor Cores optimization in serve_html.py
fix: Improve error handling and logging in chat endpoint
- Dockerfile +23 -11
- README.md +125 -124
- inference.py +7 -5
- requirements.txt +8 -2
- serve_html.py +70 -0
- server.py +4 -7
- webui/index.html +0 -258
- webui/omni_html_demo.html +13 -8
- webui/omni_streamlit.py +134 -257
@@ -7,7 +7,6 @@ WORKDIR /app
|
|
7 |
# Install system dependencies
|
8 |
RUN apt-get update && apt-get install -y \
|
9 |
ffmpeg \
|
10 |
-
portaudio19-dev \
|
11 |
&& rm -rf /var/lib/apt/lists/*
|
12 |
|
13 |
# Copy the current directory contents into the container at /app
|
@@ -16,20 +15,33 @@ COPY . /app
|
|
16 |
# Install any needed packages specified in requirements.txt
|
17 |
RUN pip install --no-cache-dir -r requirements.txt
|
18 |
|
19 |
-
#
|
20 |
-
|
21 |
-
|
22 |
-
# Make ports 7860 and 60808 available to the world outside this container
|
23 |
-
EXPOSE 7860 60808
|
24 |
|
25 |
# Set environment variable for API_URL
|
26 |
-
ENV API_URL=http://0.0.0.0:
|
27 |
|
28 |
# Set PYTHONPATH
|
29 |
ENV PYTHONPATH=./
|
30 |
|
31 |
-
#
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
-
# Run
|
35 |
-
CMD ["
|
|
|
7 |
# Install system dependencies
|
8 |
RUN apt-get update && apt-get install -y \
|
9 |
ffmpeg \
|
|
|
10 |
&& rm -rf /var/lib/apt/lists/*
|
11 |
|
12 |
# Copy the current directory contents into the container at /app
|
|
|
15 |
# Install any needed packages specified in requirements.txt
|
16 |
RUN pip install --no-cache-dir -r requirements.txt
|
17 |
|
18 |
+
# Make port 7860 available to the world outside this container
|
19 |
+
EXPOSE 7860
|
|
|
|
|
|
|
20 |
|
21 |
# Set environment variable for API_URL
|
22 |
+
ENV API_URL=http://0.0.0.0:7860/chat
|
23 |
|
24 |
# Set PYTHONPATH
|
25 |
ENV PYTHONPATH=./
|
26 |
|
27 |
+
# Set environment variables
|
28 |
+
ENV MPLCONFIGDIR=/tmp/matplotlib
|
29 |
+
ENV HF_HOME=/data/huggingface
|
30 |
+
ENV XDG_CACHE_HOME=/data/cache
|
31 |
+
|
32 |
+
# Create a volume for data
|
33 |
+
VOLUME /data
|
34 |
+
|
35 |
+
# Set permissions for the /data directory and create necessary subdirectories
|
36 |
+
RUN mkdir -p /data/checkpoint /data/cache /data/huggingface && \
|
37 |
+
chown -R 1000:1000 /data && \
|
38 |
+
chmod -R 777 /data
|
39 |
+
|
40 |
+
# Install Flask
|
41 |
+
RUN pip install flask
|
42 |
+
|
43 |
+
# Copy the HTML demo file
|
44 |
+
COPY webui/omni_html_demo.html .
|
45 |
|
46 |
+
# Run the Flask app to serve the HTML demo
|
47 |
+
CMD ["python", "serve_html.py"]
|
@@ -1,124 +1,125 @@
|
|
1 |
-
---
|
2 |
-
title: Omni Docker
|
3 |
-
emoji: 🦀
|
4 |
-
colorFrom: green
|
5 |
-
colorTo: red
|
6 |
-
sdk: docker
|
7 |
-
pinned: false
|
8 |
-
---
|
9 |
-
|
10 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
11 |
-
|
12 |
-
# Mini-Omni
|
13 |
-
|
14 |
-
<p align="center"><strong style="font-size: 18px;">
|
15 |
-
Mini-Omni: Language Models Can Hear, Talk While Thinking in Streaming
|
16 |
-
</strong>
|
17 |
-
</p>
|
18 |
-
|
19 |
-
<p align="center">
|
20 |
-
🤗 <a href="https://huggingface.co/gpt-omni/mini-omni">Hugging Face</a> | 📖 <a href="https://github.com/gpt-omni/mini-omni">Github</a>
|
21 |
-
| 📑 <a href="https://arxiv.org/abs/2408.16725">Technical report</a>
|
22 |
-
</p>
|
23 |
-
|
24 |
-
Mini-Omni is an open-source multimodal large language model that can **hear, talk while thinking**. Featuring real-time end-to-end speech input and **streaming audio output** conversational capabilities.
|
25 |
-
|
26 |
-
<p align="center">
|
27 |
-
<img src="data/figures/frameworkv3.jpg" width="100%"/>
|
28 |
-
</p>
|
29 |
-
|
30 |
-
|
31 |
-
## Features
|
32 |
-
|
33 |
-
✅ **Real-time speech-to-speech** conversational capabilities. No extra ASR or TTS models required.
|
34 |
-
|
35 |
-
✅ **Talking while thinking**, with the ability to generate text and audio at the same time.
|
36 |
-
|
37 |
-
✅ **Streaming audio output** capabilities.
|
38 |
-
|
39 |
-
✅ With "Audio-to-Text" and "Audio-to-Audio" **batch inference** to further boost the performance.
|
40 |
-
|
41 |
-
## Demo
|
42 |
-
|
43 |
-
NOTE: need to unmute first.
|
44 |
-
|
45 |
-
https://github.com/user-attachments/assets/03bdde05-9514-4748-b527-003bea57f118
|
46 |
-
|
47 |
-
|
48 |
-
## Install
|
49 |
-
|
50 |
-
Create a new conda environment and install the required packages:
|
51 |
-
|
52 |
-
```sh
|
53 |
-
conda create -n omni python=3.10
|
54 |
-
conda activate omni
|
55 |
-
|
56 |
-
git clone https://github.com/gpt-omni/mini-omni.git
|
57 |
-
cd mini-omni
|
58 |
-
pip install -r requirements.txt
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
- [
|
117 |
-
- [
|
118 |
-
- [
|
119 |
-
- [
|
120 |
-
- [
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
|
|
|
1 |
+
---
|
2 |
+
title: Omni Docker
|
3 |
+
emoji: 🦀
|
4 |
+
colorFrom: green
|
5 |
+
colorTo: red
|
6 |
+
sdk: docker
|
7 |
+
pinned: false
|
8 |
+
---
|
9 |
+
|
10 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
11 |
+
|
12 |
+
# Mini-Omni
|
13 |
+
|
14 |
+
<p align="center"><strong style="font-size: 18px;">
|
15 |
+
Mini-Omni: Language Models Can Hear, Talk While Thinking in Streaming
|
16 |
+
</strong>
|
17 |
+
</p>
|
18 |
+
|
19 |
+
<p align="center">
|
20 |
+
🤗 <a href="https://huggingface.co/gpt-omni/mini-omni">Hugging Face</a> | 📖 <a href="https://github.com/gpt-omni/mini-omni">Github</a>
|
21 |
+
| 📑 <a href="https://arxiv.org/abs/2408.16725">Technical report</a>
|
22 |
+
</p>
|
23 |
+
|
24 |
+
Mini-Omni is an open-source multimodal large language model that can **hear, talk while thinking**. Featuring real-time end-to-end speech input and **streaming audio output** conversational capabilities.
|
25 |
+
|
26 |
+
<p align="center">
|
27 |
+
<img src="data/figures/frameworkv3.jpg" width="100%"/>
|
28 |
+
</p>
|
29 |
+
|
30 |
+
|
31 |
+
## Features
|
32 |
+
|
33 |
+
✅ **Real-time speech-to-speech** conversational capabilities. No extra ASR or TTS models required.
|
34 |
+
|
35 |
+
✅ **Talking while thinking**, with the ability to generate text and audio at the same time.
|
36 |
+
|
37 |
+
✅ **Streaming audio output** capabilities.
|
38 |
+
|
39 |
+
✅ With "Audio-to-Text" and "Audio-to-Audio" **batch inference** to further boost the performance.
|
40 |
+
|
41 |
+
## Demo
|
42 |
+
|
43 |
+
NOTE: need to unmute first.
|
44 |
+
|
45 |
+
https://github.com/user-attachments/assets/03bdde05-9514-4748-b527-003bea57f118
|
46 |
+
|
47 |
+
|
48 |
+
## Install
|
49 |
+
|
50 |
+
Create a new conda environment and install the required packages:
|
51 |
+
|
52 |
+
```sh
|
53 |
+
conda create -n omni python=3.10
|
54 |
+
conda activate omni
|
55 |
+
|
56 |
+
git clone https://github.com/gpt-omni/mini-omni.git
|
57 |
+
cd mini-omni
|
58 |
+
pip install -r requirements.txt
|
59 |
+
pip install PyAudio==0.2.14
|
60 |
+
```
|
61 |
+
|
62 |
+
## Quick start
|
63 |
+
|
64 |
+
**Interactive demo**
|
65 |
+
|
66 |
+
- start server
|
67 |
+
|
68 |
+
NOTE: you need to start the server before running the streamlit or gradio demo with API_URL set to the server address.
|
69 |
+
|
70 |
+
```sh
|
71 |
+
sudo apt-get install ffmpeg
|
72 |
+
conda activate omni
|
73 |
+
cd mini-omni
|
74 |
+
python3 server.py --ip '0.0.0.0' --port 60808
|
75 |
+
```
|
76 |
+
|
77 |
+
|
78 |
+
- run streamlit demo
|
79 |
+
|
80 |
+
NOTE: you need to run streamlit locally with PyAudio installed. For error: `ModuleNotFoundError: No module named 'utils.vad'`, please run `export PYTHONPATH=./` first.
|
81 |
+
|
82 |
+
```sh
|
83 |
+
pip install PyAudio==0.2.14
|
84 |
+
API_URL=http://0.0.0.0:60808/chat streamlit run webui/omni_streamlit.py
|
85 |
+
```
|
86 |
+
|
87 |
+
- run gradio demo
|
88 |
+
```sh
|
89 |
+
API_URL=http://0.0.0.0:60808/chat python3 webui/omni_gradio.py
|
90 |
+
```
|
91 |
+
|
92 |
+
example:
|
93 |
+
|
94 |
+
NOTE: need to unmute first. Gradio seems can not play audio stream instantly, so the latency feels a bit longer.
|
95 |
+
|
96 |
+
https://github.com/user-attachments/assets/29187680-4c42-47ff-b352-f0ea333496d9
|
97 |
+
|
98 |
+
|
99 |
+
**Local test**
|
100 |
+
|
101 |
+
```sh
|
102 |
+
conda activate omni
|
103 |
+
cd mini-omni
|
104 |
+
# test run the preset audio samples and questions
|
105 |
+
python inference.py
|
106 |
+
```
|
107 |
+
|
108 |
+
## Common issues
|
109 |
+
|
110 |
+
- Error: `ModuleNotFoundError: No module named 'utils.xxxx'`
|
111 |
+
|
112 |
+
Answer: run `export PYTHONPATH=./` first.
|
113 |
+
|
114 |
+
## Acknowledgements
|
115 |
+
|
116 |
+
- [Qwen2](https://github.com/QwenLM/Qwen2/) as the LLM backbone.
|
117 |
+
- [litGPT](https://github.com/Lightning-AI/litgpt/) for training and inference.
|
118 |
+
- [whisper](https://github.com/openai/whisper/) for audio encoding.
|
119 |
+
- [snac](https://github.com/hubertsiuzdak/snac/) for audio decoding.
|
120 |
+
- [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) for generating synthetic speech.
|
121 |
+
- [OpenOrca](https://huggingface.co/datasets/Open-Orca/OpenOrca) and [MOSS](https://github.com/OpenMOSS/MOSS/tree/main) for alignment.
|
122 |
+
|
123 |
+
## Star History
|
124 |
+
|
125 |
+
[![Star History Chart](https://api.star-history.com/svg?repos=gpt-omni/mini-omni&type=Date)](https://star-history.com/#gpt-omni/mini-omni&Date)
|
@@ -7,6 +7,8 @@ from litgpt import Tokenizer
|
|
7 |
from litgpt.utils import (
|
8 |
num_parameters,
|
9 |
)
|
|
|
|
|
10 |
from litgpt.generate.base import (
|
11 |
generate_AA,
|
12 |
generate_ASR,
|
@@ -347,8 +349,8 @@ def T1_T2(fabric, input_ids, model, text_tokenizer, step):
|
|
347 |
|
348 |
|
349 |
def load_model(ckpt_dir, device):
|
350 |
-
snacmodel = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(device)
|
351 |
-
whispermodel = whisper.load_model("small").to(device)
|
352 |
text_tokenizer = Tokenizer(ckpt_dir)
|
353 |
fabric = L.Fabric(devices=1, strategy="auto")
|
354 |
config = Config.from_file(ckpt_dir + "/model_config.yaml")
|
@@ -367,12 +369,12 @@ def load_model(ckpt_dir, device):
|
|
367 |
|
368 |
def download_model(ckpt_dir):
|
369 |
repo_id = "gpt-omni/mini-omni"
|
370 |
-
snapshot_download(repo_id, local_dir=ckpt_dir, revision="main")
|
371 |
|
372 |
|
373 |
class OmniInference:
|
374 |
|
375 |
-
def __init__(self, ckpt_dir='
|
376 |
self.device = device
|
377 |
if not os.path.exists(ckpt_dir):
|
378 |
print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
|
@@ -508,7 +510,7 @@ class OmniInference:
|
|
508 |
def test_infer():
|
509 |
device = "cuda:0"
|
510 |
out_dir = f"./output/{get_time_str()}"
|
511 |
-
ckpt_dir = f"
|
512 |
if not os.path.exists(ckpt_dir):
|
513 |
print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
|
514 |
download_model(ckpt_dir)
|
|
|
7 |
from litgpt.utils import (
|
8 |
num_parameters,
|
9 |
)
|
10 |
+
import matplotlib
|
11 |
+
matplotlib.use('Agg') # Use a non-GUI backend
|
12 |
from litgpt.generate.base import (
|
13 |
generate_AA,
|
14 |
generate_ASR,
|
|
|
349 |
|
350 |
|
351 |
def load_model(ckpt_dir, device):
|
352 |
+
snacmodel = SNAC.from_pretrained("hubertsiuzdak/snac_24khz", cache_dir="/data/cache/snac").eval().to(device)
|
353 |
+
whispermodel = whisper.load_model("small", download_root="/data/cache/whisper").to(device)
|
354 |
text_tokenizer = Tokenizer(ckpt_dir)
|
355 |
fabric = L.Fabric(devices=1, strategy="auto")
|
356 |
config = Config.from_file(ckpt_dir + "/model_config.yaml")
|
|
|
369 |
|
370 |
def download_model(ckpt_dir):
|
371 |
repo_id = "gpt-omni/mini-omni"
|
372 |
+
snapshot_download(repo_id, local_dir=ckpt_dir, revision="main", cache_dir="/data/huggingface")
|
373 |
|
374 |
|
375 |
class OmniInference:
|
376 |
|
377 |
+
def __init__(self, ckpt_dir='/data/checkpoint', device='cuda:0'):
|
378 |
self.device = device
|
379 |
if not os.path.exists(ckpt_dir):
|
380 |
print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
|
|
|
510 |
def test_infer():
|
511 |
device = "cuda:0"
|
512 |
out_dir = f"./output/{get_time_str()}"
|
513 |
+
ckpt_dir = f"/data/checkpoint"
|
514 |
if not os.path.exists(ckpt_dir):
|
515 |
print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
|
516 |
download_model(ckpt_dir)
|
@@ -6,8 +6,14 @@ snac==1.2.0
|
|
6 |
soundfile==0.12.1
|
7 |
openai-whisper==20231117
|
8 |
tokenizers==0.15.2
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
pydub==0.25.1
|
12 |
onnxruntime==1.17.1
|
13 |
numpy==1.26.4
|
|
|
6 |
soundfile==0.12.1
|
7 |
openai-whisper==20231117
|
8 |
tokenizers==0.15.2
|
9 |
+
torch==2.2.1
|
10 |
+
torchvision==0.17.1
|
11 |
+
torchaudio==2.2.1
|
12 |
+
litgpt==0.4.3
|
13 |
+
snac==1.2.0
|
14 |
+
soundfile==0.12.1
|
15 |
+
openai-whisper==20231117
|
16 |
+
tokenizers==0.15.2
|
17 |
pydub==0.25.1
|
18 |
onnxruntime==1.17.1
|
19 |
numpy==1.26.4
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
torch.set_float32_matmul_precision('high')
|
3 |
+
|
4 |
+
from flask import Flask, send_from_directory, request, Response
|
5 |
+
import os
|
6 |
+
import base64
|
7 |
+
import numpy as np
|
8 |
+
from inference import OmniInference
|
9 |
+
import io
|
10 |
+
|
11 |
+
app = Flask(__name__)
|
12 |
+
|
13 |
+
# Initialize OmniInference
|
14 |
+
try:
|
15 |
+
print("Initializing OmniInference...")
|
16 |
+
omni = OmniInference()
|
17 |
+
print("OmniInference initialized successfully.")
|
18 |
+
except Exception as e:
|
19 |
+
print(f"Error initializing OmniInference: {str(e)}")
|
20 |
+
raise
|
21 |
+
|
22 |
+
@app.route('/')
|
23 |
+
def serve_html():
|
24 |
+
return send_from_directory('.', 'webui/omni_html_demo.html')
|
25 |
+
|
26 |
+
@app.route('/chat', methods=['POST'])
|
27 |
+
def chat():
|
28 |
+
try:
|
29 |
+
audio_data = request.json['audio']
|
30 |
+
if not audio_data:
|
31 |
+
return "No audio data received", 400
|
32 |
+
|
33 |
+
# Check if the audio_data contains the expected base64 prefix
|
34 |
+
if ',' in audio_data:
|
35 |
+
audio_bytes = base64.b64decode(audio_data.split(',')[1])
|
36 |
+
else:
|
37 |
+
audio_bytes = base64.b64decode(audio_data)
|
38 |
+
|
39 |
+
# Save audio to a temporary file
|
40 |
+
temp_audio_path = 'temp_audio.wav'
|
41 |
+
with open(temp_audio_path, 'wb') as f:
|
42 |
+
f.write(audio_bytes)
|
43 |
+
|
44 |
+
# Generate response using OmniInference
|
45 |
+
try:
|
46 |
+
response_generator = omni.run_AT_batch_stream(temp_audio_path)
|
47 |
+
|
48 |
+
# Concatenate all audio chunks
|
49 |
+
all_audio = b''
|
50 |
+
for audio_chunk in response_generator:
|
51 |
+
all_audio += audio_chunk
|
52 |
+
|
53 |
+
# Clean up temporary file
|
54 |
+
os.remove(temp_audio_path)
|
55 |
+
|
56 |
+
return Response(all_audio, mimetype='audio/wav')
|
57 |
+
except Exception as inner_e:
|
58 |
+
print(f"Error in OmniInference processing: {str(inner_e)}")
|
59 |
+
return f"An error occurred during audio processing: {str(inner_e)}", 500
|
60 |
+
finally:
|
61 |
+
# Ensure temporary file is removed even if an error occurs
|
62 |
+
if os.path.exists(temp_audio_path):
|
63 |
+
os.remove(temp_audio_path)
|
64 |
+
|
65 |
+
except Exception as e:
|
66 |
+
print(f"Error in chat endpoint: {str(e)}")
|
67 |
+
return f"An error occurred: {str(e)}", 500
|
68 |
+
|
69 |
+
if __name__ == '__main__':
|
70 |
+
app.run(host='0.0.0.0', port=7860)
|
@@ -2,21 +2,17 @@ import flask
|
|
2 |
import base64
|
3 |
import tempfile
|
4 |
import traceback
|
5 |
-
import torch
|
6 |
from flask import Flask, Response, stream_with_context
|
7 |
from inference import OmniInference
|
8 |
|
9 |
|
10 |
class OmniChatServer(object):
|
11 |
def __init__(self, ip='0.0.0.0', port=60808, run_app=True,
|
12 |
-
ckpt_dir='./checkpoint', device=
|
13 |
server = Flask(__name__)
|
14 |
# CORS(server, resources=r"/*")
|
15 |
# server.config["JSON_AS_ASCII"] = False
|
16 |
|
17 |
-
if device is None:
|
18 |
-
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
19 |
-
|
20 |
self.client = OmniInference(ckpt_dir, device)
|
21 |
self.client.warm_up()
|
22 |
|
@@ -50,8 +46,9 @@ def create_app():
|
|
50 |
return server.server
|
51 |
|
52 |
|
53 |
-
def serve(ip='0.0.0.0', port=60808, device=
|
54 |
-
|
|
|
55 |
|
56 |
|
57 |
if __name__ == "__main__":
|
|
|
2 |
import base64
|
3 |
import tempfile
|
4 |
import traceback
|
|
|
5 |
from flask import Flask, Response, stream_with_context
|
6 |
from inference import OmniInference
|
7 |
|
8 |
|
9 |
class OmniChatServer(object):
|
10 |
def __init__(self, ip='0.0.0.0', port=60808, run_app=True,
|
11 |
+
ckpt_dir='./checkpoint', device='cuda:0') -> None:
|
12 |
server = Flask(__name__)
|
13 |
# CORS(server, resources=r"/*")
|
14 |
# server.config["JSON_AS_ASCII"] = False
|
15 |
|
|
|
|
|
|
|
16 |
self.client = OmniInference(ckpt_dir, device)
|
17 |
self.client.warm_up()
|
18 |
|
|
|
46 |
return server.server
|
47 |
|
48 |
|
49 |
+
def serve(ip='0.0.0.0', port=60808, device='cuda:0'):
|
50 |
+
|
51 |
+
OmniChatServer(ip, port=port,run_app=True, device=device)
|
52 |
|
53 |
|
54 |
if __name__ == "__main__":
|
@@ -1,258 +0,0 @@
|
|
1 |
-
<!DOCTYPE html>
|
2 |
-
<html lang="en">
|
3 |
-
<head>
|
4 |
-
<meta charset="UTF-8">
|
5 |
-
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
6 |
-
<title>Mini-Omni Chat Demo</title>
|
7 |
-
<style>
|
8 |
-
body {
|
9 |
-
background-color: black;
|
10 |
-
color: white;
|
11 |
-
font-family: Arial, sans-serif;
|
12 |
-
}
|
13 |
-
#chat-container {
|
14 |
-
height: 300px;
|
15 |
-
overflow-y: auto;
|
16 |
-
border: 1px solid #444;
|
17 |
-
padding: 10px;
|
18 |
-
margin-bottom: 10px;
|
19 |
-
}
|
20 |
-
#status-message {
|
21 |
-
margin-bottom: 10px;
|
22 |
-
}
|
23 |
-
button {
|
24 |
-
margin-right: 10px;
|
25 |
-
}
|
26 |
-
</style>
|
27 |
-
</head>
|
28 |
-
<body>
|
29 |
-
<div id="svg-container"></div>
|
30 |
-
<div id="chat-container"></div>
|
31 |
-
<div id="status-message">Current status: idle</div>
|
32 |
-
<button id="start-button">Start</button>
|
33 |
-
<button id="stop-button" disabled>Stop</button>
|
34 |
-
<main>
|
35 |
-
<p id="current-status">Current status: idle</p>
|
36 |
-
</main>
|
37 |
-
</body>
|
38 |
-
<script>
|
39 |
-
// Load the SVG
|
40 |
-
const svgContainer = document.getElementById('svg-container');
|
41 |
-
const svgContent = `
|
42 |
-
<svg width="800" height="600" viewBox="0 0 800 600" xmlns="http://www.w3.org/2000/svg">
|
43 |
-
<ellipse id="left-eye" cx="340" cy="200" rx="20" ry="20" fill="white"/>
|
44 |
-
<circle id="left-pupil" cx="340" cy="200" r="8" fill="black"/>
|
45 |
-
<ellipse id="right-eye" cx="460" cy="200" rx="20" ry="20" fill="white"/>
|
46 |
-
<circle id="right-pupil" cx="460" cy="200" r="8" fill="black"/>
|
47 |
-
<path id="upper-lip" d="M 300 300 C 350 284, 450 284, 500 300" stroke="white" stroke-width="10" fill="none"/>
|
48 |
-
<path id="lower-lip" d="M 300 300 C 350 316, 450 316, 500 300" stroke="white" stroke-width="10" fill="none"/>
|
49 |
-
</svg>`;
|
50 |
-
svgContainer.innerHTML = svgContent;
|
51 |
-
|
52 |
-
// Set up audio context
|
53 |
-
const audioContext = new (window.AudioContext || window.webkitAudioContext)();
|
54 |
-
const analyser = audioContext.createAnalyser();
|
55 |
-
analyser.fftSize = 256;
|
56 |
-
|
57 |
-
// Animation variables
|
58 |
-
let isAudioPlaying = false;
|
59 |
-
let lastBlinkTime = 0;
|
60 |
-
let eyeMovementOffset = { x: 0, y: 0 };
|
61 |
-
|
62 |
-
// Chat variables
|
63 |
-
let mediaRecorder;
|
64 |
-
let audioChunks = [];
|
65 |
-
let isRecording = false;
|
66 |
-
const API_URL = 'http://127.0.0.1:60808/chat';
|
67 |
-
|
68 |
-
// Idle eye animation function
|
69 |
-
function animateIdleEyes(timestamp) {
|
70 |
-
const leftEye = document.getElementById('left-eye');
|
71 |
-
const rightEye = document.getElementById('right-eye');
|
72 |
-
const leftPupil = document.getElementById('left-pupil');
|
73 |
-
const rightPupil = document.getElementById('right-pupil');
|
74 |
-
const baseEyeX = { left: 340, right: 460 };
|
75 |
-
const baseEyeY = 200;
|
76 |
-
|
77 |
-
// Blink effect
|
78 |
-
const blinkInterval = 4000 + Math.random() * 2000; // Random blink interval between 4-6 seconds
|
79 |
-
if (timestamp - lastBlinkTime > blinkInterval) {
|
80 |
-
leftEye.setAttribute('ry', '2');
|
81 |
-
rightEye.setAttribute('ry', '2');
|
82 |
-
leftPupil.setAttribute('ry', '0.8');
|
83 |
-
rightPupil.setAttribute('ry', '0.8');
|
84 |
-
setTimeout(() => {
|
85 |
-
leftEye.setAttribute('ry', '20');
|
86 |
-
rightEye.setAttribute('ry', '20');
|
87 |
-
leftPupil.setAttribute('ry', '8');
|
88 |
-
rightPupil.setAttribute('ry', '8');
|
89 |
-
}, 150);
|
90 |
-
lastBlinkTime = timestamp;
|
91 |
-
}
|
92 |
-
|
93 |
-
// Subtle eye movement
|
94 |
-
const movementSpeed = 0.001;
|
95 |
-
eyeMovementOffset.x = Math.sin(timestamp * movementSpeed) * 6;
|
96 |
-
eyeMovementOffset.y = Math.cos(timestamp * movementSpeed * 1.3) * 1; // Reduced vertical movement
|
97 |
-
|
98 |
-
leftEye.setAttribute('cx', baseEyeX.left + eyeMovementOffset.x);
|
99 |
-
leftEye.setAttribute('cy', baseEyeY + eyeMovementOffset.y);
|
100 |
-
rightEye.setAttribute('cx', baseEyeX.right + eyeMovementOffset.x);
|
101 |
-
rightEye.setAttribute('cy', baseEyeY + eyeMovementOffset.y);
|
102 |
-
leftPupil.setAttribute('cx', baseEyeX.left + eyeMovementOffset.x);
|
103 |
-
leftPupil.setAttribute('cy', baseEyeY + eyeMovementOffset.y);
|
104 |
-
rightPupil.setAttribute('cx', baseEyeX.right + eyeMovementOffset.x);
|
105 |
-
rightPupil.setAttribute('cy', baseEyeY + eyeMovementOffset.y);
|
106 |
-
}
|
107 |
-
|
108 |
-
// Main animation function
|
109 |
-
function animate(timestamp) {
|
110 |
-
if (isAudioPlaying) {
|
111 |
-
const dataArray = new Uint8Array(analyser.frequencyBinCount);
|
112 |
-
analyser.getByteFrequencyData(dataArray);
|
113 |
-
|
114 |
-
// Calculate the average amplitude in the speech frequency range
|
115 |
-
const speechRange = dataArray.slice(5, 80); // Adjust based on your needs
|
116 |
-
const averageAmplitude = speechRange.reduce((a, b) => a + b) / speechRange.length;
|
117 |
-
|
118 |
-
// Normalize the amplitude (0-1 range)
|
119 |
-
const normalizedAmplitude = averageAmplitude / 255;
|
120 |
-
|
121 |
-
// Animate mouth
|
122 |
-
const upperLip = document.getElementById('upper-lip');
|
123 |
-
const lowerLip = document.getElementById('lower-lip');
|
124 |
-
const baseY = 300;
|
125 |
-
const maxMovement = 60;
|
126 |
-
const newUpperY = baseY - normalizedAmplitude * maxMovement;
|
127 |
-
const newLowerY = baseY + normalizedAmplitude * maxMovement;
|
128 |
-
|
129 |
-
// Adjust control points for more natural movement
|
130 |
-
const upperControlY1 = newUpperY - 8;
|
131 |
-
const upperControlY2 = newUpperY - 8;
|
132 |
-
const lowerControlY1 = newLowerY + 8;
|
133 |
-
const lowerControlY2 = newLowerY + 8;
|
134 |
-
|
135 |
-
upperLip.setAttribute('d', `M 300 ${baseY} C 350 ${upperControlY1}, 450 ${upperControlY2}, 500 ${baseY}`);
|
136 |
-
lowerLip.setAttribute('d', `M 300 ${baseY} C 350 ${lowerControlY1}, 450 ${lowerControlY2}, 500 ${baseY}`);
|
137 |
-
|
138 |
-
// Animate eyes
|
139 |
-
const leftEye = document.getElementById('left-eye');
|
140 |
-
const rightEye = document.getElementById('right-eye');
|
141 |
-
const leftPupil = document.getElementById('left-pupil');
|
142 |
-
const rightPupil = document.getElementById('right-pupil');
|
143 |
-
const baseEyeY = 200;
|
144 |
-
const maxEyeMovement = 10;
|
145 |
-
const newEyeY = baseEyeY - normalizedAmplitude * maxEyeMovement;
|
146 |
-
|
147 |
-
leftEye.setAttribute('cy', newEyeY);
|
148 |
-
rightEye.setAttribute('cy', newEyeY);
|
149 |
-
leftPupil.setAttribute('cy', newEyeY);
|
150 |
-
rightPupil.setAttribute('cy', newEyeY);
|
151 |
-
} else {
|
152 |
-
animateIdleEyes(timestamp);
|
153 |
-
}
|
154 |
-
|
155 |
-
requestAnimationFrame(animate);
|
156 |
-
}
|
157 |
-
|
158 |
-
// Start animation
|
159 |
-
animate();
|
160 |
-
|
161 |
-
// Chat functions
|
162 |
-
function startRecording() {
|
163 |
-
navigator.mediaDevices.getUserMedia({ audio: true })
|
164 |
-
.then(stream => {
|
165 |
-
mediaRecorder = new MediaRecorder(stream);
|
166 |
-
mediaRecorder.ondataavailable = event => {
|
167 |
-
audioChunks.push(event.data);
|
168 |
-
};
|
169 |
-
mediaRecorder.onstop = sendAudioToServer;
|
170 |
-
mediaRecorder.start();
|
171 |
-
isRecording = true;
|
172 |
-
updateStatus('Recording...');
|
173 |
-
document.getElementById('start-button').disabled = true;
|
174 |
-
document.getElementById('stop-button').disabled = false;
|
175 |
-
})
|
176 |
-
.catch(error => {
|
177 |
-
console.error('Error accessing microphone:', error);
|
178 |
-
updateStatus('Error: ' + error.message);
|
179 |
-
});
|
180 |
-
}
|
181 |
-
|
182 |
-
function stopRecording() {
|
183 |
-
if (mediaRecorder && isRecording) {
|
184 |
-
mediaRecorder.stop();
|
185 |
-
isRecording = false;
|
186 |
-
updateStatus('Processing...');
|
187 |
-
document.getElementById('start-button').disabled = false;
|
188 |
-
document.getElementById('stop-button').disabled = true;
|
189 |
-
}
|
190 |
-
}
|
191 |
-
|
192 |
-
function sendAudioToServer() {
|
193 |
-
const audioBlob = new Blob(audioChunks, { type: 'audio/wav' });
|
194 |
-
const reader = new FileReader();
|
195 |
-
reader.readAsDataURL(audioBlob);
|
196 |
-
reader.onloadend = function() {
|
197 |
-
const base64Audio = reader.result.split(',')[1];
|
198 |
-
fetch(API_URL, {
|
199 |
-
method: 'POST',
|
200 |
-
headers: {
|
201 |
-
'Content-Type': 'application/json',
|
202 |
-
},
|
203 |
-
body: JSON.stringify({ audio: base64Audio }),
|
204 |
-
})
|
205 |
-
.then(response => response.blob())
|
206 |
-
.then(blob => {
|
207 |
-
const audioUrl = URL.createObjectURL(blob);
|
208 |
-
playResponseAudio(audioUrl);
|
209 |
-
updateChatHistory('User', 'Audio message sent');
|
210 |
-
updateChatHistory('Assistant', 'Audio response received');
|
211 |
-
})
|
212 |
-
.catch(error => {
|
213 |
-
console.error('Error:', error);
|
214 |
-
updateStatus('Error: ' + error.message);
|
215 |
-
});
|
216 |
-
};
|
217 |
-
audioChunks = [];
|
218 |
-
}
|
219 |
-
|
220 |
-
function playResponseAudio(audioUrl) {
|
221 |
-
const audio = new Audio(audioUrl);
|
222 |
-
audio.onloadedmetadata = () => {
|
223 |
-
const source = audioContext.createMediaElementSource(audio);
|
224 |
-
source.connect(analyser);
|
225 |
-
analyser.connect(audioContext.destination);
|
226 |
-
};
|
227 |
-
audio.onplay = () => {
|
228 |
-
isAudioPlaying = true;
|
229 |
-
updateStatus('Playing response...');
|
230 |
-
};
|
231 |
-
audio.onended = () => {
|
232 |
-
isAudioPlaying = false;
|
233 |
-
updateStatus('Idle');
|
234 |
-
};
|
235 |
-
audio.play();
|
236 |
-
}
|
237 |
-
|
238 |
-
function updateChatHistory(role, message) {
|
239 |
-
const chatContainer = document.getElementById('chat-container');
|
240 |
-
const messageElement = document.createElement('p');
|
241 |
-
messageElement.textContent = `${role}: ${message}`;
|
242 |
-
chatContainer.appendChild(messageElement);
|
243 |
-
chatContainer.scrollTop = chatContainer.scrollHeight;
|
244 |
-
}
|
245 |
-
|
246 |
-
function updateStatus(status) {
|
247 |
-
document.getElementById('status-message').textContent = status;
|
248 |
-
document.getElementById('current-status').textContent = 'Current status: ' + status;
|
249 |
-
}
|
250 |
-
|
251 |
-
// Event listeners
|
252 |
-
document.getElementById('start-button').addEventListener('click', startRecording);
|
253 |
-
document.getElementById('stop-button').addEventListener('click', stopRecording);
|
254 |
-
|
255 |
-
// Initialize
|
256 |
-
updateStatus('Idle');
|
257 |
-
</script>
|
258 |
-
</html>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -21,7 +21,7 @@
|
|
21 |
<audio id="audioPlayback" controls style="display:none;"></audio>
|
22 |
|
23 |
<script>
|
24 |
-
const API_URL = '
|
25 |
const recordButton = document.getElementById('recordButton');
|
26 |
const chatHistory = document.getElementById('chatHistory');
|
27 |
const audioPlayback = document.getElementById('audioPlayback');
|
@@ -86,12 +86,13 @@
|
|
86 |
}
|
87 |
});
|
88 |
|
89 |
-
const
|
90 |
-
const
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
|
|
95 |
} else {
|
96 |
console.error('API response not ok:', response.status);
|
97 |
updateChatHistory('AI', 'Error in API response');
|
@@ -99,7 +100,11 @@
|
|
99 |
};
|
100 |
} catch (error) {
|
101 |
console.error('Error sending audio to API:', error);
|
102 |
-
|
|
|
|
|
|
|
|
|
103 |
}
|
104 |
}
|
105 |
|
|
|
21 |
<audio id="audioPlayback" controls style="display:none;"></audio>
|
22 |
|
23 |
<script>
|
24 |
+
const API_URL = '/chat';
|
25 |
const recordButton = document.getElementById('recordButton');
|
26 |
const chatHistory = document.getElementById('chatHistory');
|
27 |
const audioPlayback = document.getElementById('audioPlayback');
|
|
|
86 |
}
|
87 |
});
|
88 |
|
89 |
+
const responseBlob = await new Response(stream).blob();
|
90 |
+
const audioUrl = URL.createObjectURL(responseBlob);
|
91 |
+
updateChatHistory('AI', audioUrl);
|
92 |
+
|
93 |
+
// Play the audio response
|
94 |
+
const audio = new Audio(audioUrl);
|
95 |
+
audio.play();
|
96 |
} else {
|
97 |
console.error('API response not ok:', response.status);
|
98 |
updateChatHistory('AI', 'Error in API response');
|
|
|
100 |
};
|
101 |
} catch (error) {
|
102 |
console.error('Error sending audio to API:', error);
|
103 |
+
if (error.name === 'TypeError' && error.message === 'Failed to fetch') {
|
104 |
+
updateChatHistory('AI', 'Error: Unable to connect to the server. Please ensure the server is running and accessible.');
|
105 |
+
} else {
|
106 |
+
updateChatHistory('AI', 'Error communicating with the server: ' + error.message);
|
107 |
+
}
|
108 |
}
|
109 |
}
|
110 |
|
@@ -1,257 +1,134 @@
|
|
1 |
-
import streamlit as st
|
2 |
-
import
|
3 |
-
|
4 |
-
|
5 |
-
import
|
6 |
-
import
|
7 |
-
import
|
8 |
-
import
|
9 |
-
import
|
10 |
-
import
|
11 |
-
import
|
12 |
-
import
|
13 |
-
import
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
stream.write(audio_data)
|
136 |
-
except Exception as e:
|
137 |
-
st.error(f"Error during audio streaming: {e}")
|
138 |
-
|
139 |
-
out_file = save_tmp_audio(output_audio_bytes)
|
140 |
-
with st.chat_message("assistant"):
|
141 |
-
st.audio(out_file, format="audio/wav", loop=False, autoplay=False)
|
142 |
-
st.session_state.messages.append(
|
143 |
-
{"role": "assistant", "content": out_file, "type": "audio"}
|
144 |
-
)
|
145 |
-
|
146 |
-
wf.close()
|
147 |
-
# Close PyAudio stream and terminate PyAudio
|
148 |
-
stream.stop_stream()
|
149 |
-
stream.close()
|
150 |
-
p.terminate()
|
151 |
-
st.session_state.speaking = False
|
152 |
-
st.session_state.recording = True
|
153 |
-
|
154 |
-
|
155 |
-
def recording(status):
|
156 |
-
audio = pyaudio.PyAudio()
|
157 |
-
|
158 |
-
stream = audio.open(
|
159 |
-
format=IN_FORMAT,
|
160 |
-
channels=IN_CHANNELS,
|
161 |
-
rate=IN_RATE,
|
162 |
-
input=True,
|
163 |
-
frames_per_buffer=IN_CHUNK,
|
164 |
-
)
|
165 |
-
|
166 |
-
temp_audio = b""
|
167 |
-
vad_audio = b""
|
168 |
-
|
169 |
-
start_talking = False
|
170 |
-
last_temp_audio = None
|
171 |
-
st.session_state.frames = []
|
172 |
-
|
173 |
-
while st.session_state.recording:
|
174 |
-
status.success("Listening...")
|
175 |
-
audio_bytes = stream.read(IN_CHUNK)
|
176 |
-
temp_audio += audio_bytes
|
177 |
-
|
178 |
-
if len(temp_audio) > IN_SAMPLE_WIDTH * IN_RATE * IN_CHANNELS * VAD_STRIDE:
|
179 |
-
dur_vad, vad_audio_bytes, time_vad = run_vad(temp_audio, IN_RATE)
|
180 |
-
|
181 |
-
print(f"duration_after_vad: {dur_vad:.3f} s, time_vad: {time_vad:.3f} s")
|
182 |
-
|
183 |
-
if dur_vad > 0.2 and not start_talking:
|
184 |
-
if last_temp_audio is not None:
|
185 |
-
st.session_state.frames.append(last_temp_audio)
|
186 |
-
start_talking = True
|
187 |
-
if start_talking:
|
188 |
-
st.session_state.frames.append(temp_audio)
|
189 |
-
if dur_vad < 0.1 and start_talking:
|
190 |
-
st.session_state.recording = False
|
191 |
-
print(f"speech end detected. excit")
|
192 |
-
last_temp_audio = temp_audio
|
193 |
-
temp_audio = b""
|
194 |
-
|
195 |
-
stream.stop_stream()
|
196 |
-
stream.close()
|
197 |
-
|
198 |
-
audio.terminate()
|
199 |
-
|
200 |
-
|
201 |
-
def main():
|
202 |
-
|
203 |
-
st.title("Chat Mini-Omni Demo")
|
204 |
-
status = st.empty()
|
205 |
-
|
206 |
-
if "warm_up" not in st.session_state:
|
207 |
-
warm_up()
|
208 |
-
st.session_state.warm_up = True
|
209 |
-
if "start" not in st.session_state:
|
210 |
-
st.session_state.start = False
|
211 |
-
if "recording" not in st.session_state:
|
212 |
-
st.session_state.recording = False
|
213 |
-
if "speaking" not in st.session_state:
|
214 |
-
st.session_state.speaking = False
|
215 |
-
if "frames" not in st.session_state:
|
216 |
-
st.session_state.frames = []
|
217 |
-
|
218 |
-
if not st.session_state.start:
|
219 |
-
status.warning("Click Start to chat")
|
220 |
-
|
221 |
-
start_col, stop_col, _ = st.columns([0.2, 0.2, 0.6])
|
222 |
-
start_button = start_col.button("Start", key="start_button")
|
223 |
-
# stop_button = stop_col.button("Stop", key="stop_button")
|
224 |
-
if start_button:
|
225 |
-
time.sleep(1)
|
226 |
-
st.session_state.recording = True
|
227 |
-
st.session_state.start = True
|
228 |
-
|
229 |
-
for message in st.session_state.messages:
|
230 |
-
with st.chat_message(message["role"]):
|
231 |
-
if message["type"] == "msg":
|
232 |
-
st.markdown(message["content"])
|
233 |
-
elif message["type"] == "img":
|
234 |
-
st.image(message["content"], width=300)
|
235 |
-
elif message["type"] == "audio":
|
236 |
-
st.audio(
|
237 |
-
message["content"], format="audio/wav", loop=False, autoplay=False
|
238 |
-
)
|
239 |
-
|
240 |
-
while st.session_state.start:
|
241 |
-
if st.session_state.recording:
|
242 |
-
recording(status)
|
243 |
-
|
244 |
-
if not st.session_state.recording and st.session_state.start:
|
245 |
-
st.session_state.speaking = True
|
246 |
-
speaking(status)
|
247 |
-
|
248 |
-
# if stop_button:
|
249 |
-
# status.warning("Stopped, click Start to chat")
|
250 |
-
# st.session_state.start = False
|
251 |
-
# st.session_state.recording = False
|
252 |
-
# st.session_state.frames = []
|
253 |
-
# break
|
254 |
-
|
255 |
-
|
256 |
-
if __name__ == "__main__":
|
257 |
-
main()
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import numpy as np
|
3 |
+
import requests
|
4 |
+
import base64
|
5 |
+
import tempfile
|
6 |
+
import os
|
7 |
+
import time
|
8 |
+
import traceback
|
9 |
+
import librosa
|
10 |
+
from pydub import AudioSegment
|
11 |
+
from streamlit_webrtc import webrtc_streamer, WebRtcMode, RTCConfiguration
|
12 |
+
import av
|
13 |
+
from utils.vad import get_speech_timestamps, collect_chunks, VadOptions
|
14 |
+
|
15 |
+
API_URL = os.getenv("API_URL", "http://127.0.0.1:60808/chat")
|
16 |
+
|
17 |
+
# Initialize chat history
|
18 |
+
if "messages" not in st.session_state:
|
19 |
+
st.session_state.messages = []
|
20 |
+
|
21 |
+
def run_vad(audio, sr):
|
22 |
+
_st = time.time()
|
23 |
+
try:
|
24 |
+
audio = audio.astype(np.float32) / 32768.0
|
25 |
+
sampling_rate = 16000
|
26 |
+
if sr != sampling_rate:
|
27 |
+
audio = librosa.resample(audio, orig_sr=sr, target_sr=sampling_rate)
|
28 |
+
|
29 |
+
vad_parameters = {}
|
30 |
+
vad_parameters = VadOptions(**vad_parameters)
|
31 |
+
speech_chunks = get_speech_timestamps(audio, vad_parameters)
|
32 |
+
audio = collect_chunks(audio, speech_chunks)
|
33 |
+
duration_after_vad = audio.shape[0] / sampling_rate
|
34 |
+
|
35 |
+
if sr != sampling_rate:
|
36 |
+
# resample to original sampling rate
|
37 |
+
vad_audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=sr)
|
38 |
+
else:
|
39 |
+
vad_audio = audio
|
40 |
+
vad_audio = np.round(vad_audio * 32768.0).astype(np.int16)
|
41 |
+
vad_audio_bytes = vad_audio.tobytes()
|
42 |
+
|
43 |
+
return duration_after_vad, vad_audio_bytes, round(time.time() - _st, 4)
|
44 |
+
except Exception as e:
|
45 |
+
msg = f"[asr vad error] audio_len: {len(audio)/(sr):.3f} s, trace: {traceback.format_exc()}"
|
46 |
+
print(msg)
|
47 |
+
return -1, audio.tobytes(), round(time.time() - _st, 4)
|
48 |
+
|
49 |
+
def save_tmp_audio(audio_bytes):
|
50 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
|
51 |
+
file_name = tmpfile.name
|
52 |
+
audio = AudioSegment(
|
53 |
+
data=audio_bytes,
|
54 |
+
sample_width=2,
|
55 |
+
frame_rate=16000,
|
56 |
+
channels=1,
|
57 |
+
)
|
58 |
+
audio.export(file_name, format="wav")
|
59 |
+
return file_name
|
60 |
+
|
61 |
+
def main():
|
62 |
+
st.title("Chat Mini-Omni Demo")
|
63 |
+
status = st.empty()
|
64 |
+
|
65 |
+
if "audio_buffer" not in st.session_state:
|
66 |
+
st.session_state.audio_buffer = []
|
67 |
+
|
68 |
+
webrtc_ctx = webrtc_streamer(
|
69 |
+
key="speech-to-text",
|
70 |
+
mode=WebRtcMode.SENDONLY,
|
71 |
+
audio_receiver_size=1024,
|
72 |
+
rtc_configuration=RTCConfiguration(
|
73 |
+
{"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
|
74 |
+
),
|
75 |
+
media_stream_constraints={"video": False, "audio": True},
|
76 |
+
)
|
77 |
+
|
78 |
+
if webrtc_ctx.audio_receiver:
|
79 |
+
while True:
|
80 |
+
try:
|
81 |
+
audio_frame = webrtc_ctx.audio_receiver.get_frame(timeout=1)
|
82 |
+
sound_chunk = np.frombuffer(audio_frame.to_ndarray(), dtype="int16")
|
83 |
+
st.session_state.audio_buffer.extend(sound_chunk)
|
84 |
+
|
85 |
+
if len(st.session_state.audio_buffer) >= 16000:
|
86 |
+
duration_after_vad, vad_audio_bytes, vad_time = run_vad(
|
87 |
+
np.array(st.session_state.audio_buffer), 16000
|
88 |
+
)
|
89 |
+
st.session_state.audio_buffer = []
|
90 |
+
if duration_after_vad > 0:
|
91 |
+
st.session_state.messages.append(
|
92 |
+
{"role": "user", "content": "User audio"}
|
93 |
+
)
|
94 |
+
file_name = save_tmp_audio(vad_audio_bytes)
|
95 |
+
st.audio(file_name, format="audio/wav")
|
96 |
+
|
97 |
+
response = requests.post(API_URL, data=vad_audio_bytes)
|
98 |
+
assistant_audio_bytes = response.content
|
99 |
+
assistant_file_name = save_tmp_audio(assistant_audio_bytes)
|
100 |
+
st.audio(assistant_file_name, format="audio/wav")
|
101 |
+
st.session_state.messages.append(
|
102 |
+
{"role": "assistant", "content": "Assistant response"}
|
103 |
+
)
|
104 |
+
except Exception as e:
|
105 |
+
print(f"Error in audio processing: {e}")
|
106 |
+
break
|
107 |
+
|
108 |
+
if st.button("Process Audio"):
|
109 |
+
if st.session_state.audio_buffer:
|
110 |
+
duration_after_vad, vad_audio_bytes, vad_time = run_vad(
|
111 |
+
np.array(st.session_state.audio_buffer), 16000
|
112 |
+
)
|
113 |
+
st.session_state.messages.append({"role": "user", "content": "User audio"})
|
114 |
+
file_name = save_tmp_audio(vad_audio_bytes)
|
115 |
+
st.audio(file_name, format="audio/wav")
|
116 |
+
|
117 |
+
response = requests.post(API_URL, data=vad_audio_bytes)
|
118 |
+
assistant_audio_bytes = response.content
|
119 |
+
assistant_file_name = save_tmp_audio(assistant_audio_bytes)
|
120 |
+
st.audio(assistant_file_name, format="audio/wav")
|
121 |
+
st.session_state.messages.append(
|
122 |
+
{"role": "assistant", "content": "Assistant response"}
|
123 |
+
)
|
124 |
+
st.session_state.audio_buffer = []
|
125 |
+
|
126 |
+
if st.session_state.messages:
|
127 |
+
for message in st.session_state.messages:
|
128 |
+
if message["role"] == "user":
|
129 |
+
st.write(f"User: {message['content']}")
|
130 |
+
else:
|
131 |
+
st.write(f"Assistant: {message['content']}")
|
132 |
+
|
133 |
+
if __name__ == "__main__":
|
134 |
+
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|