Fislac_Bot / app.py
CamiloVega's picture
Update app.py
e5afc54 verified
import os
import logging
from typing import List, Dict
import torch
import gradio as gr
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain.llms import HuggingFacePipeline
from langchain_community.document_loaders import PyPDFLoader
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import spaces
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Get HuggingFace token from environment variables
hf_token = os.environ.get('HUGGINGFACE_TOKEN') or os.environ.get('HF_TOKEN')
if not hf_token:
logger.error("No Hugging Face token found in environment variables")
logger.error("Please set either HUGGINGFACE_TOKEN or HF_TOKEN in your Space settings")
raise ValueError("Missing Hugging Face token. Please configure it in the Space settings under Repository Secrets.")
# Constants
MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"
KNOWLEDGE_BASE_DIR = "."
class DocumentLoader:
"""Class to manage PDF document loading."""
@staticmethod
def load_pdfs(directory_path: str) -> List:
documents = []
pdf_files = [
f for f in os.listdir(directory_path)
if f.endswith('.pdf') and
(f.startswith('valencia') or 'fislac' in f.lower() or 'Valencia' in f)
]
if not pdf_files:
logger.warning(f"No matching PDF files found in {directory_path}")
return documents
for pdf_file in pdf_files:
pdf_path = os.path.join(directory_path, pdf_file)
try:
loader = PyPDFLoader(pdf_path)
pdf_documents = loader.load()
for doc in pdf_documents:
doc.metadata.update({
'title': pdf_file,
'type': 'technical' if 'valencia' in pdf_file.lower() or 'Valencia' in pdf_file else 'qa',
'language': 'en',
'page': doc.metadata.get('page', 0)
})
documents.append(doc)
logger.info(f"Document {pdf_file} loaded successfully")
except Exception as e:
logger.error(f"Error loading {pdf_file}: {str(e)}")
return documents
class TextProcessor:
"""Class to process and split text into chunks."""
def __init__(self):
self.technical_splitter = RecursiveCharacterTextSplitter(
chunk_size=800,
chunk_overlap=200,
separators=["\n\n", "\n", ". ", " ", ""],
length_function=len
)
self.qa_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=100,
separators=["\n\n", "\n", ". ", " ", ""],
length_function=len
)
def process_documents(self, documents: List) -> List:
if not documents:
logger.warning("No documents to process")
return []
processed_chunks = []
for doc in documents:
splitter = self.technical_splitter if doc.metadata['type'] == 'technical' else self.qa_splitter
chunks = splitter.split_documents([doc])
processed_chunks.extend(chunks)
logger.info(f"Documents processed into {len(processed_chunks)} chunks")
return processed_chunks
class RAGSystem:
"""Main RAG system class."""
def __init__(self, model_name: str = MODEL_NAME):
self.model_name = model_name
self.embeddings = None
self.vector_store = None
self.qa_chain = None
self.tokenizer = None
self.model = None
def initialize_system(self):
"""Initialize complete RAG system."""
try:
logger.info("Starting RAG system initialization...")
# Load and process documents
loader = DocumentLoader()
documents = loader.load_pdfs(KNOWLEDGE_BASE_DIR)
if not documents:
raise ValueError("No documents were loaded. Please check the PDF files in the root directory.")
processor = TextProcessor()
processed_chunks = processor.process_documents(documents)
if not processed_chunks:
raise ValueError("No chunks were created from the documents.")
# Initialize embeddings
logger.info("Initializing embeddings...")
self.embeddings = HuggingFaceEmbeddings(
model_name="intfloat/multilingual-e5-large",
model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
# Create vector store
logger.info("Creating vector store...")
self.vector_store = FAISS.from_documents(
processed_chunks,
self.embeddings
)
# Initialize LLM
logger.info("Initializing language model...")
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
token=hf_token,
trust_remote_code=True
)
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
token=hf_token,
torch_dtype=torch.float16,
trust_remote_code=True,
device_map="auto"
)
# Create generation pipeline
logger.info("Creating generation pipeline...")
pipe = pipeline(
"text-generation",
model=self.model,
tokenizer=self.tokenizer,
max_new_tokens=512,
temperature=0.1,
top_p=0.95,
repetition_penalty=1.15,
device_map="auto"
)
llm = HuggingFacePipeline(pipeline=pipe)
# Create prompt template
prompt_template = """
Context: {context}
Based on the context above, please provide a clear and concise answer to the following question.
If the information is not in the context, explicitly state so.
Question: {question}
"""
PROMPT = PromptTemplate(
template=prompt_template,
input_variables=["context", "question"]
)
# Set up QA chain
logger.info("Setting up QA chain...")
self.qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=self.vector_store.as_retriever(
search_kwargs={"k": 6}
),
return_source_documents=True,
chain_type_kwargs={"prompt": PROMPT}
)
logger.info("RAG system initialized successfully")
except Exception as e:
logger.error(f"Error during RAG system initialization: {str(e)}")
raise
def generate_response(self, question: str) -> Dict:
"""Generate response for a given question."""
try:
result = self.qa_chain({"query": question})
response = {
'answer': result['result'],
'sources': []
}
for doc in result['source_documents']:
source = {
'title': doc.metadata.get('title', 'Unknown'),
'content': doc.page_content[:200] + "..." if len(doc.page_content) > 200 else doc.page_content,
'metadata': doc.metadata
}
response['sources'].append(source)
return response
except Exception as e:
logger.error(f"Error generating response: {str(e)}")
raise
@spaces.GPU(duration=60)
def process_response(user_input: str, chat_history: List) -> tuple:
"""Process user input and generate response."""
try:
response = rag_system.generate_response(user_input)
# Clean and format response
answer = response['answer']
if "Answer:" in answer:
answer = answer.split("Answer:")[-1].strip()
# Format sources
sources = set([source['title'] for source in response['sources'][:3]])
if sources:
answer += "\n\nπŸ“š Sources consulted:\n" + "\n".join([f"β€’ {source}" for source in sources])
chat_history.append((user_input, answer))
return chat_history
except Exception as e:
logger.error(f"Error in process_response: {str(e)}")
error_message = f"Sorry, an error occurred: {str(e)}"
chat_history.append((user_input, error_message))
return chat_history
# Initialize RAG system
logger.info("Initializing RAG system...")
try:
rag_system = RAGSystem()
rag_system.initialize_system()
logger.info("RAG system initialization completed")
except Exception as e:
logger.error(f"Failed to initialize RAG system: {str(e)}")
raise
# Create Gradio interface
try:
logger.info("Creating Gradio interface...")
with gr.Blocks(css="div.gradio-container {background-color: #f0f2f6}") as demo:
gr.HTML("""
<div style="text-align: center; max-width: 800px; margin: 0 auto; padding: 20px;">
<h1 style="color: #2d333a;">πŸ“Š FislacBot</h1>
<p style="color: #4a5568;">
AI Assistant specialized in fiscal analysis and FISLAC documentation
</p>
</div>
""")
chatbot = gr.Chatbot(
show_label=False,
container=True,
height=500,
bubble_full_width=True,
show_copy_button=True,
scale=2
)
with gr.Row():
message = gr.Textbox(
placeholder="πŸ’­ Type your question here...",
show_label=False,
container=False,
scale=8,
autofocus=True
)
clear = gr.Button("πŸ—‘οΈ Clear", size="sm", scale=1)
# Suggested questions
gr.HTML('<p style="color: #2d333a; font-weight: bold; margin: 20px 0 10px 0;">πŸ’‘ Suggested questions:</p>')
with gr.Row():
suggestion1 = gr.Button("What is FISLAC?", scale=1)
suggestion2 = gr.Button("What are the main modules of FISLAC?", scale=1)
with gr.Row():
suggestion3 = gr.Button("What macroeconomic variables are relevant for advanced economies?", scale=1)
suggestion4 = gr.Button("How does fiscal risk compare between emerging and advanced countries?", scale=1)
# Footer
gr.HTML("""
<div style="text-align: center; max-width: 800px; margin: 20px auto; padding: 20px;
background-color: #f8f9fa; border-radius: 10px;">
<div style="margin-bottom: 15px;">
<h3 style="color: #2d333a;">πŸ” About this assistant</h3>
<p style="color: #666; font-size: 14px;">
This bot uses RAG (Retrieval Augmented Generation) technology combining:
</p>
<ul style="list-style: none; color: #666; font-size: 14px;">
<li>πŸ”Ή LLM Engine: Llama-2-7b-chat-hf</li>
<li>πŸ”Ή Embeddings: multilingual-e5-large</li>
<li>πŸ”Ή Vector Store: FAISS</li>
</ul>
</div>
<div style="border-top: 1px solid #ddd; padding-top: 15px;">
<p style="color: #666; font-size: 14px;">
<strong>Current Knowledge Base:</strong><br>
β€’ Valencia et al. (2022) - "Assessing macro-fiscal risk for Latin American and Caribbean countries"<br>
β€’ FISLAC Technical Documentation
</p>
</div>
<div style="border-top: 1px solid #ddd; margin-top: 15px; padding-top: 15px;">
<p style="color: #666; font-size: 14px;">
Created by <a href="https://www.linkedin.com/in/camilo-vega-169084b1/"
target="_blank" style="color: #2196F3; text-decoration: none;">Camilo Vega</a>,
AI Consultant πŸ€–
</p>
</div>
</div>
""")
# Configure event handlers
def submit(user_input, chat_history):
return process_response(user_input, chat_history)
message.submit(submit, [message, chatbot], [chatbot])
clear.click(lambda: None, None, chatbot)
# Handle suggested questions
for btn in [suggestion1, suggestion2, suggestion3, suggestion4]:
btn.click(submit, [btn, chatbot], [chatbot])
logger.info("Gradio interface created successfully")
demo.launch()
except Exception as e:
logger.error(f"Error in Gradio interface creation: {str(e)}")
raise