capradeepgujaran's picture
Update app.py
1a16755 verified
raw
history blame
7.37 kB
import os
import tempfile
import gradio as gr
import fitz # PyMuPDF for reading PDF files
import pytesseract
from PIL import Image
import docx # for reading .docx files
from ragchecker import RAGResults, RAGChecker
from ragchecker.metrics import all_metrics
from llama_index.core import VectorStoreIndex, Document
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI
from llama_index.core import get_response_synthesizer
from dotenv import load_dotenv
from bert_score import score as bert_score
# Load environment variables from .env file
load_dotenv()
# Set the path for Tesseract OCR (only needed on Windows)
# On Linux-based systems (like Hugging Face Spaces), Tesseract is usually available via apt
# Uncomment and adjust if necessary
# pytesseract.pytesseract.tesseract_cmd = r'/usr/bin/tesseract'
# Initialize global variables
vector_index = None
query_log = [] # Store queries and results for RAGChecker
# Function to handle PDF and OCR for scanned PDFs
def load_pdf_manually(pdf_path):
doc = fitz.open(pdf_path)
text = ""
for page_num in range(doc.page_count):
page = doc[page_num]
page_text = page.get_text()
# If no text (i.e., scanned PDF), use OCR
if not page_text.strip():
pix = page.get_pixmap()
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
page_text = pytesseract.image_to_string(img)
text += page_text
return text
# Function to handle .docx files
def load_docx_file(docx_path):
doc = docx.Document(docx_path)
full_text = []
for para in doc.paragraphs:
full_text.append(para.text)
return '\n'.join(full_text)
# Function to handle .txt files
def load_txt_file(txt_path):
with open(txt_path, 'r', encoding='utf-8') as f:
return f.read()
# General function to load a file based on its extension
def load_file_based_on_extension(file_path):
if file_path.endswith('.pdf'):
return load_pdf_manually(file_path)
elif file_path.endswith('.docx'):
return load_docx_file(file_path)
elif file_path.endswith('.txt'):
return load_txt_file(file_path)
else:
raise ValueError(f"Unsupported file format: {file_path}")
# Function to process uploaded files and create/update the vector index
def process_upload(files):
global vector_index
if not files:
return "No files uploaded.", None
documents = []
for file_path in files:
try:
text = load_file_based_on_extension(file_path)
documents.append(Document(text=text))
except ValueError as e:
return f"Skipping unsupported file: {file_path} ({e})", None
except Exception as e:
return f"Error processing file {file_path}: {e}", None
if documents:
embed_model = OpenAIEmbedding(model="text-embedding-3-large")
vector_index = VectorStoreIndex.from_documents(documents, embed_model=embed_model)
return f"Successfully indexed {len(documents)} files.", vector_index
else:
return "No valid documents were indexed.", None
# Function to handle queries
def query_app(query, model_name, use_rag_checker):
global vector_index, query_log
if vector_index is None:
return "No documents indexed yet. Please upload documents first.", None
# Initialize the LLM with the selected model
llm = OpenAI(model=model_name)
# Create a query engine and query the indexed documents
response_synthesizer = get_response_synthesizer(llm=llm)
query_engine = vector_index.as_query_engine(llm=llm, response_synthesizer=response_synthesizer)
try:
response = query_engine.query(query)
except Exception as e:
return f"Error during query processing: {e}", None
# Log query and generated response
generated_response = response.response
query_log.append({
"query_id": str(len(query_log) + 1),
"query": query,
"gt_answer": "Placeholder ground truth answer", # Replace with actual ground truth if available
"response": generated_response,
"retrieved_context": [{"text": doc.text} for doc in response.source_nodes]
})
# Initialize metrics dictionary
metrics = {}
# Calculate BERTScore if RAGChecker is selected
if use_rag_checker:
try:
rag_results = RAGResults.from_dict({"results": query_log})
evaluator = RAGChecker(
extractor_name="openai/gpt-4o-mini",
checker_name="openai/gpt-4o-mini",
batch_size_extractor=32,
batch_size_checker=32
)
evaluator.evaluate(rag_results, all_metrics)
metrics = rag_results.metrics
# Calculate BERTScore as an additional metric
gt_answer = ["Placeholder ground truth answer"] # Replace with actual ground truth
candidate = [generated_response]
P, R, F1 = bert_score(candidate, gt_answer, lang="en", verbose=False)
metrics['bertscore'] = {
"precision": P.mean().item() * 100,
"recall": R.mean().item() * 100,
"f1": F1.mean().item() * 100
}
except Exception as e:
metrics['error'] = f"Error calculating metrics: {e}"
if use_rag_checker:
return generated_response, metrics
else:
return generated_response, None
# Define the Gradio interface
def main():
with gr.Blocks(title="Document Processing App") as demo:
gr.Markdown("# πŸ“„ Document Processing and Querying App")
with gr.Tab("πŸ“€ Upload Documents"):
gr.Markdown("### Upload PDF, DOCX, or TXT files to index")
with gr.Row():
file_upload = gr.File(label="Upload Files", file_count="multiple", type="filepath")
upload_button = gr.Button("Upload and Index")
upload_status = gr.Textbox(label="Status", interactive=False)
upload_button.click(
fn=process_upload,
inputs=[file_upload],
outputs=[upload_status]
)
with gr.Tab("❓ Ask a Question"):
gr.Markdown("### Query the indexed documents")
with gr.Column():
query_input = gr.Textbox(label="Enter your question", placeholder="Type your question here...")
model_dropdown = gr.Dropdown(
choices=["gpt-3.5-turbo", "gpt-4"],
value="gpt-3.5-turbo",
label="Select Model"
)
rag_checkbox = gr.Checkbox(label="Use RAG Checker", value=True)
query_button = gr.Button("Ask")
with gr.Column():
answer_output = gr.Textbox(label="Answer", interactive=False)
metrics_output = gr.JSON(label="Metrics")
query_button.click(
fn=query_app,
inputs=[query_input, model_dropdown, rag_checkbox],
outputs=[answer_output, metrics_output]
)
gr.Markdown("""
---
**Note:** Ensure you upload documents before attempting to query. Metrics are calculated only if RAG Checker is enabled.
""")
demo.launch()
if __name__ == "__main__":
main()