Update app.py
Browse files
app.py
CHANGED
@@ -4,7 +4,7 @@ import os
|
|
4 |
import re
|
5 |
from pathlib import Path
|
6 |
from unidecode import unidecode
|
7 |
-
|
8 |
from langchain_community.document_loaders import PyPDFLoader
|
9 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
10 |
from langchain_community.vectorstores import Chroma
|
@@ -22,7 +22,7 @@ import threading
|
|
22 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
23 |
|
24 |
# Predefined values
|
25 |
-
predefined_pdf = "t6.pdf"
|
26 |
predefined_llm = "meta-llama/Llama-2-7b-hf" # Use a smaller model for faster responses
|
27 |
|
28 |
def load_doc(list_file_path, chunk_size, chunk_overlap):
|
@@ -101,7 +101,6 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
|
|
101 |
return_messages=True
|
102 |
)
|
103 |
retriever = vector_db.as_retriever()
|
104 |
-
|
105 |
print("Defining retrieval chain...")
|
106 |
qa_chain = ConversationalRetrievalChain.from_llm(
|
107 |
llm,
|
@@ -121,9 +120,7 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
|
|
121 |
|
122 |
# Define the conversation function with callback (non-blocking)
|
123 |
@spaces.GPU()
|
124 |
-
|
125 |
-
|
126 |
-
def conversation(message):
|
127 |
global qa_chain # Assuming qa_chain is a global variable
|
128 |
|
129 |
# Model definition (ensure it's accessible within the function)
|
@@ -133,29 +130,37 @@ def conversation(message):
|
|
133 |
|
134 |
max_new_tokens = 64 # Define max_new_tokens here
|
135 |
|
136 |
-
def generate_chunks(message, max_new_tokens):
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
|
160 |
# Launch the Gradio interface with share option
|
161 |
interface = gr.Interface(
|
@@ -164,5 +169,7 @@ interface = gr.Interface(
|
|
164 |
outputs="text", # Text output for streaming
|
165 |
title="Conversational AI with Retrieval",
|
166 |
description="Ask me anything about the uploaded PDF document!",
|
|
|
167 |
)
|
168 |
-
interface.launch(share=True)
|
|
|
|
4 |
import re
|
5 |
from pathlib import Path
|
6 |
from unidecode import unidecode
|
7 |
+
|
8 |
from langchain_community.document_loaders import PyPDFLoader
|
9 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
10 |
from langchain_community.vectorstores import Chroma
|
|
|
22 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
23 |
|
24 |
# Predefined values
|
25 |
+
predefined_pdf = "t6.pdf" # Replace with your PDF filepath
|
26 |
predefined_llm = "meta-llama/Llama-2-7b-hf" # Use a smaller model for faster responses
|
27 |
|
28 |
def load_doc(list_file_path, chunk_size, chunk_overlap):
|
|
|
101 |
return_messages=True
|
102 |
)
|
103 |
retriever = vector_db.as_retriever()
|
|
|
104 |
print("Defining retrieval chain...")
|
105 |
qa_chain = ConversationalRetrievalChain.from_llm(
|
106 |
llm,
|
|
|
120 |
|
121 |
# Define the conversation function with callback (non-blocking)
|
122 |
@spaces.GPU()
|
123 |
+
def conversation(message, max_chunk_length=512): # Define max_chunk_length as an argument
|
|
|
|
|
124 |
global qa_chain # Assuming qa_chain is a global variable
|
125 |
|
126 |
# Model definition (ensure it's accessible within the function)
|
|
|
130 |
|
131 |
max_new_tokens = 64 # Define max_new_tokens here
|
132 |
|
133 |
+
def generate_chunks(message, max_new_tokens, callback):
|
134 |
+
# ... rest of the generate_chunks function ... (unchanged)
|
135 |
+
|
136 |
+
def handle_response(response):
|
137 |
+
if response:
|
138 |
+
yield response
|
139 |
+
else:
|
140 |
+
yield "No response generated." # Provide a fallback message
|
141 |
+
|
142 |
+
# Start generation in a thread with callback for response
|
143 |
+
thread = threading.Thread(target=generate_chunks, args=(message, max_new_tokens, handle_response))
|
144 |
+
thread.start()
|
145 |
+
|
146 |
+
# Yield a placeholder message initially
|
147 |
+
yield "Generating response..."
|
148 |
+
|
149 |
+
# Load or create the document database (adjust as needed)
|
150 |
+
pdf_filepath = predefined_pdf
|
151 |
+
collection_name = create_collection_name(pdf_filepath)
|
152 |
+
if os.path.exists(collection_name):
|
153 |
+
vector_db = load_db()
|
154 |
+
vector_db.connect(collection_name)
|
155 |
+
print("Loaded document database from:", collection_name)
|
156 |
+
else:
|
157 |
+
print("Creating document database...")
|
158 |
+
doc_splits = load_doc([pdf_filepath], chunk_size=4096, chunk_overlap=512)
|
159 |
+
vector_db = create_db(doc_splits, collection_name)
|
160 |
+
print("Document database created:", collection_name)
|
161 |
+
|
162 |
+
# Initialize the LLM conversation chain
|
163 |
+
qa_chain = initialize_llmchain(predefined_llm, temperature=0.7, max_tokens=64, top_k=50, vector_db=vector_db) # Adjust parameters as needed
|
164 |
|
165 |
# Launch the Gradio interface with share option
|
166 |
interface = gr.Interface(
|
|
|
169 |
outputs="text", # Text output for streaming
|
170 |
title="Conversational AI with Retrieval",
|
171 |
description="Ask me anything about the uploaded PDF document!",
|
172 |
+
arguments=[("max_chunk_length", int)], # Pass max_chunk_length as an argument
|
173 |
)
|
174 |
+
interface.launch(share=True) # Set share=True to create a public link
|
175 |
+
|