CamiloVega commited on
Commit
ce57b87
Β·
verified Β·
1 Parent(s): 8d29369

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +349 -0
app.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from typing import List, Dict
4
+ import torch
5
+ import gradio as gr
6
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
7
+ from langchain.embeddings import HuggingFaceEmbeddings
8
+ from langchain.vectorstores import FAISS
9
+ from langchain.chains import RetrievalQA
10
+ from langchain.prompts import PromptTemplate
11
+ from langchain.llms import HuggingFacePipeline
12
+ from langchain_community.document_loaders import (
13
+ PyPDFLoader,
14
+ Docx2txtLoader,
15
+ CSVLoader,
16
+ UnstructuredFileLoader
17
+ )
18
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
19
+ import spaces
20
+ import tempfile
21
+
22
+ # Configure logging
23
+ logging.basicConfig(
24
+ level=logging.INFO,
25
+ format='%(asctime)s - %(levelname)s - %(message)s'
26
+ )
27
+ logger = logging.getLogger(__name__)
28
+
29
+ # Constants
30
+ MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"
31
+ SUPPORTED_FORMATS = [".pdf", ".docx", ".doc", ".csv", ".txt"]
32
+
33
+ class DocumentLoader:
34
+ """Enhanced document loader supporting multiple file formats."""
35
+
36
+ @staticmethod
37
+ def load_file(file_path: str) -> List:
38
+ """Load a single file based on its extension."""
39
+ ext = os.path.splitext(file_path)[1].lower()
40
+ try:
41
+ if ext == '.pdf':
42
+ loader = PyPDFLoader(file_path)
43
+ elif ext in ['.docx', '.doc']:
44
+ loader = Docx2txtLoader(file_path)
45
+ elif ext == '.csv':
46
+ loader = CSVLoader(file_path)
47
+ else: # fallback for txt and other text files
48
+ loader = UnstructuredFileLoader(file_path)
49
+
50
+ documents = loader.load()
51
+
52
+ # Add metadata
53
+ for doc in documents:
54
+ doc.metadata.update({
55
+ 'title': os.path.basename(file_path),
56
+ 'type': 'document',
57
+ 'format': ext[1:],
58
+ 'language': 'auto'
59
+ })
60
+
61
+ logger.info(f"Successfully loaded {file_path}")
62
+ return documents
63
+
64
+ except Exception as e:
65
+ logger.error(f"Error loading {file_path}: {str(e)}")
66
+ raise
67
+
68
+ class RAGSystem:
69
+ """Enhanced RAG system with dynamic document loading."""
70
+
71
+ def __init__(self, model_name: str = MODEL_NAME):
72
+ self.model_name = model_name
73
+ self.embeddings = None
74
+ self.vector_store = None
75
+ self.qa_chain = None
76
+ self.tokenizer = None
77
+ self.model = None
78
+ self.is_initialized = False
79
+
80
+ def initialize_model(self):
81
+ """Initialize the base model and tokenizer."""
82
+ try:
83
+ logger.info("Initializing language model...")
84
+
85
+ # Initialize embeddings
86
+ self.embeddings = HuggingFaceEmbeddings(
87
+ model_name="intfloat/multilingual-e5-large",
88
+ model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'},
89
+ encode_kwargs={'normalize_embeddings': True}
90
+ )
91
+
92
+ # Initialize model and tokenizer
93
+ self.tokenizer = AutoTokenizer.from_pretrained(
94
+ self.model_name,
95
+ trust_remote_code=True
96
+ )
97
+
98
+ self.model = AutoModelForCausalLM.from_pretrained(
99
+ self.model_name,
100
+ torch_dtype=torch.float16,
101
+ trust_remote_code=True,
102
+ device_map="auto"
103
+ )
104
+
105
+ # Create generation pipeline
106
+ pipe = pipeline(
107
+ "text-generation",
108
+ model=self.model,
109
+ tokenizer=self.tokenizer,
110
+ max_new_tokens=512,
111
+ temperature=0.1,
112
+ top_p=0.95,
113
+ repetition_penalty=1.15,
114
+ device_map="auto"
115
+ )
116
+
117
+ self.llm = HuggingFacePipeline(pipeline=pipe)
118
+ self.is_initialized = True
119
+
120
+ logger.info("Model initialization completed")
121
+
122
+ except Exception as e:
123
+ logger.error(f"Error during model initialization: {str(e)}")
124
+ raise
125
+
126
+ def process_documents(self, files: List[tempfile._TemporaryFileWrapper]) -> None:
127
+ """Process uploaded documents and update the vector store."""
128
+ try:
129
+ documents = []
130
+ for file in files:
131
+ docs = DocumentLoader.load_file(file.name)
132
+ documents.extend(docs)
133
+
134
+ if not documents:
135
+ raise ValueError("No documents were successfully loaded.")
136
+
137
+ # Process documents
138
+ text_splitter = RecursiveCharacterTextSplitter(
139
+ chunk_size=800,
140
+ chunk_overlap=200,
141
+ separators=["\n\n", "\n", ". ", " ", ""],
142
+ length_function=len
143
+ )
144
+
145
+ chunks = text_splitter.split_documents(documents)
146
+
147
+ # Create or update vector store
148
+ if self.vector_store is None:
149
+ self.vector_store = FAISS.from_documents(chunks, self.embeddings)
150
+ else:
151
+ self.vector_store.add_documents(chunks)
152
+
153
+ # Initialize QA chain
154
+ prompt_template = """
155
+ Context: {context}
156
+
157
+ Based on the provided context, please answer the following question clearly and concisely.
158
+ If the information is not in the context, please say so explicitly.
159
+
160
+ Question: {question}
161
+ """
162
+
163
+ PROMPT = PromptTemplate(
164
+ template=prompt_template,
165
+ input_variables=["context", "question"]
166
+ )
167
+
168
+ self.qa_chain = RetrievalQA.from_chain_type(
169
+ llm=self.llm,
170
+ chain_type="stuff",
171
+ retriever=self.vector_store.as_retriever(
172
+ search_kwargs={"k": 6}
173
+ ),
174
+ return_source_documents=True,
175
+ chain_type_kwargs={"prompt": PROMPT}
176
+ )
177
+
178
+ logger.info(f"Successfully processed {len(documents)} documents")
179
+
180
+ except Exception as e:
181
+ logger.error(f"Error processing documents: {str(e)}")
182
+ raise
183
+
184
+ def generate_response(self, question: str) -> Dict:
185
+ """Generate response for a given question."""
186
+ if not self.is_initialized or self.qa_chain is None:
187
+ return {
188
+ 'answer': "Please upload some documents first before asking questions.",
189
+ 'sources': []
190
+ }
191
+
192
+ try:
193
+ result = self.qa_chain({"query": question})
194
+
195
+ response = {
196
+ 'answer': result['result'],
197
+ 'sources': []
198
+ }
199
+
200
+ for doc in result['source_documents']:
201
+ source = {
202
+ 'title': doc.metadata.get('title', 'Unknown'),
203
+ 'content': doc.page_content[:200] + "..." if len(doc.page_content) > 200 else doc.page_content,
204
+ 'metadata': doc.metadata
205
+ }
206
+ response['sources'].append(source)
207
+
208
+ return response
209
+
210
+ except Exception as e:
211
+ logger.error(f"Error generating response: {str(e)}")
212
+ raise
213
+
214
+ @spaces.GPU(duration=60)
215
+ def process_response(user_input: str, chat_history: List, files: List) -> tuple:
216
+ """Process user input and generate response."""
217
+ try:
218
+ if not rag_system.is_initialized:
219
+ rag_system.initialize_model()
220
+
221
+ if files and (rag_system.vector_store is None):
222
+ rag_system.process_documents(files)
223
+
224
+ response = rag_system.generate_response(user_input)
225
+
226
+ # Clean and format response
227
+ answer = response['answer']
228
+ if "Answer:" in answer:
229
+ answer = answer.split("Answer:")[-1].strip()
230
+
231
+ # Format sources
232
+ sources = set([source['title'] for source in response['sources'][:3]])
233
+ if sources:
234
+ answer += "\n\nπŸ“š Sources consulted:\n" + "\n".join([f"β€’ {source}" for source in sources])
235
+
236
+ chat_history.append((user_input, answer))
237
+ return chat_history
238
+
239
+ except Exception as e:
240
+ logger.error(f"Error in process_response: {str(e)}")
241
+ error_message = f"Sorry, an error occurred: {str(e)}"
242
+ chat_history.append((user_input, error_message))
243
+ return chat_history
244
+
245
+ # Initialize RAG system
246
+ logger.info("Initializing RAG system...")
247
+ try:
248
+ rag_system = RAGSystem()
249
+ logger.info("RAG system created successfully")
250
+ except Exception as e:
251
+ logger.error(f"Failed to create RAG system: {str(e)}")
252
+ raise
253
+
254
+ # Create Gradio interface
255
+ try:
256
+ logger.info("Creating Gradio interface...")
257
+ with gr.Blocks(css="div.gradio-container {background-color: #f0f2f6}") as demo:
258
+ gr.HTML("""
259
+ <div style="text-align: center; max-width: 800px; margin: 0 auto; padding: 20px;">
260
+ <h1 style="color: #2d333a;">πŸ“š DocumentGPT</h1>
261
+ <p style="color: #4a5568;">
262
+ Your AI Assistant for Document Analysis and Q&A
263
+ </p>
264
+ </div>
265
+ """)
266
+
267
+ with gr.Row():
268
+ with gr.Column(scale=1):
269
+ files = gr.Files(
270
+ label="Upload Your Documents",
271
+ file_types=SUPPORTED_FORMATS,
272
+ file_count="multiple"
273
+ )
274
+ gr.HTML("""
275
+ <div style="font-size: 0.9em; color: #666; margin-top: 0.5em;">
276
+ Supported formats: PDF, DOCX, CSV, TXT
277
+ </div>
278
+ """)
279
+
280
+ chatbot = gr.Chatbot(
281
+ show_label=False,
282
+ container=True,
283
+ height=500,
284
+ bubble_full_width=True,
285
+ show_copy_button=True,
286
+ scale=2
287
+ )
288
+
289
+ with gr.Row():
290
+ message = gr.Textbox(
291
+ placeholder="πŸ’­ Ask me anything about your documents...",
292
+ show_label=False,
293
+ container=False,
294
+ scale=8,
295
+ autofocus=True
296
+ )
297
+ clear = gr.Button("πŸ—‘οΈ Clear", size="sm", scale=1)
298
+
299
+ # Instructions
300
+ gr.HTML("""
301
+ <div style="background-color: #f8f9fa; padding: 15px; border-radius: 10px; margin: 20px 0;">
302
+ <h3 style="color: #2d333a; margin-bottom: 10px;">πŸ” How to use:</h3>
303
+ <ol style="color: #666; margin-left: 20px;">
304
+ <li>Upload one or more documents (PDF, DOCX, CSV, or TXT)</li>
305
+ <li>Wait for the documents to be processed</li>
306
+ <li>Ask questions about your documents</li>
307
+ <li>View sources used in the responses</li>
308
+ </ol>
309
+ </div>
310
+ """)
311
+
312
+ # Footer with credits
313
+ gr.HTML("""
314
+ <div style="text-align: center; max-width: 800px; margin: 20px auto; padding: 20px;
315
+ background-color: #f8f9fa; border-radius: 10px;">
316
+ <div style="margin-bottom: 15px;">
317
+ <h3 style="color: #2d333a;">⚑ About this assistant</h3>
318
+ <p style="color: #666; font-size: 14px;">
319
+ This application uses RAG (Retrieval Augmented Generation) technology combining:
320
+ </p>
321
+ <ul style="list-style: none; color: #666; font-size: 14px;">
322
+ <li>πŸ”Ή LLM Engine: Llama-2-7b-chat-hf</li>
323
+ <li>πŸ”Ή Embeddings: multilingual-e5-large</li>
324
+ <li>πŸ”Ή Vector Store: FAISS</li>
325
+ </ul>
326
+ </div>
327
+ <div style="border-top: 1px solid #ddd; padding-top: 15px;">
328
+ <p style="color: #666; font-size: 14px;">
329
+ Created by <a href="https://www.linkedin.com/in/camilo-vega-169084b1/"
330
+ target="_blank" style="color: #2196F3; text-decoration: none;">Camilo Vega</a>,
331
+ AI Professor and Solutions Consultant πŸ€–
332
+ </p>
333
+ </div>
334
+ </div>
335
+ """)
336
+
337
+ # Configure event handlers
338
+ def submit(user_input, chat_history, files):
339
+ return process_response(user_input, chat_history, files)
340
+
341
+ message.submit(submit, [message, chatbot, files], [chatbot])
342
+ clear.click(lambda: None, None, chatbot)
343
+
344
+ logger.info("Gradio interface created successfully")
345
+ demo.launch()
346
+
347
+ except Exception as e:
348
+ logger.error(f"Error in Gradio interface creation: {str(e)}")
349
+ raise