File size: 4,782 Bytes
5a0c0b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import gradio as gr
import os
from dotenv import load_dotenv
from groq import Groq
import PyPDF2
import hashlib
import pickle
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import logging

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Load environment variables
load_dotenv()

# Initialize Groq client
client = Groq(api_key=os.getenv("GROQ_API_KEY"))

# List of available models
MODELS = [
    "llama3-8b-8192",
    "gemma-7b-it",
    "gemma2-9b-it",
    "llama-3.1-70b-versatile",
    "llama-3.1-8b-instant",
    "llama-guard-3-8b",
    "llama3-70b-8192",
    "llama3-groq-70b-8192-tool-use-preview",
    "llama3-groq-8b-8192-tool-use-preview",
    "mixtral-8x7b-32768"
]

def process_pdf(file):
    try:
        pdf_reader = PyPDF2.PdfReader(file)
        text = ""
        for page in pdf_reader.pages:
            text += page.extract_text() + "\n"
        return text
    except Exception as e:
        logging.error(f"Error processing PDF: {str(e)}")
        return ""

def split_into_chunks(text, chunk_size=1500, overlap=100):
    words = text.split()
    chunks = []
    for i in range(0, len(words), chunk_size - overlap):
        chunk = ' '.join(words[i:i + chunk_size])
        chunks.append(chunk)
    return chunks

def get_or_create_chunks(file_path):
    if file_path is None:
        return []

    try:
        with open(file_path, 'rb') as file:
            file_hash = hashlib.md5(file.read()).hexdigest()

        cache_file = f"cache/{file_hash}_chunks.pkl"
        if os.path.exists(cache_file):
            with open(cache_file, 'rb') as f:
                return pickle.load(f)

        with open(file_path, 'rb') as file:
            text = process_pdf(file)
        chunks = split_into_chunks(text)

        os.makedirs('cache', exist_ok=True)
        with open(cache_file, 'wb') as f:
            pickle.dump(chunks, f)

        return chunks
    except Exception as e:
        logging.error(f"Error in get_or_create_chunks: {str(e)}")
        return []

def find_most_relevant_chunks(query, chunks, top_k=2):
    try:
        vectorizer = TfidfVectorizer().fit(chunks + [query])
        chunk_vectors = vectorizer.transform(chunks)
        query_vector = vectorizer.transform([query])
        similarities = cosine_similarity(query_vector, chunk_vectors)[0]
        top_indices = similarities.argsort()[-top_k:][::-1]
        return [chunks[i] for i in top_indices]
    except Exception as e:
        logging.error(f"Error in find_most_relevant_chunks: {str(e)}")
        return []

def chat_with_pdf(pdf_file, model, prompt, history):
    if pdf_file is None:
        return "Please upload a PDF file first.", history, ""

    try:
        chunks = get_or_create_chunks(pdf_file.name)
        relevant_chunks = find_most_relevant_chunks(prompt, chunks)
        context = "\n\n".join(relevant_chunks)
        enhanced_prompt = f"Context: {context}\n\nQuestion: {prompt}\n\nAnswer:"

        chat_completion = client.chat.completions.create(
            messages=[
                {"role": "system", "content": "You are a helpful assistant for answering questions about the given PDF content."},
                {"role": "user", "content": enhanced_prompt}
            ],
            model=model,
            max_tokens=2048,
            temperature=0.7
        )
        response = chat_completion.choices[0].message.content
        history.append((prompt, response))
        return history, history, ""
    except Exception as e:
        logging.error(f"Error in chat_with_pdf: {str(e)}")
        error_message = "An error occurred while processing your request. Please try again."
        history.append((prompt, error_message))
        return history, history, ""

def clear_history():
    return None, [], ""

with gr.Blocks() as demo:
    gr.Markdown("# ChatWithPDF")

    with gr.Row():
        with gr.Column(scale=1):
            pdf_file = gr.File(label="Upload PDF", file_types=[".pdf"])
            model = gr.Dropdown(choices=MODELS, label="Select Model", value=MODELS[0])
            clear = gr.Button("Clear Chat History")

        with gr.Column(scale=2):
            chatbot = gr.Chatbot()
            msg = gr.Textbox(label="Ask a question about your PDF")
            submit = gr.Button("Submit")

    submit.click(chat_with_pdf,
                 inputs=[pdf_file, model, msg, chatbot],
                 outputs=[chatbot, chatbot, msg])
    clear.click(clear_history, outputs=[pdf_file, chatbot, msg])

if __name__ == "__main__":
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=True,
        _frontend=False,  # This disables the default Gradio frontend
    )