csukuangfj commited on
Commit
09d9587
·
1 Parent(s): 6b31279

small fixes

Browse files
Files changed (2) hide show
  1. app.py +19 -6
  2. model.py +25 -0
app.py CHANGED
@@ -16,6 +16,9 @@
16
  # See the License for the specific language governing permissions and
17
  # limitations under the License.
18
 
 
 
 
19
  import os
20
  import time
21
  from datetime import datetime
@@ -23,9 +26,16 @@ from datetime import datetime
23
  import gradio as gr
24
  import torchaudio
25
 
26
- from model import get_gigaspeech_pre_trained_model, sample_rate
 
 
 
 
27
 
28
- models = {"english": get_gigaspeech_pre_trained_model()}
 
 
 
29
 
30
 
31
  def convert_to_wav(in_filename: str) -> str:
@@ -39,8 +49,9 @@ def convert_to_wav(in_filename: str) -> str:
39
  demo = gr.Blocks()
40
 
41
 
42
- def process(in_filename: str) -> str:
43
  print("in_filename", in_filename)
 
44
  filename = convert_to_wav(in_filename)
45
 
46
  now = datetime.now()
@@ -63,7 +74,7 @@ def process(in_filename: str) -> str:
63
  )
64
  wave = wave[0] # use only the first channel.
65
 
66
- hyp = models["english"].decode_waves([wave])[0]
67
 
68
  date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f")
69
  end = time.time()
@@ -82,6 +93,8 @@ def process(in_filename: str) -> str:
82
 
83
  with demo:
84
  gr.Markdown("Upload audio from disk or record from microphone for recognition")
 
 
85
  with gr.Tabs():
86
  with gr.TabItem("Upload from disk"):
87
  uploaded_file = gr.inputs.Audio(
@@ -110,12 +123,12 @@ with demo:
110
 
111
  upload_button.click(
112
  process,
113
- inputs=uploaded_file,
114
  outputs=uploaded_output,
115
  )
116
  record_button.click(
117
  process,
118
- inputs=microphone,
119
  outputs=recorded_output,
120
  )
121
 
 
16
  # See the License for the specific language governing permissions and
17
  # limitations under the License.
18
 
19
+ # References:
20
+ # https://gradio.app/docs/#dropdown
21
+
22
  import os
23
  import time
24
  from datetime import datetime
 
26
  import gradio as gr
27
  import torchaudio
28
 
29
+ from model import (
30
+ get_gigaspeech_pre_trained_model,
31
+ sample_rate,
32
+ get_wenetspeech_pre_trained_model,
33
+ )
34
 
35
+ models = {
36
+ "Chinese": get_wenetspeech_pre_trained_model(),
37
+ "English": get_gigaspeech_pre_trained_model(),
38
+ }
39
 
40
 
41
  def convert_to_wav(in_filename: str) -> str:
 
49
  demo = gr.Blocks()
50
 
51
 
52
+ def process(in_filename: str, language: str) -> str:
53
  print("in_filename", in_filename)
54
+ print("language", language)
55
  filename = convert_to_wav(in_filename)
56
 
57
  now = datetime.now()
 
74
  )
75
  wave = wave[0] # use only the first channel.
76
 
77
+ hyp = models[language].decode_waves([wave])[0]
78
 
79
  date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f")
80
  end = time.time()
 
93
 
94
  with demo:
95
  gr.Markdown("Upload audio from disk or record from microphone for recognition")
96
+ languages = gr.inputs.Radio(label="Language", choices=list(models.keys()))
97
+
98
  with gr.Tabs():
99
  with gr.TabItem("Upload from disk"):
100
  uploaded_file = gr.inputs.Audio(
 
123
 
124
  upload_button.click(
125
  process,
126
+ inputs=[uploaded_file, language],
127
  outputs=uploaded_output,
128
  )
129
  record_button.click(
130
  process,
131
+ inputs=[microphone, language],
132
  outputs=recorded_output,
133
  )
134
 
model.py CHANGED
@@ -47,3 +47,28 @@ def get_gigaspeech_pre_trained_model():
47
  sample_rate=sample_rate,
48
  device="cpu",
49
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  sample_rate=sample_rate,
48
  device="cpu",
49
  )
50
+
51
+
52
+ @lru_cache(maxsize=1)
53
+ def get_wenetspeech_pre_trained_model():
54
+ nn_model_filename = hf_hub_download(
55
+ repo_id="luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2",
56
+ filename="cpu_jit_epoch_10_avg_2_torch_1.7.1.pt",
57
+ subfolder="exp",
58
+ )
59
+
60
+ token_filename = hf_hub_download(
61
+ repo_id="luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2",
62
+ filename="tokens.txt",
63
+ subfolder="data/lang_char",
64
+ )
65
+
66
+ return OfflineAsr(
67
+ nn_model_filename=nn_model_filename,
68
+ bpe_model_filename=None,
69
+ token_filename=token_filename,
70
+ decoding_method="greedy_search",
71
+ num_active_paths=4,
72
+ sample_rate=sample_rate,
73
+ device="cpu",
74
+ )