ASRfr / app.py
Kr08's picture
Update app.py
8669b40 verified
import os
import json
import gradio as gr
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
import logging
import traceback
import sys
from audio_processing import AudioProcessor
import spaces
from chunkedTranscriber import ChunkedTranscriber
from system_message import SYSTEM_MESSAGE
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)
def load_qa_model():
"""Load question-answering model with long context support."""
try:
from transformers import AutoModelForCausalLM, AwqConfig
model_id = "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4"
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=os.getenv("HF_TOKEN"))
quantization_config = AwqConfig(
bits=4,
fuse_max_seq_len=8192, # Configure tokenizer for long inputs
do_fuse=True,
)
# Load the model with simplified rope_scaling configuration
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
device_map="auto",
rope_scaling={
"type": "dynamic", # Simplified type as expected by the model
"factor": 8.0 # Scaling factor to support longer contexts
},
use_auth_token=os.getenv("HF_TOKEN"),
quantization_config=quantization_config
)
# Initialize the pipeline
qa_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=1024, # Limit generation as needed
)
return qa_pipeline
except Exception as e:
logger.error(f"Failed to load Q&A model: {str(e)}")
return None
# def load_qa_model():
# """Load question-answering model"""
# try:
# model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
# qa_pipeline = pipeline(
# "text-generation",
# model="hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4",
# model_kwargs={"torch_dtype": torch.bfloat16},
# device_map="auto",
# use_auth_token=os.getenv("HF_TOKEN")
# )
# return qa_pipeline
# except Exception as e:
# logger.error(f"Failed to load Q&A model: {str(e)}")
# return None
def load_summarization_model():
"""Load summarization model"""
try:
summarizer = pipeline(
"summarization",
model="sshleifer/distilbart-cnn-12-6",
device=0 if torch.cuda.is_available() else -1
)
return summarizer
except Exception as e:
logger.error(f"Failed to load summarization model: {str(e)}")
return None
@spaces.GPU(duration=180)
def process_audio(audio_file, translate=False):
"""Process audio file"""
transcriber = ChunkedTranscriber(chunk_size=5, overlap=1)
_translation, _output = transcriber.transcribe_audio(audio_file, translate=True)
return _translation, _output
# try:
# processor = AudioProcessor()
# language_segments, final_segments = processor.process_audio(audio_file, translate)
# # Format output
# transcription = ""
# full_text = ""
# # Add language detection information
# for segment in language_segments:
# transcription += f"Language: {segment['language']}\n"
# transcription += f"Time: {segment['start']:.2f}s - {segment['end']:.2f}s\n\n"
# # Add transcription/translation information
# transcription += "Transcription with language detection:\n\n"
# for segment in final_segments:
# transcription += f"[{segment['start']:.2f}s - {segment['end']:.2f}s] ({segment['language']}):\n"
# transcription += f"Original: {segment['text']}\n"
# if translate and 'translated' in segment:
# transcription += f"Translated: {segment['translated']}\n"
# full_text += segment['translated'] + " "
# else:
# full_text += segment['text'] + " "
# transcription += "\n"
# return transcription, full_text
# except Exception as e:
# logger.error(f"Audio processing failed: {str(e)}")
# raise gr.Error(f"Processing failed: {str(e)}")
# @spaces.GPU(duration=180)
# def summarize_text(text):
# """Summarize text"""
# try:
# summarizer = load_summarization_model()
# if summarizer is None:
# return "Summarization model could not be loaded."
# logger.info("Successfully loaded summarization Model")
# # logger.info(f"\n\n {text}\n")
# summary = summarizer(text, max_length=150, min_length=50, do_sample=False)[0]['summary_text']
# return summary
# except Exception as e:
# logger.error(f"Summarization failed: {str(e)}")
# return "Error occurred during summarization."
@spaces.GPU(duration=180)
def answer_question(context, question):
"""Answer questions about the text"""
try:
qa_pipeline = load_qa_model()
if qa_pipeline is None:
return "Q&A model could not be loaded."
if not question :
return "Please enter your Question"
messages = [
# {"role": "system", "content": "You are a helpful assistant who can answer questions based on the given context."},
{"role":"system", "content": SYSTEM_MESSAGE},
{"role": "user", "content": f"Context: {context}\n Question: {question}"}
]
response = qa_pipeline(messages, max_new_tokens=256)[0]['generated_text']
logger.info(response)
return response[-1]['content']
except Exception as e:
logger.error(f"Q&A failed: {str(e)}")
return f"Error occurred during Q&A process: {str(e)}"
# Create Gradio interface
with gr.Blocks() as iface:
gr.Markdown("# Automatic Speech Recognition for Indic Languages")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(type="filepath")
translate_checkbox = gr.Checkbox(label="Enable Translation")
process_button = gr.Button("Process Audio")
with gr.Column():
# ASR_RESULT = gr.Textbox(label="Output")
full_text_output = gr.Textbox(label="Full Text", lines=5)
translation_output = gr.Textbox(label="Transcription/Translation", lines=10)
with gr.Row():
# with gr.Column():
# summarize_button = gr.Button("Summarize")
# summary_output = gr.Textbox(label="Summary", lines=3)
with gr.Column():
question_input = gr.Textbox(label="Ask a question about the transcription")
answer_button = gr.Button("Get Answer")
answer_output = gr.Textbox(label="Answer", lines=3)
# Set up event handlers
process_button.click(
process_audio,
inputs=[audio_input, translate_checkbox],
outputs=[translation_output, full_text_output]
# outputs=[ASR_RESULT]
)
# translated_text = ''.join(item['translated'] for item in ASR_RESULT if 'translated' in item)
# summarize_button.click(
# summarize_text,
# # inputs=[ASR_RESULT],
# inputs=[translation_output],
# outputs=[summary_output]
# )
answer_button.click(
answer_question,
inputs=[full_text_output, question_input],
outputs=[answer_output]
)
# Add system information
gr.Markdown(f"""
## System Information
- Device: {"CUDA" if torch.cuda.is_available() else "CPU"}
- CUDA Available: {"Yes" if torch.cuda.is_available() else "No"}
## Features
- Automatic language detection
- High-quality transcription using MMS
- Optional translation to English
- Text summarization
- Question answering
""")
if __name__ == "__main__":
iface.launch(server_port=None)