|
import os |
|
import json |
|
import gradio as gr |
|
import torch |
|
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM |
|
import logging |
|
import traceback |
|
import sys |
|
from audio_processing import AudioProcessor |
|
import spaces |
|
from chunkedTranscriber import ChunkedTranscriber |
|
from system_message import SYSTEM_MESSAGE |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
handlers=[logging.StreamHandler(sys.stdout)] |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def load_qa_model(): |
|
"""Load question-answering model with long context support.""" |
|
try: |
|
from transformers import AutoModelForCausalLM, AwqConfig |
|
|
|
model_id = "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=os.getenv("HF_TOKEN")) |
|
|
|
quantization_config = AwqConfig( |
|
bits=4, |
|
fuse_max_seq_len=8192, |
|
do_fuse=True, |
|
) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
torch_dtype=torch.bfloat16, |
|
low_cpu_mem_usage=True, |
|
device_map="auto", |
|
rope_scaling={ |
|
"type": "dynamic", |
|
"factor": 8.0 |
|
}, |
|
use_auth_token=os.getenv("HF_TOKEN"), |
|
quantization_config=quantization_config |
|
) |
|
|
|
|
|
qa_pipeline = pipeline( |
|
"text-generation", |
|
model=model, |
|
tokenizer=tokenizer, |
|
max_new_tokens=1024, |
|
) |
|
|
|
return qa_pipeline |
|
|
|
except Exception as e: |
|
logger.error(f"Failed to load Q&A model: {str(e)}") |
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_summarization_model(): |
|
"""Load summarization model""" |
|
try: |
|
summarizer = pipeline( |
|
"summarization", |
|
model="sshleifer/distilbart-cnn-12-6", |
|
device=0 if torch.cuda.is_available() else -1 |
|
) |
|
return summarizer |
|
except Exception as e: |
|
logger.error(f"Failed to load summarization model: {str(e)}") |
|
return None |
|
|
|
|
|
@spaces.GPU(duration=180) |
|
def process_audio(audio_file, translate=False): |
|
"""Process audio file""" |
|
transcriber = ChunkedTranscriber(chunk_size=5, overlap=1) |
|
_translation, _output = transcriber.transcribe_audio(audio_file, translate=True) |
|
return _translation, _output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@spaces.GPU(duration=180) |
|
def answer_question(context, question): |
|
"""Answer questions about the text""" |
|
try: |
|
qa_pipeline = load_qa_model() |
|
if qa_pipeline is None: |
|
return "Q&A model could not be loaded." |
|
if not question : |
|
return "Please enter your Question" |
|
|
|
messages = [ |
|
|
|
{"role":"system", "content": SYSTEM_MESSAGE}, |
|
{"role": "user", "content": f"Context: {context}\n Question: {question}"} |
|
] |
|
response = qa_pipeline(messages, max_new_tokens=256)[0]['generated_text'] |
|
logger.info(response) |
|
return response[-1]['content'] |
|
except Exception as e: |
|
logger.error(f"Q&A failed: {str(e)}") |
|
return f"Error occurred during Q&A process: {str(e)}" |
|
|
|
|
|
|
|
with gr.Blocks() as iface: |
|
gr.Markdown("# Automatic Speech Recognition for Indic Languages") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
audio_input = gr.Audio(type="filepath") |
|
translate_checkbox = gr.Checkbox(label="Enable Translation") |
|
process_button = gr.Button("Process Audio") |
|
|
|
with gr.Column(): |
|
|
|
full_text_output = gr.Textbox(label="Full Text", lines=5) |
|
translation_output = gr.Textbox(label="Transcription/Translation", lines=10) |
|
|
|
with gr.Row(): |
|
|
|
|
|
|
|
|
|
with gr.Column(): |
|
question_input = gr.Textbox(label="Ask a question about the transcription") |
|
answer_button = gr.Button("Get Answer") |
|
answer_output = gr.Textbox(label="Answer", lines=3) |
|
|
|
|
|
process_button.click( |
|
process_audio, |
|
inputs=[audio_input, translate_checkbox], |
|
outputs=[translation_output, full_text_output] |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
answer_button.click( |
|
answer_question, |
|
inputs=[full_text_output, question_input], |
|
outputs=[answer_output] |
|
) |
|
|
|
|
|
gr.Markdown(f""" |
|
## System Information |
|
- Device: {"CUDA" if torch.cuda.is_available() else "CPU"} |
|
- CUDA Available: {"Yes" if torch.cuda.is_available() else "No"} |
|
|
|
## Features |
|
- Automatic language detection |
|
- High-quality transcription using MMS |
|
- Optional translation to English |
|
- Text summarization |
|
- Question answering |
|
""") |
|
|
|
if __name__ == "__main__": |
|
iface.launch(server_port=None) |