Porjaz's picture
Update app.py
403b777 verified
raw
history blame
7.11 kB
import spaces
import os
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import gc
from functools import partial
import gradio as gr
import torch
from speechbrain.inference.interfaces import Pretrained, foreign_class
from transformers import T5Tokenizer, T5ForConditionalGeneration
import librosa
import whisper_timestamped as whisper
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline, Wav2Vec2ForCTC, AutoProcessor
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cuda.matmul.allow_tf32 = True
def clean_up_memory():
gc.collect()
torch.cuda.empty_cache()
@spaces.GPU(duration=30)
def return_prediction_whisper_mic(mic=None, progress=gr.Progress(), device=device):
progress(0, desc="Транскриптот се генерира")
if mic is not None:
download_path = mic.split(".")[0] + ".txt"
waveform, sr = librosa.load(mic, sr=16000)
# waveform = waveform[:30*sr]
whisper_result = whisper_classifier.classify_file_whisper_mkd(waveform, device)
else:
return "You must provide a mic recording"
recap_result = ""
progress(0.75, desc=" Пост-процесирање на транскриптот")
for k, segment in enumerate(whisper_result):
recap_result += segment[0] + " "
clean_up_memory()
progress(1.0, desc=" Крај на транскрипцијата")
with open(download_path, "w") as f:
f.write(recap_result)
return recap_result, download_path
@spaces.GPU(duration=60)
def return_prediction_whisper_file(file=None, progress=gr.Progress(), device=device):
whisper_result = []
progress(0, desc="  Транскриптот се генерира")
if file is not None:
download_path = file.split(".")[0] + ".txt"
waveform, sr = librosa.load(file, sr=16000)
# waveform = waveform[:3600*sr]
whisper_result = whisper_classifier.classify_file_whisper_mkd(waveform, device)
else:
return "You must provide a mic recording"
recap_result = ""
progress(0.75, desc=" Пост-процесирање на транскриптот")
for k, segment in enumerate(whisper_result):
recap_result += segment[0] + " "
clean_up_memory()
progress(1.0, desc=" Крај на транскрипцијата")
with open(download_path, "w") as f:
f.write(recap_result)
return recap_result, download_path
# Create a partial function with the device pre-applied
return_prediction_whisper_mic_with_device = partial(return_prediction_whisper_mic, device=device)
return_prediction_whisper_file_with_device = partial(return_prediction_whisper_file, device=device)
# Load the ASR models
whisper_classifier = foreign_class(source="Macedonian-ASR/buki-whisper-capitalised-2.0", pymodule_file="custom_interface_app.py", classname="ASR")
whisper_classifier = whisper_classifier.to(device)
whisper_classifier.eval()
# Load the T5 tokenizer and model for restoring capitalization
recap_model_name = "Macedonian-ASR/mt5-restore-capitalization-macedonian"
recap_tokenizer = T5Tokenizer.from_pretrained(recap_model_name)
recap_model = T5ForConditionalGeneration.from_pretrained(recap_model_name, torch_dtype=torch.float16)
recap_model.to(device)
recap_model.eval()
with gr.Blocks() as mic_transcribe_whisper:
def clear_outputs():
return None, "", None
with gr.Row():
audio_input = gr.Audio(sources="microphone", type="filepath", label="Record Audio")
with gr.Row():
transcribe_button = gr.Button("Транскрибирај")
clear_button = gr.Button("Исчисти ги резултатите")
with gr.Row():
output_text = gr.Textbox(label="Транскрипција")
with gr.Row():
download_file = gr.File(label="Зачувај го транскриптот", file_count="single", height=50)
transcribe_button.click(
fn=return_prediction_whisper_mic_with_device,
inputs=[audio_input],
outputs=[output_text, download_file],
)
clear_button.click(
fn=clear_outputs,
inputs=[],
outputs=[audio_input, output_text, download_file],
)
with gr.Blocks() as file_transcribe_whisper:
def clear_outputs():
return {audio_input: None, output_text: "", download_file: None}
with gr.Row():
audio_input = gr.Audio(sources="upload", type="filepath", label="Upload Audio")
with gr.Row():
transcribe_button = gr.Button("Транскрибирај")
clear_button = gr.Button("Исчисти ги резултатите")
with gr.Row():
output_text = gr.Textbox(label="Транскрипција")
with gr.Row():
download_file = gr.File(label="Зачувај го транскриптот", file_count="single", height=50)
transcribe_button.click(
fn=return_prediction_whisper_file_with_device,
inputs=[audio_input],
outputs=[output_text, download_file],
)
clear_button.click(
fn=clear_outputs,
inputs=[],
outputs=[audio_input, output_text, download_file],
)
project_description = '''
<img src="https://i.ibb.co/hYhkkhg/Buki-logo-1.jpg"
alt="Bookie logo"
style="float: right; width: 150px; height: 150px; margin-left: 10px;" />
## Автори:
1. **Дејан Порјазовски**
2. **Илина Јакимовска**
3. **Ордан Чукалиев**
4. **Никола Стиков**
Оваа колаборација е дел од активностите на **Центарот за напредни интердисциплинарни истражувања ([ЦеНИИс](https://ukim.edu.mk/en/centri/centar-za-napredni-interdisciplinarni-istrazhuvanja-ceniis))** при УКИМ.
'''
# Custom CSS
css = """
.gradio-container {
background-color: #f0f0f0; /* Set your desired background color */
}
.custom-markdown p, .custom-markdown li, .custom-markdown h2, .custom-markdown a, .custom-markdown strong {
font-size: 15px !important;
font-family: Arial, sans-serif !important;
color: black !important;
}
button {
color: orange !important;
}
.gradio-container {
background-color: #f3f3f3 !important;
}
"""
transcriber_app = gr.Blocks(css=css, delete_cache=(60, 120))
with transcriber_app:
state = gr.State()
gr.Markdown(project_description, elem_classes="custom-markdown")
gr.TabbedInterface(
[mic_transcribe_whisper, file_transcribe_whisper],
[" Буки-Whisper транскрипција од микрофон", "Буки-Whisper транскрипција од фајл"],
)
state = gr.State(value=[], delete_callback=lambda v: print("STATE DELETED"))
transcriber_app.unload(return_prediction_whisper_mic)
transcriber_app.unload(return_prediction_whisper_file)
# transcriber_app.launch(debug=True, share=True, ssl_verify=False)
if __name__ == "__main__":
transcriber_app.queue()
transcriber_app.launch(share=True)