capradeepgujaran commited on
Commit
fae0258
β€’
1 Parent(s): f3da91c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +208 -0
app.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import gradio as gr
4
+ import fitz # PyMuPDF for reading PDF files
5
+ import pytesseract
6
+ from PIL import Image
7
+ import docx # for reading .docx files
8
+ from ragchecker import RAGResults, RAGChecker
9
+ from ragchecker.metrics import all_metrics
10
+ from llama_index.core import VectorStoreIndex, Document
11
+ from llama_index.embeddings.openai import OpenAIEmbedding
12
+ from llama_index.llms.openai import OpenAI
13
+ from llama_index.core import get_response_synthesizer
14
+ from dotenv import load_dotenv
15
+ from bert_score import score as bert_score
16
+
17
+ # Load environment variables from .env file
18
+ load_dotenv()
19
+
20
+ # Set the path for Tesseract OCR (only needed on Windows)
21
+ # On Linux-based systems (like Hugging Face Spaces), Tesseract is usually available via apt
22
+ # So you might not need to set this. Uncomment and adjust if necessary.
23
+ # pytesseract.pytesseract.tesseract_cmd = r'/usr/bin/tesseract'
24
+
25
+ # Initialize global variables
26
+ vector_index = None
27
+ query_log = [] # Store queries and results for RAGChecker
28
+
29
+ # Function to handle PDF and OCR for scanned PDFs
30
+ def load_pdf_manually(pdf_path):
31
+ doc = fitz.open(pdf_path)
32
+ text = ""
33
+ for page_num in range(doc.page_count):
34
+ page = doc[page_num]
35
+ page_text = page.get_text()
36
+
37
+ # If no text (i.e., scanned PDF), use OCR
38
+ if not page_text.strip():
39
+ pix = page.get_pixmap()
40
+ img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
41
+ page_text = pytesseract.image_to_string(img)
42
+
43
+ text += page_text
44
+ return text
45
+
46
+ # Function to handle .docx files
47
+ def load_docx_file(docx_path):
48
+ doc = docx.Document(docx_path)
49
+ full_text = []
50
+ for para in doc.paragraphs:
51
+ full_text.append(para.text)
52
+ return '\n'.join(full_text)
53
+
54
+ # Function to handle .txt files
55
+ def load_txt_file(txt_path):
56
+ with open(txt_path, 'r', encoding='utf-8') as f:
57
+ return f.read()
58
+
59
+ # General function to load a file based on its extension
60
+ def load_file_based_on_extension(file_path):
61
+ if file_path.endswith('.pdf'):
62
+ return load_pdf_manually(file_path)
63
+ elif file_path.endswith('.docx'):
64
+ return load_docx_file(file_path)
65
+ elif file_path.endswith('.txt'):
66
+ return load_txt_file(file_path)
67
+ else:
68
+ raise ValueError(f"Unsupported file format: {file_path}")
69
+
70
+ # Function to process uploaded files and create/update the vector index
71
+ def process_upload(files):
72
+ global vector_index
73
+
74
+ if not files:
75
+ return "No files uploaded.", None
76
+
77
+ documents = []
78
+ for file in files:
79
+ try:
80
+ with tempfile.NamedTemporaryFile(delete=False, suffix=file.name) as tmp:
81
+ tmp.write(file.read())
82
+ tmp_path = tmp.name
83
+ text = load_file_based_on_extension(tmp_path)
84
+ documents.append(Document(text=text))
85
+ os.unlink(tmp_path) # Clean up the temporary file
86
+ except ValueError as e:
87
+ return f"Skipping unsupported file: {file.name} ({e})", None
88
+ except Exception as e:
89
+ return f"Error processing file {file.name}: {e}", None
90
+
91
+ if documents:
92
+ embed_model = OpenAIEmbedding(model="text-embedding-3-large")
93
+ vector_index = VectorStoreIndex.from_documents(documents, embed_model=embed_model)
94
+ return f"Successfully indexed {len(documents)} files.", vector_index
95
+ else:
96
+ return "No valid documents were indexed.", None
97
+
98
+ # Function to handle queries
99
+ def query_app(query, model_name, use_rag_checker):
100
+ global vector_index, query_log
101
+
102
+ if vector_index is None:
103
+ return "No documents indexed yet. Please upload documents first.", None
104
+
105
+ # Initialize the LLM with the selected model
106
+ llm = OpenAI(model=model_name)
107
+
108
+ # Create a query engine and query the indexed documents
109
+ response_synthesizer = get_response_synthesizer(llm=llm)
110
+ query_engine = vector_index.as_query_engine(llm=llm, response_synthesizer=response_synthesizer)
111
+
112
+ try:
113
+ response = query_engine.query(query)
114
+ except Exception as e:
115
+ return f"Error during query processing: {e}", None
116
+
117
+ # Log query and generated response
118
+ generated_response = response.response
119
+ query_log.append({
120
+ "query_id": str(len(query_log) + 1),
121
+ "query": query,
122
+ "gt_answer": "Placeholder ground truth answer", # Replace with actual ground truth if available
123
+ "response": generated_response,
124
+ "retrieved_context": [{"text": doc.text} for doc in response.source_nodes]
125
+ })
126
+
127
+ # Initialize metrics dictionary
128
+ metrics = {}
129
+
130
+ # Calculate BERTScore if RAGChecker is selected
131
+ if use_rag_checker:
132
+ try:
133
+ rag_results = RAGResults.from_dict({"results": query_log})
134
+ evaluator = RAGChecker(
135
+ extractor_name="openai/gpt-4o-mini",
136
+ checker_name="openai/gpt-4o-mini",
137
+ batch_size_extractor=32,
138
+ batch_size_checker=32
139
+ )
140
+ evaluator.evaluate(rag_results, all_metrics)
141
+ metrics = rag_results.metrics
142
+
143
+ # Calculate BERTScore as an additional metric
144
+ gt_answer = ["Placeholder ground truth answer"] # Replace with actual ground truth
145
+ candidate = [generated_response]
146
+
147
+ P, R, F1 = bert_score(candidate, gt_answer, lang="en", verbose=False)
148
+ metrics['bertscore'] = {
149
+ "precision": P.mean().item() * 100,
150
+ "recall": R.mean().item() * 100,
151
+ "f1": F1.mean().item() * 100
152
+ }
153
+ except Exception as e:
154
+ metrics['error'] = f"Error calculating metrics: {e}"
155
+
156
+ if use_rag_checker:
157
+ return generated_response, metrics
158
+ else:
159
+ return generated_response, None
160
+
161
+ # Define the Gradio interface
162
+ def main():
163
+ with gr.Blocks(title="Document Processing App") as demo:
164
+ gr.Markdown("# πŸ“„ Document Processing and Querying App")
165
+
166
+ with gr.Tab("πŸ“€ Upload Documents"):
167
+ gr.Markdown("### Upload PDF, DOCX, or TXT files to index")
168
+ with gr.Row():
169
+ file_upload = gr.File(label="Upload Files", file_count="multiple", type="file")
170
+ upload_button = gr.Button("Upload and Index")
171
+ upload_status = gr.Textbox(label="Status", interactive=False)
172
+
173
+ upload_button.click(
174
+ fn=process_upload,
175
+ inputs=[file_upload],
176
+ outputs=[upload_status, gr.State()]
177
+ )
178
+
179
+ with gr.Tab("❓ Ask a Question"):
180
+ gr.Markdown("### Query the indexed documents")
181
+ with gr.Column():
182
+ query_input = gr.Textbox(label="Enter your question", placeholder="Type your question here...")
183
+ model_dropdown = gr.Dropdown(
184
+ choices=["gpt-3.5-turbo", "gpt-4"],
185
+ value="gpt-3.5-turbo",
186
+ label="Select Model"
187
+ )
188
+ rag_checkbox = gr.Checkbox(label="Use RAG Checker", value=True)
189
+ query_button = gr.Button("Ask")
190
+ with gr.Column():
191
+ answer_output = gr.Textbox(label="Answer", interactive=False)
192
+ metrics_output = gr.JSON(label="Metrics", interactive=False)
193
+
194
+ query_button.click(
195
+ fn=query_app,
196
+ inputs=[query_input, model_dropdown, rag_checkbox],
197
+ outputs=[answer_output, metrics_output]
198
+ )
199
+
200
+ gr.Markdown("""
201
+ ---
202
+ **Note:** Ensure you upload documents before attempting to query. Metrics are calculated only if RAG Checker is enabled.
203
+ """)
204
+
205
+ demo.launch()
206
+
207
+ if __name__ == "__main__":
208
+ main()