Helw150 commited on
Commit
87930ea
0 Parent(s):
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ utils/assets/silero_vad.onnx filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ checkpoint/
2
+ __pycache__
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Omni Mini
3
+ emoji: 🌖
4
+ colorFrom: gray
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 5.0.0b1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import traceback
3
+ from dataclasses import dataclass, field
4
+
5
+ import gradio as gr
6
+ import librosa
7
+ import numpy as np
8
+ import soundfile as sf
9
+ import spaces
10
+ import torch
11
+ import xxhash
12
+ from datasets import Audio
13
+ from transformers import AutoModel
14
+ import io
15
+ from pydub import AudioSegment
16
+ import tempfile
17
+
18
+ from utils.vad import VadOptions, collect_chunks, get_speech_timestamps
19
+
20
+ diva_model = AutoModel.from_pretrained(
21
+ "WillHeld/DiVA-llama-3-v0-8b", trust_remote_code=True
22
+ )
23
+
24
+ resampler = Audio(sampling_rate=16_000)
25
+
26
+
27
+ @spaces.GPU
28
+ @torch.no_grad
29
+ def diva_audio(audio_input, do_sample=False, temperature=0.001):
30
+ sr, y = audio_input
31
+ x = xxhash.xxh32(bytes(y)).hexdigest()
32
+ y = y.astype(np.float32)
33
+ y /= np.max(np.abs(y))
34
+ a = resampler.decode_example(
35
+ resampler.encode_example({"array": y, "sampling_rate": sr})
36
+ )
37
+ yield from diva_model.generate_stream(
38
+ a["array"], None, do_sample=do_sample, max_new_tokens=256
39
+ )
40
+
41
+
42
+ def run_vad(ori_audio, sr):
43
+ _st = time.time()
44
+ try:
45
+ audio = ori_audio
46
+ audio = audio.astype(np.float32) / 32768.0
47
+ sampling_rate = 16000
48
+ if sr != sampling_rate:
49
+ audio = librosa.resample(audio, orig_sr=sr, target_sr=sampling_rate)
50
+
51
+ vad_parameters = {}
52
+ vad_parameters = VadOptions(**vad_parameters)
53
+ speech_chunks = get_speech_timestamps(audio, vad_parameters)
54
+ audio = collect_chunks(audio, speech_chunks)
55
+ duration_after_vad = audio.shape[0] / sampling_rate
56
+
57
+ if sr != sampling_rate:
58
+ # resample to original sampling rate
59
+ vad_audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=sr)
60
+ else:
61
+ vad_audio = audio
62
+ vad_audio = np.round(vad_audio * 32768.0).astype(np.int16)
63
+ vad_audio_bytes = vad_audio.tobytes()
64
+
65
+ return duration_after_vad, vad_audio_bytes, round(time.time() - _st, 4)
66
+ except Exception as e:
67
+ msg = f"[asr vad error] audio_len: {len(ori_audio)/(sr*2):.3f} s, trace: {traceback.format_exc()}"
68
+ print(msg)
69
+ return -1, ori_audio, round(time.time() - _st, 4)
70
+
71
+
72
+ def warm_up():
73
+ frames = b"\x00\x00" * 1024 * 2 # 1024 frames of 2 bytes each
74
+ dur, frames, tcost = run_vad(frames, 16000)
75
+ print(f"warm up done, time_cost: {tcost:.3f} s")
76
+
77
+
78
+ warm_up()
79
+
80
+
81
+ @dataclass
82
+ class AppState:
83
+ stream: np.ndarray | None = None
84
+ sampling_rate: int = 0
85
+ pause_detected: bool = False
86
+ started_talking: bool = False
87
+ stopped: bool = False
88
+ conversation: list = field(default_factory=list)
89
+
90
+
91
+ def determine_pause(audio: np.ndarray, sampling_rate: int, state: AppState) -> bool:
92
+ """Take in the stream, determine if a pause happened"""
93
+
94
+ temp_audio = audio
95
+
96
+ dur_vad, _, time_vad = run_vad(temp_audio, sampling_rate)
97
+ duration = len(audio) / sampling_rate
98
+
99
+ if dur_vad > 0.5 and not state.started_talking:
100
+ print("started talking")
101
+ state.started_talking = True
102
+ return False
103
+
104
+ print(f"duration_after_vad: {dur_vad:.3f} s, time_vad: {time_vad:.3f} s")
105
+
106
+ return (duration - dur_vad) > 1
107
+
108
+
109
+ def process_audio(audio: tuple, state: AppState):
110
+ if state.stream is None:
111
+ state.stream = audio[1]
112
+ state.sampling_rate = audio[0]
113
+ else:
114
+ state.stream = np.concatenate((state.stream, audio[1]))
115
+
116
+ pause_detected = determine_pause(state.stream, state.sampling_rate, state)
117
+ state.pause_detected = pause_detected
118
+
119
+ if state.pause_detected and state.started_talking:
120
+ return gr.Audio(recording=False), state
121
+ return None, state
122
+
123
+
124
+ def response(state: AppState):
125
+ if not state.pause_detected and not state.started_talking:
126
+ return AppState()
127
+
128
+ file_name = f"/tmp/{xxhash.xxh32(bytes(state.stream)).hexdigest()}.wav"
129
+
130
+ sf.write(f"{x}.wav", state.stream, state.sampling_rate, format="wav")
131
+
132
+ state.conversation.append(
133
+ {"role": "user", "content": {"path": file_name, "mime_type": "audio/wav"}}
134
+ )
135
+
136
+ start = False
137
+ for resp in diva_audio((state.sampling_rate, state.stream)):
138
+ if not start:
139
+ state.conversation.append({"role": "assistant", "content": resp})
140
+ start = True
141
+ else:
142
+ state.conversation[-1]["content"] = resp
143
+ yield state, state.conversation
144
+
145
+ yield AppState(conversation=state.conversation), state.conversation
146
+
147
+
148
+ def start_recording_user(state: AppState):
149
+ if not state.stopped:
150
+ return gr.Audio(recording=True)
151
+
152
+
153
+ theme = gr.themes.Soft(
154
+ primary_hue=gr.themes.Color(
155
+ c100="#82000019",
156
+ c200="#82000033",
157
+ c300="#8200004c",
158
+ c400="#82000066",
159
+ c50="#8200007f",
160
+ c500="#8200007f",
161
+ c600="#82000099",
162
+ c700="#820000b2",
163
+ c800="#820000cc",
164
+ c900="#820000e5",
165
+ c950="#820000f2",
166
+ ),
167
+ secondary_hue="rose",
168
+ neutral_hue="stone",
169
+ )
170
+
171
+ with gr.Blocks(theme=theme) as demo:
172
+ with gr.Row():
173
+ with gr.Column():
174
+ input_audio = gr.Audio(
175
+ label="Input Audio", sources="microphone", type="numpy"
176
+ )
177
+ with gr.Column():
178
+ chatbot = gr.Chatbot(label="Conversation", type="messages")
179
+ state = gr.State(value=AppState())
180
+
181
+ stream = input_audio.stream(
182
+ process_audio,
183
+ [input_audio, state],
184
+ [input_audio, state],
185
+ stream_every=0.50,
186
+ time_limit=30,
187
+ )
188
+ respond = input_audio.stop_recording(response, [state], [state, chatbot])
189
+ respond.then(start_recording_user, [state], [input_audio])
190
+
191
+ cancel = gr.Button("Stop Conversation", variant="stop")
192
+ cancel.click(
193
+ lambda: (AppState(stopped=True), gr.Audio(recording=False)),
194
+ None,
195
+ [state, input_audio],
196
+ cancels=[respond, stream],
197
+ )
198
+
199
+
200
+ demo.launch(share=True)
data/samples/output1.wav ADDED
Binary file (62.2 kB). View file
 
data/samples/output2.wav ADDED
Binary file (105 kB). View file
 
data/samples/output3.wav ADDED
Binary file (70.4 kB). View file
 
data/samples/output4.wav ADDED
Binary file (67.6 kB). View file
 
data/samples/output5.wav ADDED
Binary file (115 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers==4.43.3
2
+ gradio==5.0.1
3
+ spaces
4
+ accelerate
5
+ peft
6
+ librosa
7
+ torchaudio
8
+ soundfile
9
+ transformers_stream_generator
10
+ einops
11
+ sentencepiece
12
+ tiktoken
13
+ torch==2.3.1
14
+ torchvision==0.18.1
15
+ torchaudio==2.3.1
16
+ soundfile==0.12.1
17
+ tokenizers==0.19.1
18
+ librosa==0.10.2.post1
19
+ onnxruntime==1.19.0
20
+