Update app.py
Browse files
app.py
CHANGED
@@ -23,10 +23,8 @@ decode_cfg.beam.beam_size = 1
|
|
23 |
canary_model.change_decoding_strategy(decode_cfg)
|
24 |
|
25 |
# load TTS model
|
26 |
-
tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
|
27 |
-
tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
|
28 |
-
tts_fra_model = VitsModel.from_pretrained("facebook/mms-tts-fra")
|
29 |
-
tts_fra_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-fra")
|
30 |
|
31 |
# Function to convert audio to text using ASR
|
32 |
def gen_text(audio_filepath, action):
|
@@ -56,38 +54,49 @@ def gen_text(audio_filepath, action):
|
|
56 |
manifest_filepath = os.path.join(tmpdir, f"{utt_id}.json")
|
57 |
with open(manifest_filepath, 'w') as fout:
|
58 |
fout.write(json.dumps(manifest_data))
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
70 |
|
71 |
return predicted_text
|
72 |
|
73 |
# Function to convert text to speech using TTS
|
74 |
-
def gen_speech(text):
|
75 |
set_seed(555) # Make it deterministic
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
with torch.no_grad():
|
78 |
-
outputs =
|
79 |
waveform_np = outputs.waveform[0].cpu().numpy()
|
80 |
output_file = f"{str(uuid.uuid4())}.wav"
|
81 |
wav.write(output_file, rate=tts_model.config.sampling_rate, data=waveform_np)
|
82 |
return output_file
|
83 |
|
84 |
# Root function for Gradio interface
|
85 |
-
def start_process(audio_filepath):
|
86 |
transcription = gen_text(audio_filepath, "asr")
|
87 |
print("Done transcribing")
|
88 |
translation = gen_text(audio_filepath, "s2t_translation")
|
89 |
print("Done translation")
|
90 |
-
audio_output_filepath = gen_speech(transcription)
|
91 |
print("Done speaking")
|
92 |
return transcription, translation, audio_output_filepath
|
93 |
|
@@ -123,7 +132,7 @@ with playground:
|
|
123 |
with gr.Column():
|
124 |
submit_button = gr.Button(value="Start Process", variant="primary")
|
125 |
with gr.Column():
|
126 |
-
clear_button = gr.ClearButton(components=[input_audio,
|
127 |
|
128 |
# with gr.Row():
|
129 |
# gr.Examples(
|
@@ -133,6 +142,6 @@ with playground:
|
|
133 |
# run_on_click=True, cache_examples=True, fn=start_process
|
134 |
# )
|
135 |
|
136 |
-
submit_button.click(start_process, inputs=[input_audio], outputs=[transcipted_text, translated_text, translated_speech])
|
137 |
|
138 |
playground.launch()
|
|
|
23 |
canary_model.change_decoding_strategy(decode_cfg)
|
24 |
|
25 |
# load TTS model
|
26 |
+
# tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
|
27 |
+
# tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
|
|
|
|
|
28 |
|
29 |
# Function to convert audio to text using ASR
|
30 |
def gen_text(audio_filepath, action):
|
|
|
54 |
manifest_filepath = os.path.join(tmpdir, f"{utt_id}.json")
|
55 |
with open(manifest_filepath, 'w') as fout:
|
56 |
fout.write(json.dumps(manifest_data))
|
57 |
+
|
58 |
+
predicted_text = canary_model.transcribe(manifest_filepath)[0]
|
59 |
+
# if duration < 40:
|
60 |
+
# predicted_text = canary_model.transcribe(manifest_filepath)[0]
|
61 |
+
# else:
|
62 |
+
# predicted_text = get_buffered_pred_feat_multitaskAED(
|
63 |
+
# frame_asr,
|
64 |
+
# canary_model.cfg.preprocessor,
|
65 |
+
# model_stride_in_secs,
|
66 |
+
# canary_model.device,
|
67 |
+
# manifest=manifest_filepath,
|
68 |
+
# )[0].text
|
69 |
|
70 |
return predicted_text
|
71 |
|
72 |
# Function to convert text to speech using TTS
|
73 |
+
def gen_speech(text, lang):
|
74 |
set_seed(555) # Make it deterministic
|
75 |
+
|
76 |
+
if lang=="en":
|
77 |
+
model = "facebook/mms-tts-eng"
|
78 |
+
elif lang=="fr":
|
79 |
+
model = "facebook/mms-tts-fra"
|
80 |
+
|
81 |
+
# load TTS model
|
82 |
+
tts_model = VitsModel.from_pretrained(model)
|
83 |
+
tts_tokenizer = AutoTokenizer.from_pretrained(model)
|
84 |
+
|
85 |
+
input_text = tts_tokenizer(text, return_tensors="pt")
|
86 |
with torch.no_grad():
|
87 |
+
outputs = tts_model(**input_text)
|
88 |
waveform_np = outputs.waveform[0].cpu().numpy()
|
89 |
output_file = f"{str(uuid.uuid4())}.wav"
|
90 |
wav.write(output_file, rate=tts_model.config.sampling_rate, data=waveform_np)
|
91 |
return output_file
|
92 |
|
93 |
# Root function for Gradio interface
|
94 |
+
def start_process(audio_filepath, source_lang, target_lang):
|
95 |
transcription = gen_text(audio_filepath, "asr")
|
96 |
print("Done transcribing")
|
97 |
translation = gen_text(audio_filepath, "s2t_translation")
|
98 |
print("Done translation")
|
99 |
+
audio_output_filepath = gen_speech(transcription, target_lang)
|
100 |
print("Done speaking")
|
101 |
return transcription, translation, audio_output_filepath
|
102 |
|
|
|
132 |
with gr.Column():
|
133 |
submit_button = gr.Button(value="Start Process", variant="primary")
|
134 |
with gr.Column():
|
135 |
+
clear_button = gr.ClearButton(components=[input_audio, source_lang, target_lang, transcipted_text, translated_text, translated_speech], value="Clear")
|
136 |
|
137 |
# with gr.Row():
|
138 |
# gr.Examples(
|
|
|
142 |
# run_on_click=True, cache_examples=True, fn=start_process
|
143 |
# )
|
144 |
|
145 |
+
submit_button.click(start_process, inputs=[input_audio, source_lang, target_lang], outputs=[transcipted_text, translated_text, translated_speech])
|
146 |
|
147 |
playground.launch()
|