victor HF staff commited on
Commit
9616027
1 Parent(s): 040d99a

feat: Update Dockerfile and requirements.txt to resolve PyAudio build issues

Browse files

The main changes are:

1. Added `build-essential` and `libasound2-dev` to the system dependencies in the Dockerfile to ensure the necessary build tools are available.
2. Removed PyAudio from the `requirements.txt` file to avoid the pip installation issues.
3. Added a separate `RUN pip install PyAudio==0.2.14` command in the Dockerfile to install PyAudio manually.

These changes should resolve the build issues with PyAudio on the CUDA server.

Revert "fix: Handle CUDA availability in OmniChatServer"

This reverts commit 28ed763269f75cea8298b3d64449fd7776d05f52.

docs: add PyAudio to dependencies

feat: Replace PyAudio with streamlit-webrtc for user recording

fix: Replace PyAudio with streamlit-webrtc for audio recording

feat: Serve HTML demo instead of Streamlit app

fix: Update API_URL and error handling in webui/omni_html_demo.html

fix: Replace audio playback with text-to-speech

feat: Implement audio processing and response generation

fix: Use a Docker data volume for caching

feat: Add Docker data volume and environment variables for caching

diff --git a/inference.py b/inference.py
index 4d4d4d1..d4d4d1a 100644
--- a/inference.py
+++ b/inference.py
@@ -1,6 +1,7 @@
def download_model(ckpt_dir):
repo_id = "gpt-omni/mini-omni"
- snapshot_download(repo_id, local_dir=ckpt_dir, revision="main")
+ cache_dir = os.environ.get('XDG_CACHE_HOME', '/tmp')
+ snapshot_download(repo_id, local_dir=ckpt_dir, revision="main", cache_dir=cache_dir)

rm

fix: Remove cache-related code and update Dockerfile

fix: Add Docker volume and set permissions for model download

fix: Set correct permissions for checkpoint directory

feat: Use DATA volume to store model checkpoint

fix: Set permissions and create necessary subdirectories in the DATA volume

fix: Implement error handling and CUDA Tensor Cores optimization in serve_html.py

fix: Improve error handling and logging in chat endpoint

Files changed (9) hide show
  1. Dockerfile +23 -11
  2. README.md +125 -124
  3. inference.py +7 -5
  4. requirements.txt +8 -2
  5. serve_html.py +70 -0
  6. server.py +4 -7
  7. webui/index.html +0 -258
  8. webui/omni_html_demo.html +13 -8
  9. webui/omni_streamlit.py +134 -257
Dockerfile CHANGED
@@ -7,7 +7,6 @@ WORKDIR /app
7
  # Install system dependencies
8
  RUN apt-get update && apt-get install -y \
9
  ffmpeg \
10
- portaudio19-dev \
11
  && rm -rf /var/lib/apt/lists/*
12
 
13
  # Copy the current directory contents into the container at /app
@@ -16,20 +15,33 @@ COPY . /app
16
  # Install any needed packages specified in requirements.txt
17
  RUN pip install --no-cache-dir -r requirements.txt
18
 
19
- # Install PyAudio
20
- RUN pip install PyAudio==0.2.14
21
-
22
- # Make ports 7860 and 60808 available to the world outside this container
23
- EXPOSE 7860 60808
24
 
25
  # Set environment variable for API_URL
26
- ENV API_URL=http://0.0.0.0:60808/chat
27
 
28
  # Set PYTHONPATH
29
  ENV PYTHONPATH=./
30
 
31
- # Make start.sh executable
32
- RUN chmod +x start.sh
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- # Run start.sh when the container launches
35
- CMD ["./start.sh"]
 
7
  # Install system dependencies
8
  RUN apt-get update && apt-get install -y \
9
  ffmpeg \
 
10
  && rm -rf /var/lib/apt/lists/*
11
 
12
  # Copy the current directory contents into the container at /app
 
15
  # Install any needed packages specified in requirements.txt
16
  RUN pip install --no-cache-dir -r requirements.txt
17
 
18
+ # Make port 7860 available to the world outside this container
19
+ EXPOSE 7860
 
 
 
20
 
21
  # Set environment variable for API_URL
22
+ ENV API_URL=http://0.0.0.0:7860/chat
23
 
24
  # Set PYTHONPATH
25
  ENV PYTHONPATH=./
26
 
27
+ # Set environment variables
28
+ ENV MPLCONFIGDIR=/tmp/matplotlib
29
+ ENV HF_HOME=/data/huggingface
30
+ ENV XDG_CACHE_HOME=/data/cache
31
+
32
+ # Create a volume for data
33
+ VOLUME /data
34
+
35
+ # Set permissions for the /data directory and create necessary subdirectories
36
+ RUN mkdir -p /data/checkpoint /data/cache /data/huggingface && \
37
+ chown -R 1000:1000 /data && \
38
+ chmod -R 777 /data
39
+
40
+ # Install Flask
41
+ RUN pip install flask
42
+
43
+ # Copy the HTML demo file
44
+ COPY webui/omni_html_demo.html .
45
 
46
+ # Run the Flask app to serve the HTML demo
47
+ CMD ["python", "serve_html.py"]
README.md CHANGED
@@ -1,124 +1,125 @@
1
- ---
2
- title: Omni Docker
3
- emoji: 🦀
4
- colorFrom: green
5
- colorTo: red
6
- sdk: docker
7
- pinned: false
8
- ---
9
-
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
11
-
12
- # Mini-Omni
13
-
14
- <p align="center"><strong style="font-size: 18px;">
15
- Mini-Omni: Language Models Can Hear, Talk While Thinking in Streaming
16
- </strong>
17
- </p>
18
-
19
- <p align="center">
20
- 🤗 <a href="https://huggingface.co/gpt-omni/mini-omni">Hugging Face</a> | 📖 <a href="https://github.com/gpt-omni/mini-omni">Github</a>
21
- | 📑 <a href="https://arxiv.org/abs/2408.16725">Technical report</a>
22
- </p>
23
-
24
- Mini-Omni is an open-source multimodal large language model that can **hear, talk while thinking**. Featuring real-time end-to-end speech input and **streaming audio output** conversational capabilities.
25
-
26
- <p align="center">
27
- <img src="data/figures/frameworkv3.jpg" width="100%"/>
28
- </p>
29
-
30
-
31
- ## Features
32
-
33
- ✅ **Real-time speech-to-speech** conversational capabilities. No extra ASR or TTS models required.
34
-
35
- ✅ **Talking while thinking**, with the ability to generate text and audio at the same time.
36
-
37
- ✅ **Streaming audio output** capabilities.
38
-
39
- ✅ With "Audio-to-Text" and "Audio-to-Audio" **batch inference** to further boost the performance.
40
-
41
- ## Demo
42
-
43
- NOTE: need to unmute first.
44
-
45
- https://github.com/user-attachments/assets/03bdde05-9514-4748-b527-003bea57f118
46
-
47
-
48
- ## Install
49
-
50
- Create a new conda environment and install the required packages:
51
-
52
- ```sh
53
- conda create -n omni python=3.10
54
- conda activate omni
55
-
56
- git clone https://github.com/gpt-omni/mini-omni.git
57
- cd mini-omni
58
- pip install -r requirements.txt
59
- ```
60
-
61
- ## Quick start
62
-
63
- **Interactive demo**
64
-
65
- - start server
66
-
67
- NOTE: you need to start the server before running the streamlit or gradio demo with API_URL set to the server address.
68
-
69
- ```sh
70
- sudo apt-get install ffmpeg
71
- conda activate omni
72
- cd mini-omni
73
- python3 server.py --ip '0.0.0.0' --port 60808
74
- ```
75
-
76
-
77
- - run streamlit demo
78
-
79
- NOTE: you need to run streamlit locally with PyAudio installed. For error: `ModuleNotFoundError: No module named 'utils.vad'`, please run `export PYTHONPATH=./` first.
80
-
81
- ```sh
82
- pip install PyAudio==0.2.14
83
- API_URL=http://0.0.0.0:60808/chat streamlit run webui/omni_streamlit.py
84
- ```
85
-
86
- - run gradio demo
87
- ```sh
88
- API_URL=http://0.0.0.0:60808/chat python3 webui/omni_gradio.py
89
- ```
90
-
91
- example:
92
-
93
- NOTE: need to unmute first. Gradio seems can not play audio stream instantly, so the latency feels a bit longer.
94
-
95
- https://github.com/user-attachments/assets/29187680-4c42-47ff-b352-f0ea333496d9
96
-
97
-
98
- **Local test**
99
-
100
- ```sh
101
- conda activate omni
102
- cd mini-omni
103
- # test run the preset audio samples and questions
104
- python inference.py
105
- ```
106
-
107
- ## Common issues
108
-
109
- - Error: `ModuleNotFoundError: No module named 'utils.xxxx'`
110
-
111
- Answer: run `export PYTHONPATH=./` first.
112
-
113
- ## Acknowledgements
114
-
115
- - [Qwen2](https://github.com/QwenLM/Qwen2/) as the LLM backbone.
116
- - [litGPT](https://github.com/Lightning-AI/litgpt/) for training and inference.
117
- - [whisper](https://github.com/openai/whisper/) for audio encoding.
118
- - [snac](https://github.com/hubertsiuzdak/snac/) for audio decoding.
119
- - [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) for generating synthetic speech.
120
- - [OpenOrca](https://huggingface.co/datasets/Open-Orca/OpenOrca) and [MOSS](https://github.com/OpenMOSS/MOSS/tree/main) for alignment.
121
-
122
- ## Star History
123
-
124
- [![Star History Chart](https://api.star-history.com/svg?repos=gpt-omni/mini-omni&type=Date)](https://star-history.com/#gpt-omni/mini-omni&Date)
 
 
1
+ ---
2
+ title: Omni Docker
3
+ emoji: 🦀
4
+ colorFrom: green
5
+ colorTo: red
6
+ sdk: docker
7
+ pinned: false
8
+ ---
9
+
10
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
11
+
12
+ # Mini-Omni
13
+
14
+ <p align="center"><strong style="font-size: 18px;">
15
+ Mini-Omni: Language Models Can Hear, Talk While Thinking in Streaming
16
+ </strong>
17
+ </p>
18
+
19
+ <p align="center">
20
+ 🤗 <a href="https://huggingface.co/gpt-omni/mini-omni">Hugging Face</a> | 📖 <a href="https://github.com/gpt-omni/mini-omni">Github</a>
21
+ | 📑 <a href="https://arxiv.org/abs/2408.16725">Technical report</a>
22
+ </p>
23
+
24
+ Mini-Omni is an open-source multimodal large language model that can **hear, talk while thinking**. Featuring real-time end-to-end speech input and **streaming audio output** conversational capabilities.
25
+
26
+ <p align="center">
27
+ <img src="data/figures/frameworkv3.jpg" width="100%"/>
28
+ </p>
29
+
30
+
31
+ ## Features
32
+
33
+ ✅ **Real-time speech-to-speech** conversational capabilities. No extra ASR or TTS models required.
34
+
35
+ ✅ **Talking while thinking**, with the ability to generate text and audio at the same time.
36
+
37
+ ✅ **Streaming audio output** capabilities.
38
+
39
+ ✅ With "Audio-to-Text" and "Audio-to-Audio" **batch inference** to further boost the performance.
40
+
41
+ ## Demo
42
+
43
+ NOTE: need to unmute first.
44
+
45
+ https://github.com/user-attachments/assets/03bdde05-9514-4748-b527-003bea57f118
46
+
47
+
48
+ ## Install
49
+
50
+ Create a new conda environment and install the required packages:
51
+
52
+ ```sh
53
+ conda create -n omni python=3.10
54
+ conda activate omni
55
+
56
+ git clone https://github.com/gpt-omni/mini-omni.git
57
+ cd mini-omni
58
+ pip install -r requirements.txt
59
+ pip install PyAudio==0.2.14
60
+ ```
61
+
62
+ ## Quick start
63
+
64
+ **Interactive demo**
65
+
66
+ - start server
67
+
68
+ NOTE: you need to start the server before running the streamlit or gradio demo with API_URL set to the server address.
69
+
70
+ ```sh
71
+ sudo apt-get install ffmpeg
72
+ conda activate omni
73
+ cd mini-omni
74
+ python3 server.py --ip '0.0.0.0' --port 60808
75
+ ```
76
+
77
+
78
+ - run streamlit demo
79
+
80
+ NOTE: you need to run streamlit locally with PyAudio installed. For error: `ModuleNotFoundError: No module named 'utils.vad'`, please run `export PYTHONPATH=./` first.
81
+
82
+ ```sh
83
+ pip install PyAudio==0.2.14
84
+ API_URL=http://0.0.0.0:60808/chat streamlit run webui/omni_streamlit.py
85
+ ```
86
+
87
+ - run gradio demo
88
+ ```sh
89
+ API_URL=http://0.0.0.0:60808/chat python3 webui/omni_gradio.py
90
+ ```
91
+
92
+ example:
93
+
94
+ NOTE: need to unmute first. Gradio seems can not play audio stream instantly, so the latency feels a bit longer.
95
+
96
+ https://github.com/user-attachments/assets/29187680-4c42-47ff-b352-f0ea333496d9
97
+
98
+
99
+ **Local test**
100
+
101
+ ```sh
102
+ conda activate omni
103
+ cd mini-omni
104
+ # test run the preset audio samples and questions
105
+ python inference.py
106
+ ```
107
+
108
+ ## Common issues
109
+
110
+ - Error: `ModuleNotFoundError: No module named 'utils.xxxx'`
111
+
112
+ Answer: run `export PYTHONPATH=./` first.
113
+
114
+ ## Acknowledgements
115
+
116
+ - [Qwen2](https://github.com/QwenLM/Qwen2/) as the LLM backbone.
117
+ - [litGPT](https://github.com/Lightning-AI/litgpt/) for training and inference.
118
+ - [whisper](https://github.com/openai/whisper/) for audio encoding.
119
+ - [snac](https://github.com/hubertsiuzdak/snac/) for audio decoding.
120
+ - [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) for generating synthetic speech.
121
+ - [OpenOrca](https://huggingface.co/datasets/Open-Orca/OpenOrca) and [MOSS](https://github.com/OpenMOSS/MOSS/tree/main) for alignment.
122
+
123
+ ## Star History
124
+
125
+ [![Star History Chart](https://api.star-history.com/svg?repos=gpt-omni/mini-omni&type=Date)](https://star-history.com/#gpt-omni/mini-omni&Date)
inference.py CHANGED
@@ -7,6 +7,8 @@ from litgpt import Tokenizer
7
  from litgpt.utils import (
8
  num_parameters,
9
  )
 
 
10
  from litgpt.generate.base import (
11
  generate_AA,
12
  generate_ASR,
@@ -347,8 +349,8 @@ def T1_T2(fabric, input_ids, model, text_tokenizer, step):
347
 
348
 
349
  def load_model(ckpt_dir, device):
350
- snacmodel = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(device)
351
- whispermodel = whisper.load_model("small").to(device)
352
  text_tokenizer = Tokenizer(ckpt_dir)
353
  fabric = L.Fabric(devices=1, strategy="auto")
354
  config = Config.from_file(ckpt_dir + "/model_config.yaml")
@@ -367,12 +369,12 @@ def load_model(ckpt_dir, device):
367
 
368
  def download_model(ckpt_dir):
369
  repo_id = "gpt-omni/mini-omni"
370
- snapshot_download(repo_id, local_dir=ckpt_dir, revision="main")
371
 
372
 
373
  class OmniInference:
374
 
375
- def __init__(self, ckpt_dir='./checkpoint', device='cuda:0'):
376
  self.device = device
377
  if not os.path.exists(ckpt_dir):
378
  print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
@@ -508,7 +510,7 @@ class OmniInference:
508
  def test_infer():
509
  device = "cuda:0"
510
  out_dir = f"./output/{get_time_str()}"
511
- ckpt_dir = f"./checkpoint"
512
  if not os.path.exists(ckpt_dir):
513
  print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
514
  download_model(ckpt_dir)
 
7
  from litgpt.utils import (
8
  num_parameters,
9
  )
10
+ import matplotlib
11
+ matplotlib.use('Agg') # Use a non-GUI backend
12
  from litgpt.generate.base import (
13
  generate_AA,
14
  generate_ASR,
 
349
 
350
 
351
  def load_model(ckpt_dir, device):
352
+ snacmodel = SNAC.from_pretrained("hubertsiuzdak/snac_24khz", cache_dir="/data/cache/snac").eval().to(device)
353
+ whispermodel = whisper.load_model("small", download_root="/data/cache/whisper").to(device)
354
  text_tokenizer = Tokenizer(ckpt_dir)
355
  fabric = L.Fabric(devices=1, strategy="auto")
356
  config = Config.from_file(ckpt_dir + "/model_config.yaml")
 
369
 
370
  def download_model(ckpt_dir):
371
  repo_id = "gpt-omni/mini-omni"
372
+ snapshot_download(repo_id, local_dir=ckpt_dir, revision="main", cache_dir="/data/huggingface")
373
 
374
 
375
  class OmniInference:
376
 
377
+ def __init__(self, ckpt_dir='/data/checkpoint', device='cuda:0'):
378
  self.device = device
379
  if not os.path.exists(ckpt_dir):
380
  print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
 
510
  def test_infer():
511
  device = "cuda:0"
512
  out_dir = f"./output/{get_time_str()}"
513
+ ckpt_dir = f"/data/checkpoint"
514
  if not os.path.exists(ckpt_dir):
515
  print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
516
  download_model(ckpt_dir)
requirements.txt CHANGED
@@ -6,8 +6,14 @@ snac==1.2.0
6
  soundfile==0.12.1
7
  openai-whisper==20231117
8
  tokenizers==0.15.2
9
- streamlit==1.32.2
10
- PyAudio==0.2.14
 
 
 
 
 
 
11
  pydub==0.25.1
12
  onnxruntime==1.17.1
13
  numpy==1.26.4
 
6
  soundfile==0.12.1
7
  openai-whisper==20231117
8
  tokenizers==0.15.2
9
+ torch==2.2.1
10
+ torchvision==0.17.1
11
+ torchaudio==2.2.1
12
+ litgpt==0.4.3
13
+ snac==1.2.0
14
+ soundfile==0.12.1
15
+ openai-whisper==20231117
16
+ tokenizers==0.15.2
17
  pydub==0.25.1
18
  onnxruntime==1.17.1
19
  numpy==1.26.4
serve_html.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ torch.set_float32_matmul_precision('high')
3
+
4
+ from flask import Flask, send_from_directory, request, Response
5
+ import os
6
+ import base64
7
+ import numpy as np
8
+ from inference import OmniInference
9
+ import io
10
+
11
+ app = Flask(__name__)
12
+
13
+ # Initialize OmniInference
14
+ try:
15
+ print("Initializing OmniInference...")
16
+ omni = OmniInference()
17
+ print("OmniInference initialized successfully.")
18
+ except Exception as e:
19
+ print(f"Error initializing OmniInference: {str(e)}")
20
+ raise
21
+
22
+ @app.route('/')
23
+ def serve_html():
24
+ return send_from_directory('.', 'webui/omni_html_demo.html')
25
+
26
+ @app.route('/chat', methods=['POST'])
27
+ def chat():
28
+ try:
29
+ audio_data = request.json['audio']
30
+ if not audio_data:
31
+ return "No audio data received", 400
32
+
33
+ # Check if the audio_data contains the expected base64 prefix
34
+ if ',' in audio_data:
35
+ audio_bytes = base64.b64decode(audio_data.split(',')[1])
36
+ else:
37
+ audio_bytes = base64.b64decode(audio_data)
38
+
39
+ # Save audio to a temporary file
40
+ temp_audio_path = 'temp_audio.wav'
41
+ with open(temp_audio_path, 'wb') as f:
42
+ f.write(audio_bytes)
43
+
44
+ # Generate response using OmniInference
45
+ try:
46
+ response_generator = omni.run_AT_batch_stream(temp_audio_path)
47
+
48
+ # Concatenate all audio chunks
49
+ all_audio = b''
50
+ for audio_chunk in response_generator:
51
+ all_audio += audio_chunk
52
+
53
+ # Clean up temporary file
54
+ os.remove(temp_audio_path)
55
+
56
+ return Response(all_audio, mimetype='audio/wav')
57
+ except Exception as inner_e:
58
+ print(f"Error in OmniInference processing: {str(inner_e)}")
59
+ return f"An error occurred during audio processing: {str(inner_e)}", 500
60
+ finally:
61
+ # Ensure temporary file is removed even if an error occurs
62
+ if os.path.exists(temp_audio_path):
63
+ os.remove(temp_audio_path)
64
+
65
+ except Exception as e:
66
+ print(f"Error in chat endpoint: {str(e)}")
67
+ return f"An error occurred: {str(e)}", 500
68
+
69
+ if __name__ == '__main__':
70
+ app.run(host='0.0.0.0', port=7860)
server.py CHANGED
@@ -2,21 +2,17 @@ import flask
2
  import base64
3
  import tempfile
4
  import traceback
5
- import torch
6
  from flask import Flask, Response, stream_with_context
7
  from inference import OmniInference
8
 
9
 
10
  class OmniChatServer(object):
11
  def __init__(self, ip='0.0.0.0', port=60808, run_app=True,
12
- ckpt_dir='./checkpoint', device=None) -> None:
13
  server = Flask(__name__)
14
  # CORS(server, resources=r"/*")
15
  # server.config["JSON_AS_ASCII"] = False
16
 
17
- if device is None:
18
- device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
19
-
20
  self.client = OmniInference(ckpt_dir, device)
21
  self.client.warm_up()
22
 
@@ -50,8 +46,9 @@ def create_app():
50
  return server.server
51
 
52
 
53
- def serve(ip='0.0.0.0', port=60808, device=None):
54
- OmniChatServer(ip, port=port, run_app=True, device=device)
 
55
 
56
 
57
  if __name__ == "__main__":
 
2
  import base64
3
  import tempfile
4
  import traceback
 
5
  from flask import Flask, Response, stream_with_context
6
  from inference import OmniInference
7
 
8
 
9
  class OmniChatServer(object):
10
  def __init__(self, ip='0.0.0.0', port=60808, run_app=True,
11
+ ckpt_dir='./checkpoint', device='cuda:0') -> None:
12
  server = Flask(__name__)
13
  # CORS(server, resources=r"/*")
14
  # server.config["JSON_AS_ASCII"] = False
15
 
 
 
 
16
  self.client = OmniInference(ckpt_dir, device)
17
  self.client.warm_up()
18
 
 
46
  return server.server
47
 
48
 
49
+ def serve(ip='0.0.0.0', port=60808, device='cuda:0'):
50
+
51
+ OmniChatServer(ip, port=port,run_app=True, device=device)
52
 
53
 
54
  if __name__ == "__main__":
webui/index.html DELETED
@@ -1,258 +0,0 @@
1
- <!DOCTYPE html>
2
- <html lang="en">
3
- <head>
4
- <meta charset="UTF-8">
5
- <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
- <title>Mini-Omni Chat Demo</title>
7
- <style>
8
- body {
9
- background-color: black;
10
- color: white;
11
- font-family: Arial, sans-serif;
12
- }
13
- #chat-container {
14
- height: 300px;
15
- overflow-y: auto;
16
- border: 1px solid #444;
17
- padding: 10px;
18
- margin-bottom: 10px;
19
- }
20
- #status-message {
21
- margin-bottom: 10px;
22
- }
23
- button {
24
- margin-right: 10px;
25
- }
26
- </style>
27
- </head>
28
- <body>
29
- <div id="svg-container"></div>
30
- <div id="chat-container"></div>
31
- <div id="status-message">Current status: idle</div>
32
- <button id="start-button">Start</button>
33
- <button id="stop-button" disabled>Stop</button>
34
- <main>
35
- <p id="current-status">Current status: idle</p>
36
- </main>
37
- </body>
38
- <script>
39
- // Load the SVG
40
- const svgContainer = document.getElementById('svg-container');
41
- const svgContent = `
42
- <svg width="800" height="600" viewBox="0 0 800 600" xmlns="http://www.w3.org/2000/svg">
43
- <ellipse id="left-eye" cx="340" cy="200" rx="20" ry="20" fill="white"/>
44
- <circle id="left-pupil" cx="340" cy="200" r="8" fill="black"/>
45
- <ellipse id="right-eye" cx="460" cy="200" rx="20" ry="20" fill="white"/>
46
- <circle id="right-pupil" cx="460" cy="200" r="8" fill="black"/>
47
- <path id="upper-lip" d="M 300 300 C 350 284, 450 284, 500 300" stroke="white" stroke-width="10" fill="none"/>
48
- <path id="lower-lip" d="M 300 300 C 350 316, 450 316, 500 300" stroke="white" stroke-width="10" fill="none"/>
49
- </svg>`;
50
- svgContainer.innerHTML = svgContent;
51
-
52
- // Set up audio context
53
- const audioContext = new (window.AudioContext || window.webkitAudioContext)();
54
- const analyser = audioContext.createAnalyser();
55
- analyser.fftSize = 256;
56
-
57
- // Animation variables
58
- let isAudioPlaying = false;
59
- let lastBlinkTime = 0;
60
- let eyeMovementOffset = { x: 0, y: 0 };
61
-
62
- // Chat variables
63
- let mediaRecorder;
64
- let audioChunks = [];
65
- let isRecording = false;
66
- const API_URL = 'http://127.0.0.1:60808/chat';
67
-
68
- // Idle eye animation function
69
- function animateIdleEyes(timestamp) {
70
- const leftEye = document.getElementById('left-eye');
71
- const rightEye = document.getElementById('right-eye');
72
- const leftPupil = document.getElementById('left-pupil');
73
- const rightPupil = document.getElementById('right-pupil');
74
- const baseEyeX = { left: 340, right: 460 };
75
- const baseEyeY = 200;
76
-
77
- // Blink effect
78
- const blinkInterval = 4000 + Math.random() * 2000; // Random blink interval between 4-6 seconds
79
- if (timestamp - lastBlinkTime > blinkInterval) {
80
- leftEye.setAttribute('ry', '2');
81
- rightEye.setAttribute('ry', '2');
82
- leftPupil.setAttribute('ry', '0.8');
83
- rightPupil.setAttribute('ry', '0.8');
84
- setTimeout(() => {
85
- leftEye.setAttribute('ry', '20');
86
- rightEye.setAttribute('ry', '20');
87
- leftPupil.setAttribute('ry', '8');
88
- rightPupil.setAttribute('ry', '8');
89
- }, 150);
90
- lastBlinkTime = timestamp;
91
- }
92
-
93
- // Subtle eye movement
94
- const movementSpeed = 0.001;
95
- eyeMovementOffset.x = Math.sin(timestamp * movementSpeed) * 6;
96
- eyeMovementOffset.y = Math.cos(timestamp * movementSpeed * 1.3) * 1; // Reduced vertical movement
97
-
98
- leftEye.setAttribute('cx', baseEyeX.left + eyeMovementOffset.x);
99
- leftEye.setAttribute('cy', baseEyeY + eyeMovementOffset.y);
100
- rightEye.setAttribute('cx', baseEyeX.right + eyeMovementOffset.x);
101
- rightEye.setAttribute('cy', baseEyeY + eyeMovementOffset.y);
102
- leftPupil.setAttribute('cx', baseEyeX.left + eyeMovementOffset.x);
103
- leftPupil.setAttribute('cy', baseEyeY + eyeMovementOffset.y);
104
- rightPupil.setAttribute('cx', baseEyeX.right + eyeMovementOffset.x);
105
- rightPupil.setAttribute('cy', baseEyeY + eyeMovementOffset.y);
106
- }
107
-
108
- // Main animation function
109
- function animate(timestamp) {
110
- if (isAudioPlaying) {
111
- const dataArray = new Uint8Array(analyser.frequencyBinCount);
112
- analyser.getByteFrequencyData(dataArray);
113
-
114
- // Calculate the average amplitude in the speech frequency range
115
- const speechRange = dataArray.slice(5, 80); // Adjust based on your needs
116
- const averageAmplitude = speechRange.reduce((a, b) => a + b) / speechRange.length;
117
-
118
- // Normalize the amplitude (0-1 range)
119
- const normalizedAmplitude = averageAmplitude / 255;
120
-
121
- // Animate mouth
122
- const upperLip = document.getElementById('upper-lip');
123
- const lowerLip = document.getElementById('lower-lip');
124
- const baseY = 300;
125
- const maxMovement = 60;
126
- const newUpperY = baseY - normalizedAmplitude * maxMovement;
127
- const newLowerY = baseY + normalizedAmplitude * maxMovement;
128
-
129
- // Adjust control points for more natural movement
130
- const upperControlY1 = newUpperY - 8;
131
- const upperControlY2 = newUpperY - 8;
132
- const lowerControlY1 = newLowerY + 8;
133
- const lowerControlY2 = newLowerY + 8;
134
-
135
- upperLip.setAttribute('d', `M 300 ${baseY} C 350 ${upperControlY1}, 450 ${upperControlY2}, 500 ${baseY}`);
136
- lowerLip.setAttribute('d', `M 300 ${baseY} C 350 ${lowerControlY1}, 450 ${lowerControlY2}, 500 ${baseY}`);
137
-
138
- // Animate eyes
139
- const leftEye = document.getElementById('left-eye');
140
- const rightEye = document.getElementById('right-eye');
141
- const leftPupil = document.getElementById('left-pupil');
142
- const rightPupil = document.getElementById('right-pupil');
143
- const baseEyeY = 200;
144
- const maxEyeMovement = 10;
145
- const newEyeY = baseEyeY - normalizedAmplitude * maxEyeMovement;
146
-
147
- leftEye.setAttribute('cy', newEyeY);
148
- rightEye.setAttribute('cy', newEyeY);
149
- leftPupil.setAttribute('cy', newEyeY);
150
- rightPupil.setAttribute('cy', newEyeY);
151
- } else {
152
- animateIdleEyes(timestamp);
153
- }
154
-
155
- requestAnimationFrame(animate);
156
- }
157
-
158
- // Start animation
159
- animate();
160
-
161
- // Chat functions
162
- function startRecording() {
163
- navigator.mediaDevices.getUserMedia({ audio: true })
164
- .then(stream => {
165
- mediaRecorder = new MediaRecorder(stream);
166
- mediaRecorder.ondataavailable = event => {
167
- audioChunks.push(event.data);
168
- };
169
- mediaRecorder.onstop = sendAudioToServer;
170
- mediaRecorder.start();
171
- isRecording = true;
172
- updateStatus('Recording...');
173
- document.getElementById('start-button').disabled = true;
174
- document.getElementById('stop-button').disabled = false;
175
- })
176
- .catch(error => {
177
- console.error('Error accessing microphone:', error);
178
- updateStatus('Error: ' + error.message);
179
- });
180
- }
181
-
182
- function stopRecording() {
183
- if (mediaRecorder && isRecording) {
184
- mediaRecorder.stop();
185
- isRecording = false;
186
- updateStatus('Processing...');
187
- document.getElementById('start-button').disabled = false;
188
- document.getElementById('stop-button').disabled = true;
189
- }
190
- }
191
-
192
- function sendAudioToServer() {
193
- const audioBlob = new Blob(audioChunks, { type: 'audio/wav' });
194
- const reader = new FileReader();
195
- reader.readAsDataURL(audioBlob);
196
- reader.onloadend = function() {
197
- const base64Audio = reader.result.split(',')[1];
198
- fetch(API_URL, {
199
- method: 'POST',
200
- headers: {
201
- 'Content-Type': 'application/json',
202
- },
203
- body: JSON.stringify({ audio: base64Audio }),
204
- })
205
- .then(response => response.blob())
206
- .then(blob => {
207
- const audioUrl = URL.createObjectURL(blob);
208
- playResponseAudio(audioUrl);
209
- updateChatHistory('User', 'Audio message sent');
210
- updateChatHistory('Assistant', 'Audio response received');
211
- })
212
- .catch(error => {
213
- console.error('Error:', error);
214
- updateStatus('Error: ' + error.message);
215
- });
216
- };
217
- audioChunks = [];
218
- }
219
-
220
- function playResponseAudio(audioUrl) {
221
- const audio = new Audio(audioUrl);
222
- audio.onloadedmetadata = () => {
223
- const source = audioContext.createMediaElementSource(audio);
224
- source.connect(analyser);
225
- analyser.connect(audioContext.destination);
226
- };
227
- audio.onplay = () => {
228
- isAudioPlaying = true;
229
- updateStatus('Playing response...');
230
- };
231
- audio.onended = () => {
232
- isAudioPlaying = false;
233
- updateStatus('Idle');
234
- };
235
- audio.play();
236
- }
237
-
238
- function updateChatHistory(role, message) {
239
- const chatContainer = document.getElementById('chat-container');
240
- const messageElement = document.createElement('p');
241
- messageElement.textContent = `${role}: ${message}`;
242
- chatContainer.appendChild(messageElement);
243
- chatContainer.scrollTop = chatContainer.scrollHeight;
244
- }
245
-
246
- function updateStatus(status) {
247
- document.getElementById('status-message').textContent = status;
248
- document.getElementById('current-status').textContent = 'Current status: ' + status;
249
- }
250
-
251
- // Event listeners
252
- document.getElementById('start-button').addEventListener('click', startRecording);
253
- document.getElementById('stop-button').addEventListener('click', stopRecording);
254
-
255
- // Initialize
256
- updateStatus('Idle');
257
- </script>
258
- </html>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
webui/omni_html_demo.html CHANGED
@@ -21,7 +21,7 @@
21
  <audio id="audioPlayback" controls style="display:none;"></audio>
22
 
23
  <script>
24
- const API_URL = 'http://127.0.0.1:60808/chat';
25
  const recordButton = document.getElementById('recordButton');
26
  const chatHistory = document.getElementById('chatHistory');
27
  const audioPlayback = document.getElementById('audioPlayback');
@@ -86,12 +86,13 @@
86
  }
87
  });
88
 
89
- const audioResponse = new Response(stream);
90
- const audioBlob = await audioResponse.blob();
91
- audioPlayback.src = URL.createObjectURL(audioBlob);
92
- audioPlayback.play();
93
-
94
- updateChatHistory('AI', URL.createObjectURL(audioBlob));
 
95
  } else {
96
  console.error('API response not ok:', response.status);
97
  updateChatHistory('AI', 'Error in API response');
@@ -99,7 +100,11 @@
99
  };
100
  } catch (error) {
101
  console.error('Error sending audio to API:', error);
102
- updateChatHistory('AI', 'Error communicating with the server');
 
 
 
 
103
  }
104
  }
105
 
 
21
  <audio id="audioPlayback" controls style="display:none;"></audio>
22
 
23
  <script>
24
+ const API_URL = '/chat';
25
  const recordButton = document.getElementById('recordButton');
26
  const chatHistory = document.getElementById('chatHistory');
27
  const audioPlayback = document.getElementById('audioPlayback');
 
86
  }
87
  });
88
 
89
+ const responseBlob = await new Response(stream).blob();
90
+ const audioUrl = URL.createObjectURL(responseBlob);
91
+ updateChatHistory('AI', audioUrl);
92
+
93
+ // Play the audio response
94
+ const audio = new Audio(audioUrl);
95
+ audio.play();
96
  } else {
97
  console.error('API response not ok:', response.status);
98
  updateChatHistory('AI', 'Error in API response');
 
100
  };
101
  } catch (error) {
102
  console.error('Error sending audio to API:', error);
103
+ if (error.name === 'TypeError' && error.message === 'Failed to fetch') {
104
+ updateChatHistory('AI', 'Error: Unable to connect to the server. Please ensure the server is running and accessible.');
105
+ } else {
106
+ updateChatHistory('AI', 'Error communicating with the server: ' + error.message);
107
+ }
108
  }
109
  }
110
 
webui/omni_streamlit.py CHANGED
@@ -1,257 +1,134 @@
1
- import streamlit as st
2
- import wave
3
-
4
- # from ASR import recognize
5
- import requests
6
- import pyaudio
7
- import numpy as np
8
- import base64
9
- import io
10
- import os
11
- import time
12
- import tempfile
13
- import librosa
14
- import traceback
15
- from pydub import AudioSegment
16
- from utils.vad import get_speech_timestamps, collect_chunks, VadOptions
17
-
18
-
19
- API_URL = os.getenv("API_URL", "http://127.0.0.1:60808/chat")
20
-
21
- # recording parameters
22
- IN_FORMAT = pyaudio.paInt16
23
- IN_CHANNELS = 1
24
- IN_RATE = 24000
25
- IN_CHUNK = 1024
26
- IN_SAMPLE_WIDTH = 2
27
- VAD_STRIDE = 0.5
28
-
29
- # playing parameters
30
- OUT_FORMAT = pyaudio.paInt16
31
- OUT_CHANNELS = 1
32
- OUT_RATE = 24000
33
- OUT_SAMPLE_WIDTH = 2
34
- OUT_CHUNK = 5760
35
-
36
-
37
- # Initialize chat history
38
- if "messages" not in st.session_state:
39
- st.session_state.messages = []
40
-
41
-
42
- def run_vad(ori_audio, sr):
43
- _st = time.time()
44
- try:
45
- audio = np.frombuffer(ori_audio, dtype=np.int16)
46
- audio = audio.astype(np.float32) / 32768.0
47
- sampling_rate = 16000
48
- if sr != sampling_rate:
49
- audio = librosa.resample(audio, orig_sr=sr, target_sr=sampling_rate)
50
-
51
- vad_parameters = {}
52
- vad_parameters = VadOptions(**vad_parameters)
53
- speech_chunks = get_speech_timestamps(audio, vad_parameters)
54
- audio = collect_chunks(audio, speech_chunks)
55
- duration_after_vad = audio.shape[0] / sampling_rate
56
-
57
- if sr != sampling_rate:
58
- # resample to original sampling rate
59
- vad_audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=sr)
60
- else:
61
- vad_audio = audio
62
- vad_audio = np.round(vad_audio * 32768.0).astype(np.int16)
63
- vad_audio_bytes = vad_audio.tobytes()
64
-
65
- return duration_after_vad, vad_audio_bytes, round(time.time() - _st, 4)
66
- except Exception as e:
67
- msg = f"[asr vad error] audio_len: {len(ori_audio)/(sr*2):.3f} s, trace: {traceback.format_exc()}"
68
- print(msg)
69
- return -1, ori_audio, round(time.time() - _st, 4)
70
-
71
-
72
- def warm_up():
73
- frames = b"\x00\x00" * 1024 * 2 # 1024 frames of 2 bytes each
74
- dur, frames, tcost = run_vad(frames, 16000)
75
- print(f"warm up done, time_cost: {tcost:.3f} s")
76
-
77
-
78
- def save_tmp_audio(audio_bytes):
79
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
80
- file_name = tmpfile.name
81
- audio = AudioSegment(
82
- data=audio_bytes,
83
- sample_width=OUT_SAMPLE_WIDTH,
84
- frame_rate=OUT_RATE,
85
- channels=OUT_CHANNELS,
86
- )
87
- audio.export(file_name, format="wav")
88
- return file_name
89
-
90
-
91
- def speaking(status):
92
-
93
- # Initialize PyAudio
94
- p = pyaudio.PyAudio()
95
-
96
- # Open PyAudio stream
97
- stream = p.open(
98
- format=OUT_FORMAT, channels=OUT_CHANNELS, rate=OUT_RATE, output=True
99
- )
100
-
101
- audio_buffer = io.BytesIO()
102
- wf = wave.open(audio_buffer, "wb")
103
- wf.setnchannels(IN_CHANNELS)
104
- wf.setsampwidth(IN_SAMPLE_WIDTH)
105
- wf.setframerate(IN_RATE)
106
- total_frames = b"".join(st.session_state.frames)
107
- dur = len(total_frames) / (IN_RATE * IN_CHANNELS * IN_SAMPLE_WIDTH)
108
- status.warning(f"Speaking... recorded audio duration: {dur:.3f} s")
109
- wf.writeframes(total_frames)
110
-
111
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
112
- with open(tmpfile.name, "wb") as f:
113
- f.write(audio_buffer.getvalue())
114
- file_name = tmpfile.name
115
- with st.chat_message("user"):
116
- st.audio(file_name, format="audio/wav", loop=False, autoplay=False)
117
- st.session_state.messages.append(
118
- {"role": "assistant", "content": file_name, "type": "audio"}
119
- )
120
-
121
- st.session_state.frames = []
122
-
123
- audio_bytes = audio_buffer.getvalue()
124
- base64_encoded = str(base64.b64encode(audio_bytes), encoding="utf-8")
125
- files = {"audio": base64_encoded}
126
- output_audio_bytes = b""
127
- with requests.post(API_URL, json=files, stream=True) as response:
128
- try:
129
- for chunk in response.iter_content(chunk_size=OUT_CHUNK):
130
- if chunk:
131
- # Convert chunk to numpy array
132
- output_audio_bytes += chunk
133
- audio_data = np.frombuffer(chunk, dtype=np.int8)
134
- # Play audio
135
- stream.write(audio_data)
136
- except Exception as e:
137
- st.error(f"Error during audio streaming: {e}")
138
-
139
- out_file = save_tmp_audio(output_audio_bytes)
140
- with st.chat_message("assistant"):
141
- st.audio(out_file, format="audio/wav", loop=False, autoplay=False)
142
- st.session_state.messages.append(
143
- {"role": "assistant", "content": out_file, "type": "audio"}
144
- )
145
-
146
- wf.close()
147
- # Close PyAudio stream and terminate PyAudio
148
- stream.stop_stream()
149
- stream.close()
150
- p.terminate()
151
- st.session_state.speaking = False
152
- st.session_state.recording = True
153
-
154
-
155
- def recording(status):
156
- audio = pyaudio.PyAudio()
157
-
158
- stream = audio.open(
159
- format=IN_FORMAT,
160
- channels=IN_CHANNELS,
161
- rate=IN_RATE,
162
- input=True,
163
- frames_per_buffer=IN_CHUNK,
164
- )
165
-
166
- temp_audio = b""
167
- vad_audio = b""
168
-
169
- start_talking = False
170
- last_temp_audio = None
171
- st.session_state.frames = []
172
-
173
- while st.session_state.recording:
174
- status.success("Listening...")
175
- audio_bytes = stream.read(IN_CHUNK)
176
- temp_audio += audio_bytes
177
-
178
- if len(temp_audio) > IN_SAMPLE_WIDTH * IN_RATE * IN_CHANNELS * VAD_STRIDE:
179
- dur_vad, vad_audio_bytes, time_vad = run_vad(temp_audio, IN_RATE)
180
-
181
- print(f"duration_after_vad: {dur_vad:.3f} s, time_vad: {time_vad:.3f} s")
182
-
183
- if dur_vad > 0.2 and not start_talking:
184
- if last_temp_audio is not None:
185
- st.session_state.frames.append(last_temp_audio)
186
- start_talking = True
187
- if start_talking:
188
- st.session_state.frames.append(temp_audio)
189
- if dur_vad < 0.1 and start_talking:
190
- st.session_state.recording = False
191
- print(f"speech end detected. excit")
192
- last_temp_audio = temp_audio
193
- temp_audio = b""
194
-
195
- stream.stop_stream()
196
- stream.close()
197
-
198
- audio.terminate()
199
-
200
-
201
- def main():
202
-
203
- st.title("Chat Mini-Omni Demo")
204
- status = st.empty()
205
-
206
- if "warm_up" not in st.session_state:
207
- warm_up()
208
- st.session_state.warm_up = True
209
- if "start" not in st.session_state:
210
- st.session_state.start = False
211
- if "recording" not in st.session_state:
212
- st.session_state.recording = False
213
- if "speaking" not in st.session_state:
214
- st.session_state.speaking = False
215
- if "frames" not in st.session_state:
216
- st.session_state.frames = []
217
-
218
- if not st.session_state.start:
219
- status.warning("Click Start to chat")
220
-
221
- start_col, stop_col, _ = st.columns([0.2, 0.2, 0.6])
222
- start_button = start_col.button("Start", key="start_button")
223
- # stop_button = stop_col.button("Stop", key="stop_button")
224
- if start_button:
225
- time.sleep(1)
226
- st.session_state.recording = True
227
- st.session_state.start = True
228
-
229
- for message in st.session_state.messages:
230
- with st.chat_message(message["role"]):
231
- if message["type"] == "msg":
232
- st.markdown(message["content"])
233
- elif message["type"] == "img":
234
- st.image(message["content"], width=300)
235
- elif message["type"] == "audio":
236
- st.audio(
237
- message["content"], format="audio/wav", loop=False, autoplay=False
238
- )
239
-
240
- while st.session_state.start:
241
- if st.session_state.recording:
242
- recording(status)
243
-
244
- if not st.session_state.recording and st.session_state.start:
245
- st.session_state.speaking = True
246
- speaking(status)
247
-
248
- # if stop_button:
249
- # status.warning("Stopped, click Start to chat")
250
- # st.session_state.start = False
251
- # st.session_state.recording = False
252
- # st.session_state.frames = []
253
- # break
254
-
255
-
256
- if __name__ == "__main__":
257
- main()
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import requests
4
+ import base64
5
+ import tempfile
6
+ import os
7
+ import time
8
+ import traceback
9
+ import librosa
10
+ from pydub import AudioSegment
11
+ from streamlit_webrtc import webrtc_streamer, WebRtcMode, RTCConfiguration
12
+ import av
13
+ from utils.vad import get_speech_timestamps, collect_chunks, VadOptions
14
+
15
+ API_URL = os.getenv("API_URL", "http://127.0.0.1:60808/chat")
16
+
17
+ # Initialize chat history
18
+ if "messages" not in st.session_state:
19
+ st.session_state.messages = []
20
+
21
+ def run_vad(audio, sr):
22
+ _st = time.time()
23
+ try:
24
+ audio = audio.astype(np.float32) / 32768.0
25
+ sampling_rate = 16000
26
+ if sr != sampling_rate:
27
+ audio = librosa.resample(audio, orig_sr=sr, target_sr=sampling_rate)
28
+
29
+ vad_parameters = {}
30
+ vad_parameters = VadOptions(**vad_parameters)
31
+ speech_chunks = get_speech_timestamps(audio, vad_parameters)
32
+ audio = collect_chunks(audio, speech_chunks)
33
+ duration_after_vad = audio.shape[0] / sampling_rate
34
+
35
+ if sr != sampling_rate:
36
+ # resample to original sampling rate
37
+ vad_audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=sr)
38
+ else:
39
+ vad_audio = audio
40
+ vad_audio = np.round(vad_audio * 32768.0).astype(np.int16)
41
+ vad_audio_bytes = vad_audio.tobytes()
42
+
43
+ return duration_after_vad, vad_audio_bytes, round(time.time() - _st, 4)
44
+ except Exception as e:
45
+ msg = f"[asr vad error] audio_len: {len(audio)/(sr):.3f} s, trace: {traceback.format_exc()}"
46
+ print(msg)
47
+ return -1, audio.tobytes(), round(time.time() - _st, 4)
48
+
49
+ def save_tmp_audio(audio_bytes):
50
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
51
+ file_name = tmpfile.name
52
+ audio = AudioSegment(
53
+ data=audio_bytes,
54
+ sample_width=2,
55
+ frame_rate=16000,
56
+ channels=1,
57
+ )
58
+ audio.export(file_name, format="wav")
59
+ return file_name
60
+
61
+ def main():
62
+ st.title("Chat Mini-Omni Demo")
63
+ status = st.empty()
64
+
65
+ if "audio_buffer" not in st.session_state:
66
+ st.session_state.audio_buffer = []
67
+
68
+ webrtc_ctx = webrtc_streamer(
69
+ key="speech-to-text",
70
+ mode=WebRtcMode.SENDONLY,
71
+ audio_receiver_size=1024,
72
+ rtc_configuration=RTCConfiguration(
73
+ {"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
74
+ ),
75
+ media_stream_constraints={"video": False, "audio": True},
76
+ )
77
+
78
+ if webrtc_ctx.audio_receiver:
79
+ while True:
80
+ try:
81
+ audio_frame = webrtc_ctx.audio_receiver.get_frame(timeout=1)
82
+ sound_chunk = np.frombuffer(audio_frame.to_ndarray(), dtype="int16")
83
+ st.session_state.audio_buffer.extend(sound_chunk)
84
+
85
+ if len(st.session_state.audio_buffer) >= 16000:
86
+ duration_after_vad, vad_audio_bytes, vad_time = run_vad(
87
+ np.array(st.session_state.audio_buffer), 16000
88
+ )
89
+ st.session_state.audio_buffer = []
90
+ if duration_after_vad > 0:
91
+ st.session_state.messages.append(
92
+ {"role": "user", "content": "User audio"}
93
+ )
94
+ file_name = save_tmp_audio(vad_audio_bytes)
95
+ st.audio(file_name, format="audio/wav")
96
+
97
+ response = requests.post(API_URL, data=vad_audio_bytes)
98
+ assistant_audio_bytes = response.content
99
+ assistant_file_name = save_tmp_audio(assistant_audio_bytes)
100
+ st.audio(assistant_file_name, format="audio/wav")
101
+ st.session_state.messages.append(
102
+ {"role": "assistant", "content": "Assistant response"}
103
+ )
104
+ except Exception as e:
105
+ print(f"Error in audio processing: {e}")
106
+ break
107
+
108
+ if st.button("Process Audio"):
109
+ if st.session_state.audio_buffer:
110
+ duration_after_vad, vad_audio_bytes, vad_time = run_vad(
111
+ np.array(st.session_state.audio_buffer), 16000
112
+ )
113
+ st.session_state.messages.append({"role": "user", "content": "User audio"})
114
+ file_name = save_tmp_audio(vad_audio_bytes)
115
+ st.audio(file_name, format="audio/wav")
116
+
117
+ response = requests.post(API_URL, data=vad_audio_bytes)
118
+ assistant_audio_bytes = response.content
119
+ assistant_file_name = save_tmp_audio(assistant_audio_bytes)
120
+ st.audio(assistant_file_name, format="audio/wav")
121
+ st.session_state.messages.append(
122
+ {"role": "assistant", "content": "Assistant response"}
123
+ )
124
+ st.session_state.audio_buffer = []
125
+
126
+ if st.session_state.messages:
127
+ for message in st.session_state.messages:
128
+ if message["role"] == "user":
129
+ st.write(f"User: {message['content']}")
130
+ else:
131
+ st.write(f"Assistant: {message['content']}")
132
+
133
+ if __name__ == "__main__":
134
+ main()