Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
from dotenv import load_dotenv | |
from groq import Groq | |
import PyPDF2 | |
import hashlib | |
import pickle | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
from sklearn.metrics.pairwise import cosine_similarity | |
import logging | |
# Set up logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
# Load environment variables | |
load_dotenv() | |
# Initialize Groq client | |
client = Groq(api_key=os.getenv("GROQ_API_KEY")) | |
# List of available models | |
MODELS = [ | |
"llama3-8b-8192", | |
"gemma-7b-it", | |
"gemma2-9b-it", | |
"llama-3.1-70b-versatile", | |
"llama-3.1-8b-instant", | |
"llama-guard-3-8b", | |
"llama3-70b-8192", | |
"llama3-groq-70b-8192-tool-use-preview", | |
"llama3-groq-8b-8192-tool-use-preview", | |
"mixtral-8x7b-32768" | |
] | |
def process_pdf(file): | |
try: | |
pdf_reader = PyPDF2.PdfReader(file) | |
text = "" | |
for page in pdf_reader.pages: | |
text += page.extract_text() + "\n" | |
return text | |
except Exception as e: | |
logging.error(f"Error processing PDF: {str(e)}") | |
return "" | |
def split_into_chunks(text, chunk_size=1500, overlap=100): | |
words = text.split() | |
chunks = [] | |
for i in range(0, len(words), chunk_size - overlap): | |
chunk = ' '.join(words[i:i + chunk_size]) | |
chunks.append(chunk) | |
return chunks | |
def get_or_create_chunks(file_path): | |
if file_path is None: | |
return [] | |
try: | |
with open(file_path, 'rb') as file: | |
file_hash = hashlib.md5(file.read()).hexdigest() | |
cache_file = f"cache/{file_hash}_chunks.pkl" | |
if os.path.exists(cache_file): | |
with open(cache_file, 'rb') as f: | |
return pickle.load(f) | |
with open(file_path, 'rb') as file: | |
text = process_pdf(file) | |
chunks = split_into_chunks(text) | |
os.makedirs('cache', exist_ok=True) | |
with open(cache_file, 'wb') as f: | |
pickle.dump(chunks, f) | |
return chunks | |
except Exception as e: | |
logging.error(f"Error in get_or_create_chunks: {str(e)}") | |
return [] | |
def find_most_relevant_chunks(query, chunks, top_k=2): | |
try: | |
vectorizer = TfidfVectorizer().fit(chunks + [query]) | |
chunk_vectors = vectorizer.transform(chunks) | |
query_vector = vectorizer.transform([query]) | |
similarities = cosine_similarity(query_vector, chunk_vectors)[0] | |
top_indices = similarities.argsort()[-top_k:][::-1] | |
return [chunks[i] for i in top_indices] | |
except Exception as e: | |
logging.error(f"Error in find_most_relevant_chunks: {str(e)}") | |
return [] | |
def chat_with_pdf(pdf_file, model, prompt, history): | |
if pdf_file is None: | |
return "Please upload a PDF file first.", history, "" | |
try: | |
chunks = get_or_create_chunks(pdf_file.name) | |
relevant_chunks = find_most_relevant_chunks(prompt, chunks) | |
context = "\n\n".join(relevant_chunks) | |
enhanced_prompt = f"Context: {context}\n\nQuestion: {prompt}\n\nAnswer:" | |
chat_completion = client.chat.completions.create( | |
messages=[ | |
{"role": "system", "content": "You are a helpful assistant for answering questions about the given PDF content."}, | |
{"role": "user", "content": enhanced_prompt} | |
], | |
model=model, | |
max_tokens=2048, | |
temperature=0.7 | |
) | |
response = chat_completion.choices[0].message.content | |
history.append((prompt, response)) | |
return history, history, "" | |
except Exception as e: | |
logging.error(f"Error in chat_with_pdf: {str(e)}") | |
error_message = "An error occurred while processing your request. Please try again." | |
history.append((prompt, error_message)) | |
return history, history, "" | |
def clear_history(): | |
return None, [], "" | |
with gr.Blocks() as demo: | |
gr.Markdown("# ChatWithPDF") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
pdf_file = gr.File(label="Upload PDF", file_types=[".pdf"]) | |
model = gr.Dropdown(choices=MODELS, label="Select Model", value=MODELS[0]) | |
clear = gr.Button("Clear Chat History") | |
with gr.Column(scale=2): | |
chatbot = gr.Chatbot() | |
msg = gr.Textbox(label="Ask a question about your PDF") | |
submit = gr.Button("Submit") | |
submit.click(chat_with_pdf, | |
inputs=[pdf_file, model, msg, chatbot], | |
outputs=[chatbot, chatbot, msg]) | |
clear.click(clear_history, outputs=[pdf_file, chatbot, msg]) | |
if __name__ == "__main__": | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=True, | |
_frontend=False, # This disables the default Gradio frontend | |
) | |