csukuangfj commited on
Commit
2e7292b
·
1 Parent(s): a73e88b

first commit

Browse files
Files changed (3) hide show
  1. README.md +7 -2
  2. app.py +297 -0
  3. model.py +114 -0
README.md CHANGED
@@ -4,10 +4,15 @@ emoji: 📈
4
  colorFrom: blue
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 4.26.0
 
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
4
  colorFrom: blue
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.14.0
8
+ python_version: 3.8.9
9
  app_file: app.py
10
  pinned: false
11
  license: apache-2.0
12
  ---
13
 
14
+ Please see
15
+
16
+ https://k2-fsa.github.io/sherpa/onnx/audio-tagging/index.html
17
+
18
+ for more information.
app.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ #
3
+ # Copyright 2022-2024 Xiaomi Corp. (authors: Fangjun Kuang)
4
+ #
5
+ # See LICENSE for clarification regarding multiple authors
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+ # References:
20
+ # https://gradio.app/docs/#dropdown
21
+
22
+ import logging
23
+ import os
24
+ import tempfile
25
+ import time
26
+ import urllib.request
27
+ from datetime import datetime
28
+
29
+ from examples import examples
30
+ import gradio as gr
31
+ import soundfile as sf
32
+
33
+ from model import decode, get_pretrained_model, models
34
+
35
+
36
+ def convert_to_wav(in_filename: str) -> str:
37
+ """Convert the input audio file to a wave file"""
38
+ out_filename = in_filename + ".wav"
39
+ logging.info(f"Converting '{in_filename}' to '{out_filename}'")
40
+
41
+ _ = os.system(
42
+ f"ffmpeg -hide_banner -i '{in_filename}' -ar 16000 -ac 1 '{out_filename}' -y"
43
+ )
44
+
45
+ return out_filename
46
+
47
+
48
+ def build_html_output(s: str, style: str = "result_item_success"):
49
+ return f"""
50
+ <div class='result'>
51
+ <div class='result_item {style}'>
52
+ {s}
53
+ </div>
54
+ </div>
55
+ """
56
+
57
+
58
+ def process_url(
59
+ repo_id: str,
60
+ url: str,
61
+ ):
62
+ logging.info(f"Processing URL: {url}")
63
+ with tempfile.NamedTemporaryFile() as f:
64
+ try:
65
+ urllib.request.urlretrieve(url, f.name)
66
+
67
+ return process(
68
+ in_filename=f.name,
69
+ repo_id=repo_id,
70
+ )
71
+ except Exception as e:
72
+ logging.info(str(e))
73
+ return "", build_html_output(str(e), "result_item_error")
74
+
75
+
76
+ def process_uploaded_file(
77
+ repo_id: str,
78
+ in_filename: str,
79
+ ):
80
+ if in_filename is None or in_filename == "":
81
+ return "", build_html_output(
82
+ "Please first upload a file and then click "
83
+ 'the button "submit for recognition"',
84
+ "result_item_error",
85
+ )
86
+
87
+ logging.info(f"Processing uploaded file: {in_filename}")
88
+ try:
89
+ return process(
90
+ in_filename=in_filename,
91
+ repo_id=repo_id,
92
+ )
93
+ except Exception as e:
94
+ logging.info(str(e))
95
+ return "", build_html_output(str(e), "result_item_error")
96
+
97
+
98
+ def process_microphone(
99
+ repo_id: str,
100
+ in_filename: str,
101
+ ):
102
+ if in_filename is None or in_filename == "":
103
+ return "", build_html_output(
104
+ "Please first click 'Record from microphone', speak, "
105
+ "click 'Stop recording', and then "
106
+ "click the button 'submit for recognition'",
107
+ "result_item_error",
108
+ )
109
+
110
+ logging.info(f"Processing microphone: {in_filename}")
111
+ try:
112
+ return process(
113
+ in_filename=in_filename,
114
+ repo_id=repo_id,
115
+ )
116
+ except Exception as e:
117
+ logging.info(str(e))
118
+ return "", build_html_output(str(e), "result_item_error")
119
+
120
+
121
+ def process(
122
+ repo_id: str,
123
+ in_filename: str,
124
+ ):
125
+ logging.info(f"repo_id: {repo_id}")
126
+ logging.info(f"in_filename: {in_filename}")
127
+
128
+ filename = convert_to_wav(in_filename)
129
+
130
+ now = datetime.now()
131
+ date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f")
132
+ logging.info(f"Started at {date_time}")
133
+
134
+ start = time.time()
135
+
136
+ tagger = get_pretrained_model(repo_id)
137
+
138
+ events = decode(tagger, filename)
139
+
140
+ date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f")
141
+ end = time.time()
142
+
143
+ info = sf.info(filename)
144
+ duration = info.duration
145
+
146
+ elapsed = end - start
147
+ rtf = elapsed / duration
148
+
149
+ logging.info(f"Finished at {date_time} s. Elapsed: {elapsed: .3f} s")
150
+
151
+ info = f"""
152
+ Wave duration : {duration: .3f} s <br/>
153
+ Processing time: {elapsed: .3f} s <br/>
154
+ RTF: {elapsed: .3f}/{duration: .3f} = {rtf:.3f} <br/>
155
+ """
156
+ if rtf > 1:
157
+ info += (
158
+ "<br/>We are loading the model for the first run. "
159
+ "Please run again to measure the real RTF.<br/>"
160
+ )
161
+
162
+ logging.info(info)
163
+ logging.info(f"\nrepo_id: {repo_id}\nDetected events: {events}")
164
+
165
+ events = {
166
+ "headers": ["Event name", "Probability"],
167
+ "data": [["bird", 0.9], ["pig", 0.8]],
168
+ }
169
+
170
+ return events, build_html_output(info)
171
+
172
+
173
+ title = "# Audio tagging with [Next-gen Kaldi](https://github.com/k2-fsa) "
174
+ description = """
175
+ This space shows how to do audio tagging with [Next-gen Kaldi](https://github.com/k2-fsa)
176
+
177
+ It is running on a machine with 2 vCPUs with 16 GB RAM within a docker container provided by Hugging Face.
178
+
179
+ See more information by visiting the following links:
180
+
181
+ - <https://github.com/k2-fsa/sherpa-onnx>
182
+
183
+ If you want to deploy it locally, please see
184
+ <https://k2-fsa.github.io/sherpa/onnx>
185
+ """
186
+
187
+ # css style is copied from
188
+ # https://huggingface.co/spaces/alphacep/asr/blob/main/app.py#L113
189
+ css = """
190
+ .result {display:flex;flex-direction:column}
191
+ .result_item {padding:15px;margin-bottom:8px;border-radius:15px;width:100%}
192
+ .result_item_success {background-color:mediumaquamarine;color:white;align-self:start}
193
+ .result_item_error {background-color:#ff7070;color:white;align-self:start}
194
+ """
195
+
196
+
197
+ demo = gr.Blocks(css=css)
198
+
199
+
200
+ with demo:
201
+ gr.Markdown(title)
202
+ model_choices = list(models.keys())
203
+
204
+ model_dropdown = gr.Dropdown(
205
+ choices=model_choices,
206
+ label="Select a model",
207
+ value=model_choices[0],
208
+ )
209
+
210
+ with gr.Tabs():
211
+ with gr.TabItem("Upload from disk"):
212
+ uploaded_file = gr.Audio(
213
+ sources=["upload"], # Choose between "microphone", "upload"
214
+ type="filepath",
215
+ label="Upload from disk",
216
+ )
217
+ upload_button = gr.Button("Submit for recognition")
218
+ uploaded_html_info = gr.HTML(label="Info")
219
+
220
+ gr.Examples(
221
+ examples=examples,
222
+ inputs=[
223
+ model_dropdown,
224
+ uploaded_file,
225
+ ],
226
+ outputs=["dataframe", uploaded_html_info],
227
+ fn=process_uploaded_file,
228
+ )
229
+
230
+ with gr.TabItem("Record from microphone"):
231
+ microphone = gr.Audio(
232
+ sources=["microphone"], # Choose between "microphone", "upload"
233
+ type="filepath",
234
+ label="Record from microphone",
235
+ )
236
+
237
+ record_button = gr.Button("Submit for recognition")
238
+ recorded_output = gr.Textbox(label="Detected language from recordings")
239
+ recorded_html_info = gr.HTML(label="Info")
240
+
241
+ gr.Examples(
242
+ examples=examples,
243
+ inputs=[
244
+ model_dropdown,
245
+ microphone,
246
+ ],
247
+ outputs=[recorded_output, recorded_html_info],
248
+ fn=process_microphone,
249
+ )
250
+
251
+ with gr.TabItem("From URL"):
252
+ url_textbox = gr.Textbox(
253
+ max_lines=1,
254
+ placeholder="URL to an audio file",
255
+ label="URL",
256
+ interactive=True,
257
+ )
258
+
259
+ url_button = gr.Button("Submit for recognition")
260
+ url_output = gr.Textbox(label="Detected language from URL")
261
+ url_html_info = gr.HTML(label="Info")
262
+
263
+ upload_button.click(
264
+ process_uploaded_file,
265
+ inputs=[
266
+ model_dropdown,
267
+ uploaded_file,
268
+ ],
269
+ outputs=["dataframe", uploaded_html_info],
270
+ )
271
+
272
+ record_button.click(
273
+ process_microphone,
274
+ inputs=[
275
+ model_dropdown,
276
+ microphone,
277
+ ],
278
+ outputs=[recorded_output, recorded_html_info],
279
+ )
280
+
281
+ url_button.click(
282
+ process_url,
283
+ inputs=[
284
+ model_dropdown,
285
+ url_textbox,
286
+ ],
287
+ outputs=[url_output, url_html_info],
288
+ )
289
+
290
+ gr.Markdown(description)
291
+
292
+ if __name__ == "__main__":
293
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
294
+
295
+ logging.basicConfig(format=formatter, level=logging.INFO)
296
+
297
+ demo.launch()
model.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022-2024 Xiaomi Corp. (authors: Fangjun Kuang)
2
+ #
3
+ # See LICENSE for clarification regarding multiple authors
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import wave
18
+ from functools import lru_cache
19
+ from typing import Tuple, List
20
+
21
+ import numpy as np
22
+ import sherpa_onnx
23
+
24
+ from huggingface_hub import hf_hub_download
25
+
26
+ sample_rate = 16000
27
+
28
+
29
+ def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
30
+ """
31
+ Args:
32
+ wave_filename:
33
+ Path to a wave file. It should be single channel and each sample should
34
+ be 16-bit. Its sample rate does not need to be 16kHz.
35
+ Returns:
36
+ Return a tuple containing:
37
+ - A 1-D array of dtype np.float32 containing the samples, which are
38
+ normalized to the range [-1, 1].
39
+ - sample rate of the wave file
40
+ """
41
+
42
+ with wave.open(wave_filename) as f:
43
+ assert f.getnchannels() == 1, f.getnchannels()
44
+ assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
45
+ num_samples = f.getnframes()
46
+ samples = f.readframes(num_samples)
47
+ samples_int16 = np.frombuffer(samples, dtype=np.int16)
48
+ samples_float32 = samples_int16.astype(np.float32)
49
+
50
+ samples_float32 = samples_float32 / 32768
51
+ return samples_float32, f.getframerate()
52
+
53
+
54
+ def decode(
55
+ tagger: sherpa_onnx.AudioTagging,
56
+ filename: str,
57
+ top_k: int = -1,
58
+ ) -> List[sherpa_onnx.AudioEvent]:
59
+ s = tagger.create_stream()
60
+ samples, sample_rate = read_wave(filename)
61
+ s.accept_waveform(sample_rate, samples)
62
+ events = tagger.compute(s, top_k)
63
+ return events
64
+
65
+
66
+ def _get_nn_model_filename(
67
+ repo_id: str,
68
+ filename: str,
69
+ subfolder: str = ".",
70
+ ) -> str:
71
+ nn_model_filename = hf_hub_download(
72
+ repo_id=repo_id,
73
+ filename=filename,
74
+ subfolder=subfolder,
75
+ )
76
+ return nn_model_filename
77
+
78
+
79
+ @lru_cache(maxsize=8)
80
+ def get_pretrained_model(repo_id: str) -> sherpa_onnx.AudioTagging:
81
+ assert repo_id in (
82
+ "k2-fsa/sherpa-onnx-zipformer-small-audio-tagging-2024-04-15",
83
+ "k2-fsa/sherpa-onnx-zipformer-audio-tagging-2024-04-09",
84
+ ), repo_id
85
+
86
+ model = _get_nn_model_filename(
87
+ repo_id=repo_id,
88
+ filename="model.int8.onnx",
89
+ )
90
+
91
+ labels = _get_nn_model_filename(
92
+ repo_id=repo_id,
93
+ filename="class_labels_indices.csv",
94
+ )
95
+
96
+ config = sherpa_onnx.AudioTaggingConfig(
97
+ model=sherpa_onnx.AudioTaggingModelConfig(
98
+ zipformer=sherpa_onnx.OfflineZipformerAudioTaggingModelConfig(
99
+ model=model,
100
+ ),
101
+ num_threads=1,
102
+ debug=True,
103
+ provider="cpu",
104
+ ),
105
+ labels=labels,
106
+ top_k=5,
107
+ )
108
+ return sherpa_onnx.AudioTagging(config)
109
+
110
+
111
+ models = {
112
+ "k2-fsa/sherpa-onnx-zipformer-audio-tagging-2024-04-09": get_pretrained_model,
113
+ "k2-fsa/sherpa-onnx-zipformer-small-audio-tagging-2024-04-15": get_pretrained_model,
114
+ }