Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
from sentence_transformers import SentenceTransformer | |
from PyPDF2 import PdfReader | |
import numpy as np | |
import torch | |
class RAGChatbot: | |
def __init__(self, | |
model_name="facebook/opt-350m", | |
embedding_model="all-MiniLM-L6-v2"): | |
# Initialize tokenizer and model | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) | |
# self.bnb_config = BitsAndBytesConfig( | |
# load_in_8bit=True, # Enable 8-bit loading | |
# llm_int8_threshold=6.0, # Threshold for mixed-precision computation | |
# ) | |
self.model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.bfloat16, | |
device_map="auto" | |
) | |
# Initialize embedding model | |
self.embedding_model = SentenceTransformer(embedding_model) | |
# Initialize document storage | |
self.documents = [] | |
self.embeddings = [] | |
def extract_text_from_pdf(self, pdf_path): | |
reader = PdfReader(pdf_path) | |
text = "" | |
for page in reader.pages: | |
text += page.extract_text() + "\n" | |
return text | |
def load_documents(self, file_paths): | |
self.documents = [] | |
self.embeddings = [] | |
for file_path in file_paths: | |
if file_path.endswith('.pdf'): | |
text = self.extract_text_from_pdf(file_path) | |
else: | |
with open(file_path, 'r', encoding='utf-8') as f: | |
text = f.read() | |
# Split text into chunks | |
chunks = [text[i:i+500] for i in range(0, len(text), 500)] | |
self.documents.extend(chunks) | |
# Generate embeddings | |
self.embeddings = self.embedding_model.encode(self.documents) | |
return f"Loaded {len(self.documents)} text chunks from {len(file_paths)} files" | |
def retrieve_relevant_context(self, query, top_k=3): | |
if not self.documents: | |
return "No documents loaded" | |
# Generate query embedding | |
query_embedding = self.embedding_model.encode([query])[0] | |
# Calculate cosine similarities | |
similarities = np.dot(self.embeddings, query_embedding) / ( | |
np.linalg.norm(self.embeddings, axis=1) * np.linalg.norm(query_embedding) | |
) | |
# Get top k most similar documents | |
top_indices = similarities.argsort()[-top_k:][::-1] | |
return " ".join([self.documents[i] for i in top_indices]) | |
def generate_response(self, query, context): | |
# Construct prompt with | |
truncated_context = " ".join(context.split()[:100]) | |
full_prompt = f"Context: {truncated_context}\n\nQuestion: {query}\n\nAnswer:" | |
# Generate response | |
tokens = self.tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True).to(self.model.device) | |
inputs = tokens.input_ids.to(self.model.device) | |
attention_mask = tokens.attention_mask | |
outputs = self.model.generate(inputs, max_new_tokens=128,attention_mask=attention_mask,pad_token_id=self.tokenizer.eos_token_id,repetition_penalty=1.0) | |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return response.split("Answer:")[-1].strip() | |
def chat(self, query, history): | |
if not query: | |
return history, "" | |
try: | |
# Retrieve relevant context | |
context = self.retrieve_relevant_context(query) | |
# Generate response | |
response = self.generate_response(query, context) | |
# Append to history using messages format | |
updated_history = history + [ | |
{"role": "user", "content": query}, | |
{"role": "assistant", "content": response} | |
] | |
return updated_history, "" | |
except Exception as e: | |
error_response = f"An error occurred: {str(e)}" | |
updated_history = history + [ | |
{"role": "user", "content": query}, | |
{"role": "assistant", "content": error_response} | |
] | |
return updated_history, "" | |
# Create Gradio interface | |
def create_interface(): | |
rag_chatbot = RAGChatbot() | |
with gr.Blocks() as demo: | |
gr.Markdown("# Ask your PDf!") | |
with gr.Row(): | |
file_input = gr.File(label="Upload Documents", file_count="multiple", type="filepath") | |
load_btn = gr.Button("Load Documents") | |
status_output = gr.Textbox(label="Load Status") | |
chatbot = gr.Chatbot(type="messages") # Specify message type | |
msg = gr.Textbox(label="Enter your query") | |
submit_btn = gr.Button("Send") | |
clear_btn = gr.Button("Clear Chat") | |
# Event handlers | |
load_btn.click( | |
rag_chatbot.load_documents, | |
inputs=[file_input], | |
outputs=[status_output] | |
) | |
submit_btn.click( | |
rag_chatbot.chat, | |
inputs=[msg, chatbot], | |
outputs=[chatbot, msg] | |
) | |
msg.submit( | |
rag_chatbot.chat, | |
inputs=[msg, chatbot], | |
outputs=[chatbot, msg] | |
) | |
clear_btn.click(lambda: (None, ""), None, [chatbot, msg]) | |
return demo | |
# Launch the app | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.launch() |