Spaces:
Sleeping
Sleeping
import os | |
import logging | |
from typing import List, Dict | |
import torch | |
import gradio as gr | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.vectorstores import FAISS | |
from langchain.chains import RetrievalQA | |
from langchain.prompts import PromptTemplate | |
from langchain.llms import HuggingFacePipeline | |
from langchain_community.document_loaders import PyPDFLoader | |
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer | |
import spaces | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
# Get HuggingFace token from environment variables | |
hf_token = os.environ.get('HUGGINGFACE_TOKEN') or os.environ.get('HF_TOKEN') | |
if not hf_token: | |
logger.error("No Hugging Face token found in environment variables") | |
logger.error("Please set either HUGGINGFACE_TOKEN or HF_TOKEN in your Space settings") | |
raise ValueError("Missing Hugging Face token. Please configure it in the Space settings under Repository Secrets.") | |
# Constants | |
MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf" | |
KNOWLEDGE_BASE_DIR = "." | |
class DocumentLoader: | |
"""Class to manage PDF document loading.""" | |
def load_pdfs(directory_path: str) -> List: | |
documents = [] | |
pdf_files = [ | |
f for f in os.listdir(directory_path) | |
if f.endswith('.pdf') and | |
(f.startswith('valencia') or 'fislac' in f.lower() or 'Valencia' in f) | |
] | |
if not pdf_files: | |
logger.warning(f"No matching PDF files found in {directory_path}") | |
return documents | |
for pdf_file in pdf_files: | |
pdf_path = os.path.join(directory_path, pdf_file) | |
try: | |
loader = PyPDFLoader(pdf_path) | |
pdf_documents = loader.load() | |
for doc in pdf_documents: | |
doc.metadata.update({ | |
'title': pdf_file, | |
'type': 'technical' if 'valencia' in pdf_file.lower() or 'Valencia' in pdf_file else 'qa', | |
'language': 'en', | |
'page': doc.metadata.get('page', 0) | |
}) | |
documents.append(doc) | |
logger.info(f"Document {pdf_file} loaded successfully") | |
except Exception as e: | |
logger.error(f"Error loading {pdf_file}: {str(e)}") | |
return documents | |
class TextProcessor: | |
"""Class to process and split text into chunks.""" | |
def __init__(self): | |
self.technical_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=800, | |
chunk_overlap=200, | |
separators=["\n\n", "\n", ". ", " ", ""], | |
length_function=len | |
) | |
self.qa_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=500, | |
chunk_overlap=100, | |
separators=["\n\n", "\n", ". ", " ", ""], | |
length_function=len | |
) | |
def process_documents(self, documents: List) -> List: | |
if not documents: | |
logger.warning("No documents to process") | |
return [] | |
processed_chunks = [] | |
for doc in documents: | |
splitter = self.technical_splitter if doc.metadata['type'] == 'technical' else self.qa_splitter | |
chunks = splitter.split_documents([doc]) | |
processed_chunks.extend(chunks) | |
logger.info(f"Documents processed into {len(processed_chunks)} chunks") | |
return processed_chunks | |
class RAGSystem: | |
"""Main RAG system class.""" | |
def __init__(self, model_name: str = MODEL_NAME): | |
self.model_name = model_name | |
self.embeddings = None | |
self.vector_store = None | |
self.qa_chain = None | |
self.tokenizer = None | |
self.model = None | |
def initialize_system(self): | |
"""Initialize complete RAG system.""" | |
try: | |
logger.info("Starting RAG system initialization...") | |
# Load and process documents | |
loader = DocumentLoader() | |
documents = loader.load_pdfs(KNOWLEDGE_BASE_DIR) | |
if not documents: | |
raise ValueError("No documents were loaded. Please check the PDF files in the root directory.") | |
processor = TextProcessor() | |
processed_chunks = processor.process_documents(documents) | |
if not processed_chunks: | |
raise ValueError("No chunks were created from the documents.") | |
# Initialize embeddings | |
logger.info("Initializing embeddings...") | |
self.embeddings = HuggingFaceEmbeddings( | |
model_name="intfloat/multilingual-e5-large", | |
model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'}, | |
encode_kwargs={'normalize_embeddings': True} | |
) | |
# Create vector store | |
logger.info("Creating vector store...") | |
self.vector_store = FAISS.from_documents( | |
processed_chunks, | |
self.embeddings | |
) | |
# Initialize LLM | |
logger.info("Initializing language model...") | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
self.model_name, | |
token=hf_token, | |
trust_remote_code=True | |
) | |
self.model = AutoModelForCausalLM.from_pretrained( | |
self.model_name, | |
token=hf_token, | |
torch_dtype=torch.float16, | |
trust_remote_code=True, | |
device_map="auto" | |
) | |
# Create generation pipeline | |
logger.info("Creating generation pipeline...") | |
pipe = pipeline( | |
"text-generation", | |
model=self.model, | |
tokenizer=self.tokenizer, | |
max_new_tokens=512, | |
temperature=0.1, | |
top_p=0.95, | |
repetition_penalty=1.15, | |
device_map="auto" | |
) | |
llm = HuggingFacePipeline(pipeline=pipe) | |
# Create prompt template | |
prompt_template = """ | |
Context: {context} | |
Based on the context above, please provide a clear and concise answer to the following question. | |
If the information is not in the context, explicitly state so. | |
Question: {question} | |
""" | |
PROMPT = PromptTemplate( | |
template=prompt_template, | |
input_variables=["context", "question"] | |
) | |
# Set up QA chain | |
logger.info("Setting up QA chain...") | |
self.qa_chain = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=self.vector_store.as_retriever( | |
search_kwargs={"k": 6} | |
), | |
return_source_documents=True, | |
chain_type_kwargs={"prompt": PROMPT} | |
) | |
logger.info("RAG system initialized successfully") | |
except Exception as e: | |
logger.error(f"Error during RAG system initialization: {str(e)}") | |
raise | |
def generate_response(self, question: str) -> Dict: | |
"""Generate response for a given question.""" | |
try: | |
result = self.qa_chain({"query": question}) | |
response = { | |
'answer': result['result'], | |
'sources': [] | |
} | |
for doc in result['source_documents']: | |
source = { | |
'title': doc.metadata.get('title', 'Unknown'), | |
'content': doc.page_content[:200] + "..." if len(doc.page_content) > 200 else doc.page_content, | |
'metadata': doc.metadata | |
} | |
response['sources'].append(source) | |
return response | |
except Exception as e: | |
logger.error(f"Error generating response: {str(e)}") | |
raise | |
def process_response(user_input: str, chat_history: List) -> tuple: | |
"""Process user input and generate response.""" | |
try: | |
response = rag_system.generate_response(user_input) | |
# Clean and format response | |
answer = response['answer'] | |
if "Answer:" in answer: | |
answer = answer.split("Answer:")[-1].strip() | |
# Format sources | |
sources = set([source['title'] for source in response['sources'][:3]]) | |
if sources: | |
answer += "\n\nπ Sources consulted:\n" + "\n".join([f"β’ {source}" for source in sources]) | |
chat_history.append((user_input, answer)) | |
return chat_history | |
except Exception as e: | |
logger.error(f"Error in process_response: {str(e)}") | |
error_message = f"Sorry, an error occurred: {str(e)}" | |
chat_history.append((user_input, error_message)) | |
return chat_history | |
# Initialize RAG system | |
logger.info("Initializing RAG system...") | |
try: | |
rag_system = RAGSystem() | |
rag_system.initialize_system() | |
logger.info("RAG system initialization completed") | |
except Exception as e: | |
logger.error(f"Failed to initialize RAG system: {str(e)}") | |
raise | |
# Create Gradio interface | |
try: | |
logger.info("Creating Gradio interface...") | |
with gr.Blocks(css="div.gradio-container {background-color: #f0f2f6}") as demo: | |
gr.HTML(""" | |
<div style="text-align: center; max-width: 800px; margin: 0 auto; padding: 20px;"> | |
<h1 style="color: #2d333a;">π FislacBot</h1> | |
<p style="color: #4a5568;"> | |
AI Assistant specialized in fiscal analysis and FISLAC documentation | |
</p> | |
</div> | |
""") | |
chatbot = gr.Chatbot( | |
show_label=False, | |
container=True, | |
height=500, | |
bubble_full_width=True, | |
show_copy_button=True, | |
scale=2 | |
) | |
with gr.Row(): | |
message = gr.Textbox( | |
placeholder="π Type your question here...", | |
show_label=False, | |
container=False, | |
scale=8, | |
autofocus=True | |
) | |
clear = gr.Button("ποΈ Clear", size="sm", scale=1) | |
# Suggested questions | |
gr.HTML('<p style="color: #2d333a; font-weight: bold; margin: 20px 0 10px 0;">π‘ Suggested questions:</p>') | |
with gr.Row(): | |
suggestion1 = gr.Button("What is FISLAC?", scale=1) | |
suggestion2 = gr.Button("What are the main modules of FISLAC?", scale=1) | |
with gr.Row(): | |
suggestion3 = gr.Button("What macroeconomic variables are relevant for advanced economies?", scale=1) | |
suggestion4 = gr.Button("How does fiscal risk compare between emerging and advanced countries?", scale=1) | |
# Footer | |
gr.HTML(""" | |
<div style="text-align: center; max-width: 800px; margin: 20px auto; padding: 20px; | |
background-color: #f8f9fa; border-radius: 10px;"> | |
<div style="margin-bottom: 15px;"> | |
<h3 style="color: #2d333a;">π About this assistant</h3> | |
<p style="color: #666; font-size: 14px;"> | |
This bot uses RAG (Retrieval Augmented Generation) technology combining: | |
</p> | |
<ul style="list-style: none; color: #666; font-size: 14px;"> | |
<li>πΉ LLM Engine: Llama-2-7b-chat-hf</li> | |
<li>πΉ Embeddings: multilingual-e5-large</li> | |
<li>πΉ Vector Store: FAISS</li> | |
</ul> | |
</div> | |
<div style="border-top: 1px solid #ddd; padding-top: 15px;"> | |
<p style="color: #666; font-size: 14px;"> | |
<strong>Current Knowledge Base:</strong><br> | |
β’ Valencia et al. (2022) - "Assessing macro-fiscal risk for Latin American and Caribbean countries"<br> | |
β’ FISLAC Technical Documentation | |
</p> | |
</div> | |
<div style="border-top: 1px solid #ddd; margin-top: 15px; padding-top: 15px;"> | |
<p style="color: #666; font-size: 14px;"> | |
Created by <a href="https://www.linkedin.com/in/camilo-vega-169084b1/" | |
target="_blank" style="color: #2196F3; text-decoration: none;">Camilo Vega</a>, | |
AI Consultant π€ | |
</p> | |
</div> | |
</div> | |
""") | |
# Configure event handlers | |
def submit(user_input, chat_history): | |
return process_response(user_input, chat_history) | |
message.submit(submit, [message, chatbot], [chatbot]) | |
clear.click(lambda: None, None, chatbot) | |
# Handle suggested questions | |
for btn in [suggestion1, suggestion2, suggestion3, suggestion4]: | |
btn.click(submit, [btn, chatbot], [chatbot]) | |
logger.info("Gradio interface created successfully") | |
demo.launch() | |
except Exception as e: | |
logger.error(f"Error in Gradio interface creation: {str(e)}") | |
raise |