|
import os |
|
import json |
|
import gradio as gr |
|
import torch |
|
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM |
|
import logging |
|
import traceback |
|
import sys |
|
from audio_processing import AudioProcessor |
|
import spaces |
|
from chunkedTranscriber import ChunkedTranscriber |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
handlers=[logging.StreamHandler(sys.stdout)] |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
def load_qa_model(): |
|
"""Load question-answering model""" |
|
try: |
|
model_id = "meta-llama/Meta-Llama-3-8B-Instruct" |
|
qa_pipeline = pipeline( |
|
"text-generation", |
|
model="hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", |
|
model_kwargs={"torch_dtype": torch.bfloat16}, |
|
device_map="auto", |
|
use_auth_token=os.getenv("HF_TOKEN") |
|
) |
|
return qa_pipeline |
|
except Exception as e: |
|
logger.error(f"Failed to load Q&A model: {str(e)}") |
|
return None |
|
|
|
def load_summarization_model(): |
|
"""Load summarization model""" |
|
try: |
|
summarizer = pipeline( |
|
"summarization", |
|
model="sshleifer/distilbart-cnn-12-6", |
|
device=0 if torch.cuda.is_available() else -1 |
|
) |
|
return summarizer |
|
except Exception as e: |
|
logger.error(f"Failed to load summarization model: {str(e)}") |
|
return None |
|
|
|
|
|
@spaces.GPU(duration=120) |
|
def process_audio(audio_file, translate=False): |
|
"""Process audio file""" |
|
transcriber = ChunkedTranscriber(chunk_size=5, overlap=1) |
|
results = transcriber.transcribe_audio(audio_file, translate=True) |
|
return results |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@spaces.GPU(duration=120) |
|
def summarize_text(text): |
|
"""Summarize text""" |
|
try: |
|
|
|
summarizer = load_summarization_model() |
|
|
|
if summarizer is None: |
|
return "Summarization model could not be loaded." |
|
logger.info("Successfully loaded summarization Model") |
|
data = json.loads(text) |
|
translated_text = ''.join(item['translated'] for item in data if 'translated' in item) |
|
|
|
logger.info(f"\n\nWorking on text:\n{full_text}") |
|
summary = summarizer( full_text, max_length=150, min_length=50, do_sample=False)[0]['summary_text'] |
|
return summary |
|
except Exception as e: |
|
logger.error(f"Summarization failed: {str(e)}") |
|
return "Error occurred during summarization." |
|
|
|
|
|
@spaces.GPU(duration=120) |
|
def answer_question(context, question): |
|
"""Answer questions about the text""" |
|
try: |
|
qa_pipeline = load_qa_model() |
|
if qa_pipeline is None: |
|
return "Q&A model could not be loaded." |
|
if not question : |
|
return "Please enter your Question" |
|
|
|
messages = [ |
|
{"role": "system", "content": "You are a helpful assistant who can answer questions based on the given context."}, |
|
{"role": "user", "content": f"Context: {''.join(item['translated'] for item in context if 'translated' in item)}\n\nQuestion: {question}"} |
|
] |
|
|
|
response = qa_pipeline(messages, max_new_tokens=256)[0]['generated_text'] |
|
return response |
|
except Exception as e: |
|
logger.error(f"Q&A failed: {str(e)}") |
|
return f"Error occurred during Q&A process: {str(e)}" |
|
|
|
|
|
|
|
with gr.Blocks() as iface: |
|
gr.Markdown("# Automatic Speech Recognition for Indic Languages") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
audio_input = gr.Audio(type="filepath") |
|
translate_checkbox = gr.Checkbox(label="Enable Translation") |
|
process_button = gr.Button("Process Audio") |
|
|
|
with gr.Column(): |
|
ASR_RESULT = gr.Textbox(label="Output") |
|
|
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
summarize_button = gr.Button("Summarize") |
|
summary_output = gr.Textbox(label="Summary", lines=3) |
|
|
|
with gr.Column(): |
|
question_input = gr.Textbox(label="Ask a question about the transcription") |
|
answer_button = gr.Button("Get Answer") |
|
answer_output = gr.Textbox(label="Answer", lines=3) |
|
|
|
|
|
process_button.click( |
|
process_audio, |
|
inputs=[audio_input, translate_checkbox], |
|
|
|
outputs=[ASR_RESULT] |
|
) |
|
|
|
|
|
summarize_button.click( |
|
summarize_text, |
|
inputs=[ASR_RESULT], |
|
|
|
outputs=[summary_output] |
|
) |
|
|
|
answer_button.click( |
|
answer_question, |
|
inputs=[ASR_RESULT, question_input], |
|
outputs=[answer_output] |
|
) |
|
|
|
|
|
gr.Markdown(f""" |
|
## System Information |
|
- Device: {"CUDA" if torch.cuda.is_available() else "CPU"} |
|
- CUDA Available: {"Yes" if torch.cuda.is_available() else "No"} |
|
|
|
## Features |
|
- Automatic language detection |
|
- High-quality transcription using MMS |
|
- Optional translation to English |
|
- Text summarization |
|
- Question answering |
|
""") |
|
|
|
if __name__ == "__main__": |
|
iface.launch(server_port=None) |