rayl-aoit commited on
Commit
1180d04
1 Parent(s): 79b79bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -22
app.py CHANGED
@@ -23,10 +23,8 @@ decode_cfg.beam.beam_size = 1
23
  canary_model.change_decoding_strategy(decode_cfg)
24
 
25
  # load TTS model
26
- tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
27
- tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
28
- tts_fra_model = VitsModel.from_pretrained("facebook/mms-tts-fra")
29
- tts_fra_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-fra")
30
 
31
  # Function to convert audio to text using ASR
32
  def gen_text(audio_filepath, action):
@@ -56,38 +54,49 @@ def gen_text(audio_filepath, action):
56
  manifest_filepath = os.path.join(tmpdir, f"{utt_id}.json")
57
  with open(manifest_filepath, 'w') as fout:
58
  fout.write(json.dumps(manifest_data))
59
-
60
- if duration < 40:
61
- predicted_text = canary_model.transcribe(manifest_filepath)[0]
62
- else:
63
- predicted_text = get_buffered_pred_feat_multitaskAED(
64
- frame_asr,
65
- canary_model.cfg.preprocessor,
66
- model_stride_in_secs,
67
- canary_model.device,
68
- manifest=manifest_filepath,
69
- )[0].text
 
70
 
71
  return predicted_text
72
 
73
  # Function to convert text to speech using TTS
74
- def gen_speech(text):
75
  set_seed(555) # Make it deterministic
76
- input_text = tts_fra_tokenizer(text, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
77
  with torch.no_grad():
78
- outputs = tts_fra_model(**input_text)
79
  waveform_np = outputs.waveform[0].cpu().numpy()
80
  output_file = f"{str(uuid.uuid4())}.wav"
81
  wav.write(output_file, rate=tts_model.config.sampling_rate, data=waveform_np)
82
  return output_file
83
 
84
  # Root function for Gradio interface
85
- def start_process(audio_filepath):
86
  transcription = gen_text(audio_filepath, "asr")
87
  print("Done transcribing")
88
  translation = gen_text(audio_filepath, "s2t_translation")
89
  print("Done translation")
90
- audio_output_filepath = gen_speech(transcription)
91
  print("Done speaking")
92
  return transcription, translation, audio_output_filepath
93
 
@@ -123,7 +132,7 @@ with playground:
123
  with gr.Column():
124
  submit_button = gr.Button(value="Start Process", variant="primary")
125
  with gr.Column():
126
- clear_button = gr.ClearButton(components=[input_audio, transcipted_text, translated_speech, translated_text, source_lang, target_lang], value="Clear")
127
 
128
  # with gr.Row():
129
  # gr.Examples(
@@ -133,6 +142,6 @@ with playground:
133
  # run_on_click=True, cache_examples=True, fn=start_process
134
  # )
135
 
136
- submit_button.click(start_process, inputs=[input_audio], outputs=[transcipted_text, translated_text, translated_speech])
137
 
138
  playground.launch()
 
23
  canary_model.change_decoding_strategy(decode_cfg)
24
 
25
  # load TTS model
26
+ # tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
27
+ # tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
 
 
28
 
29
  # Function to convert audio to text using ASR
30
  def gen_text(audio_filepath, action):
 
54
  manifest_filepath = os.path.join(tmpdir, f"{utt_id}.json")
55
  with open(manifest_filepath, 'w') as fout:
56
  fout.write(json.dumps(manifest_data))
57
+
58
+ predicted_text = canary_model.transcribe(manifest_filepath)[0]
59
+ # if duration < 40:
60
+ # predicted_text = canary_model.transcribe(manifest_filepath)[0]
61
+ # else:
62
+ # predicted_text = get_buffered_pred_feat_multitaskAED(
63
+ # frame_asr,
64
+ # canary_model.cfg.preprocessor,
65
+ # model_stride_in_secs,
66
+ # canary_model.device,
67
+ # manifest=manifest_filepath,
68
+ # )[0].text
69
 
70
  return predicted_text
71
 
72
  # Function to convert text to speech using TTS
73
+ def gen_speech(text, lang):
74
  set_seed(555) # Make it deterministic
75
+
76
+ if lang=="en":
77
+ model = "facebook/mms-tts-eng"
78
+ elif lang=="fr":
79
+ model = "facebook/mms-tts-fra"
80
+
81
+ # load TTS model
82
+ tts_model = VitsModel.from_pretrained(model)
83
+ tts_tokenizer = AutoTokenizer.from_pretrained(model)
84
+
85
+ input_text = tts_tokenizer(text, return_tensors="pt")
86
  with torch.no_grad():
87
+ outputs = tts_model(**input_text)
88
  waveform_np = outputs.waveform[0].cpu().numpy()
89
  output_file = f"{str(uuid.uuid4())}.wav"
90
  wav.write(output_file, rate=tts_model.config.sampling_rate, data=waveform_np)
91
  return output_file
92
 
93
  # Root function for Gradio interface
94
+ def start_process(audio_filepath, source_lang, target_lang):
95
  transcription = gen_text(audio_filepath, "asr")
96
  print("Done transcribing")
97
  translation = gen_text(audio_filepath, "s2t_translation")
98
  print("Done translation")
99
+ audio_output_filepath = gen_speech(transcription, target_lang)
100
  print("Done speaking")
101
  return transcription, translation, audio_output_filepath
102
 
 
132
  with gr.Column():
133
  submit_button = gr.Button(value="Start Process", variant="primary")
134
  with gr.Column():
135
+ clear_button = gr.ClearButton(components=[input_audio, source_lang, target_lang, transcipted_text, translated_text, translated_speech], value="Clear")
136
 
137
  # with gr.Row():
138
  # gr.Examples(
 
142
  # run_on_click=True, cache_examples=True, fn=start_process
143
  # )
144
 
145
+ submit_button.click(start_process, inputs=[input_audio, source_lang, target_lang], outputs=[transcipted_text, translated_text, translated_speech])
146
 
147
  playground.launch()