rag / app.py
chayanbhansali's picture
Update app.py
ea219e7 verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from sentence_transformers import SentenceTransformer
from PyPDF2 import PdfReader
import numpy as np
import torch
class RAGChatbot:
def __init__(self,
model_name="facebook/opt-350m",
embedding_model="all-MiniLM-L6-v2"):
# Initialize tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
# self.bnb_config = BitsAndBytesConfig(
# load_in_8bit=True, # Enable 8-bit loading
# llm_int8_threshold=6.0, # Threshold for mixed-precision computation
# )
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto"
)
# Initialize embedding model
self.embedding_model = SentenceTransformer(embedding_model)
# Initialize document storage
self.documents = []
self.embeddings = []
def extract_text_from_pdf(self, pdf_path):
reader = PdfReader(pdf_path)
text = ""
for page in reader.pages:
text += page.extract_text() + "\n"
return text
def load_documents(self, file_paths):
self.documents = []
self.embeddings = []
for file_path in file_paths:
if file_path.endswith('.pdf'):
text = self.extract_text_from_pdf(file_path)
else:
with open(file_path, 'r', encoding='utf-8') as f:
text = f.read()
# Split text into chunks
chunks = [text[i:i+500] for i in range(0, len(text), 500)]
self.documents.extend(chunks)
# Generate embeddings
self.embeddings = self.embedding_model.encode(self.documents)
return f"Loaded {len(self.documents)} text chunks from {len(file_paths)} files"
def retrieve_relevant_context(self, query, top_k=3):
if not self.documents:
return "No documents loaded"
# Generate query embedding
query_embedding = self.embedding_model.encode([query])[0]
# Calculate cosine similarities
similarities = np.dot(self.embeddings, query_embedding) / (
np.linalg.norm(self.embeddings, axis=1) * np.linalg.norm(query_embedding)
)
# Get top k most similar documents
top_indices = similarities.argsort()[-top_k:][::-1]
return " ".join([self.documents[i] for i in top_indices])
def generate_response(self, query, context):
# Construct prompt with
truncated_context = " ".join(context.split()[:100])
full_prompt = f"Context: {truncated_context}\n\nQuestion: {query}\n\nAnswer:"
# Generate response
tokens = self.tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True).to(self.model.device)
inputs = tokens.input_ids.to(self.model.device)
attention_mask = tokens.attention_mask
outputs = self.model.generate(inputs, max_new_tokens=128,attention_mask=attention_mask,pad_token_id=self.tokenizer.eos_token_id,repetition_penalty=1.0)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return response.split("Answer:")[-1].strip()
def chat(self, query, history):
if not query:
return history, ""
try:
# Retrieve relevant context
context = self.retrieve_relevant_context(query)
# Generate response
response = self.generate_response(query, context)
# Append to history using messages format
updated_history = history + [
{"role": "user", "content": query},
{"role": "assistant", "content": response}
]
return updated_history, ""
except Exception as e:
error_response = f"An error occurred: {str(e)}"
updated_history = history + [
{"role": "user", "content": query},
{"role": "assistant", "content": error_response}
]
return updated_history, ""
# Create Gradio interface
def create_interface():
rag_chatbot = RAGChatbot()
with gr.Blocks() as demo:
gr.Markdown("# Ask your PDf!")
with gr.Row():
file_input = gr.File(label="Upload Documents", file_count="multiple", type="filepath")
load_btn = gr.Button("Load Documents")
status_output = gr.Textbox(label="Load Status")
chatbot = gr.Chatbot(type="messages") # Specify message type
msg = gr.Textbox(label="Enter your query")
submit_btn = gr.Button("Send")
clear_btn = gr.Button("Clear Chat")
# Event handlers
load_btn.click(
rag_chatbot.load_documents,
inputs=[file_input],
outputs=[status_output]
)
submit_btn.click(
rag_chatbot.chat,
inputs=[msg, chatbot],
outputs=[chatbot, msg]
)
msg.submit(
rag_chatbot.chat,
inputs=[msg, chatbot],
outputs=[chatbot, msg]
)
clear_btn.click(lambda: (None, ""), None, [chatbot, msg])
return demo
# Launch the app
if __name__ == "__main__":
demo = create_interface()
demo.launch()