rayl-aoit commited on
Commit
291564a
1 Parent(s): be5dfc2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -0
app.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import uuid
4
+ import json
5
+ import librosa
6
+ import os
7
+ import tempfile
8
+ import soundfile as sf
9
+ import scipy.io.wavfile as wav
10
+
11
+ from transformers import pipeline, VitsModel, AutoTokenizer, set_seed
12
+ from nemo.collections.asr.models import EncDecMultiTaskModel
13
+
14
+ # Constants
15
+ SAMPLE_RATE = 16000 # Hz
16
+
17
+ # load ASR model
18
+ canary_model = EncDecMultiTaskModel.from_pretrained('nvidia/canary-1b')
19
+
20
+ # update dcode params
21
+ decode_cfg = canary_model.cfg.decoding
22
+ decode_cfg.beam.beam_size = 1
23
+ canary_model.change_decoding_strategy(decode_cfg)
24
+
25
+ # load TTS model
26
+ tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
27
+ tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
28
+
29
+ # Function to convert audio to text using ASR
30
+ def transcribe(audio_filepath):
31
+ if audio_filepath is None:
32
+ raise gr.Error("Please provide some input audio.")
33
+
34
+ utt_id = uuid.uuid4()
35
+ with tempfile.TemporaryDirectory() as tmpdir:
36
+ # Convert to 16 kHz
37
+ data, sr = librosa.load(audio_filepath, sr=None, mono=True)
38
+ if sr != SAMPLE_RATE:
39
+ data = librosa.resample(data, orig_sr=sr, target_sr=SAMPLE_RATE)
40
+ converted_audio_filepath = os.path.join(tmpdir, f"{utt_id}.wav")
41
+ sf.write(converted_audio_filepath, data, SAMPLE_RATE)
42
+
43
+ # Transcribe audio
44
+ duration = len(data) / SAMPLE_RATE
45
+ manifest_data = {
46
+ "audio_filepath": converted_audio_filepath,
47
+ "taskname": "asr",
48
+ "source_lang": "en",
49
+ "target_lang": "en",
50
+ "pnc": "no",
51
+ "answer": "predict",
52
+ "duration": str(duration),
53
+ }
54
+ manifest_filepath = os.path.join(tmpdir, f"{utt_id}.json")
55
+ with open(manifest_filepath, 'w') as fout:
56
+ fout.write(json.dumps(manifest_data))
57
+
58
+ if duration < 40:
59
+ transcription = canary_model.transcribe(manifest_filepath)[0]
60
+ else:
61
+ transcription = get_buffered_pred_feat_multitaskAED(
62
+ frame_asr,
63
+ canary_model.cfg.preprocessor,
64
+ model_stride_in_secs,
65
+ canary_model.device,
66
+ manifest=manifest_filepath,
67
+ )[0].text
68
+
69
+ return transcription
70
+
71
+ # Function to convert text to speech using TTS
72
+ def gen_speech(text):
73
+ set_seed(555) # Make it deterministic
74
+ input_text = tts_tokenizer(text, return_tensors="pt")
75
+ with torch.no_grad():
76
+ outputs = tts_model(**input_text)
77
+ waveform_np = outputs.waveform[0].cpu().numpy()
78
+ output_file = f"{str(uuid.uuid4())}.wav"
79
+ wav.write(output_file, rate=tts_model.config.sampling_rate, data=waveform_np)
80
+ return output_file
81
+
82
+ # Root function for Gradio interface
83
+ def start_process(audio_filepath):
84
+ transcription = transcribe(audio_filepath)
85
+ print("Done transcribing")
86
+ translation = "working in progress"
87
+ audio_output_filepath = gen_speech(transcription)
88
+ print("Done speaking")
89
+ return transcription, translation, audio_output_filepath
90
+
91
+
92
+ # Create Gradio interface
93
+ playground = gr.Blocks()
94
+
95
+ with playground:
96
+ with gr.Row():
97
+ with gr.Column():
98
+ input_audio = gr.Audio(sources=["microphone"], type="filepath", label="Input Audio")
99
+ transcipted_text = gr.Textbox(label="Transcription")
100
+ with gr.Column():
101
+ translated_speech = gr.Audio(type="filepath", label="Generated Speech")
102
+ translated_text = gr.Textbox(label="Translation")
103
+
104
+ with gr.Row():
105
+ with gr.Column():
106
+ submit_button = gr.Button(value="Start Process", variant="primary")
107
+ with gr.Column():
108
+ clear_button = gr.ClearButton(components=[input_audio, transcipted_text, translated_speech, translated_text], value="Clear")
109
+
110
+ submit_button.click(start_process, inputs=[input_audio], outputs=[transcipted_text, translated_speech, translated_text])