batlahiya commited on
Commit
232e6a8
·
verified ·
1 Parent(s): dc4c639

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -16
app.py CHANGED
@@ -15,15 +15,13 @@ from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
15
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
16
  import chromadb
17
  import torch
18
- from concurrent.futures import ThreadPoolExecutor
19
- import threading
20
 
21
  # Environment configuration
22
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
23
 
24
- # Predefined values
25
- predefined_pdf = "t6.pdf" # Replace with your PDF filepath
26
- predefined_llm = "meta-llama/Llama-2-7b-hf" # Use a smaller model for faster responses
27
 
28
 
29
  def load_doc(list_file_path, chunk_size, chunk_overlap):
@@ -93,7 +91,7 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
93
  model=model,
94
  tokenizer=tokenizer,
95
  device_map='auto',
96
- max_new_tokens=max_tokens, # Define max_tokens here
97
  do_sample=True,
98
  top_k=top_k,
99
  num_return_sequences=1,
@@ -109,9 +107,14 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
109
  )
110
 
111
  retriever = vector_db.as_retriever()
112
-
113
  print("Defining retrieval chain...")
114
- qa_chain = ConversationalRetrievalChain.from_llm(llm.encode("What is the weather like today?"), memory=memory, retriever=retriever) # Initial prompt to prime the memory
 
 
 
 
 
 
115
  return qa_chain
116
 
117
  # Load the model in a separate thread
@@ -132,7 +135,7 @@ def conversation(message, max_chunk_length=512):
132
 
133
  # Model loading handled by `initialize_llmchain` (called once)
134
 
135
- max_new_tokens = 64 # Define max_tokens here (moved from `conversation`)
136
 
137
  def generate_chunks(message, max_new_tokens):
138
  """
@@ -140,10 +143,27 @@ def conversation(message, max_chunk_length=512):
140
  each chunk using the LLM model, and returns the generated response.
141
  """
142
 
143
- # Adjust max_chunk_length based on your model and memory constraints
144
- # ... rest of the generate_chunks function ... (unchanged)
 
 
 
145
 
146
- return response # Return the generated response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
  def handle_response(response):
149
  if response:
@@ -170,8 +190,9 @@ def conversation(message, max_chunk_length=512):
170
  # Yield the final response (if any)
171
  # yield response # Removed as generation happens within the thread
172
 
173
- # Load or create the document database (adjust as needed)
174
- pdf_filepath = predefined_pdf
 
175
  collection_name = create_collection_name(pdf_filepath)
176
  if os.path.exists(collection_name):
177
  vector_db = load_db()
@@ -184,10 +205,9 @@ else:
184
  print("Document database created:", collection_name)
185
 
186
  # Initialize the LLM conversation chain (model loaded in separate thread)
187
- qa_chain = initialize_llmchain(predefined_llm, temperature=0.7, max_tokens=64, top_k=50, vector_db=vector_db)
188
 
189
  # Launch the Gradio interface with share option
190
- # (Consider removing 'arguments' if there's a version incompatibility with gradio)
191
  interface = gr.Interface(
192
  fn=conversation,
193
  inputs="textbox", # Use a single input textbox
 
15
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
16
  import chromadb
17
  import torch
 
 
18
 
19
  # Environment configuration
20
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
21
 
22
+ # Predefined values (replace with your PDF path and desired LLM)
23
+ pdf_filepath = "your_pdf.pdf"
24
+ llm_model = "meta-llama/Llama-2-7b-hf" # Use a smaller model for faster responses
25
 
26
 
27
  def load_doc(list_file_path, chunk_size, chunk_overlap):
 
91
  model=model,
92
  tokenizer=tokenizer,
93
  device_map='auto',
94
+ max_new_tokens=max_tokens,
95
  do_sample=True,
96
  top_k=top_k,
97
  num_return_sequences=1,
 
107
  )
108
 
109
  retriever = vector_db.as_retriever()
 
110
  print("Defining retrieval chain...")
111
+ qa_chain = ConversationalRetrievalChain.from_llm(
112
+ llm=llm,
113
+ memory=memory,
114
+ retriever=retriever
115
+ )
116
+
117
+ llm(llm.encode("What is the weather like today?"), memory=memory, retriever=retriever) # Initial prompt to prime the memory
118
  return qa_chain
119
 
120
  # Load the model in a separate thread
 
135
 
136
  # Model loading handled by `initialize_llmchain` (called once)
137
 
138
+ max_new_tokens = 64 # Define max_tokens here
139
 
140
  def generate_chunks(message, max_new_tokens):
141
  """
 
143
  each chunk using the LLM model, and returns the generated response.
144
  """
145
 
146
+ responses = [] # List to store individual chunk responses
147
+
148
+ # Loop through the message in chunks
149
+ for i in range(0, len(message), max_new_tokens):
150
+ chunk = message[i:i+max_new_tokens] # Extract the current chunk
151
 
152
+ # Encode the chunk for the LLM model
153
+ encoded_chunk = tokenizer.encode(chunk, return_tensors="pt")
154
+
155
+ try:
156
+ # Generate response using the LLM model
157
+ response = llm(encoded_chunk)[0]["generated_text"]
158
+ responses.append(response) # Add response to the list
159
+ except Exception as e:
160
+ print(f"Error generating response for chunk: {chunk}")
161
+ # Handle error (e.g., return a fallback message)
162
+
163
+ # Combine individual responses into a final response
164
+ final_response = " ".join(responses)
165
+
166
+ return final_response
167
 
168
  def handle_response(response):
169
  if response:
 
190
  # Yield the final response (if any)
191
  # yield response # Removed as generation happens within the thread
192
 
193
+
194
+ # Load or create the document database
195
+ pdf_filepath = pdf_filepath
196
  collection_name = create_collection_name(pdf_filepath)
197
  if os.path.exists(collection_name):
198
  vector_db = load_db()
 
205
  print("Document database created:", collection_name)
206
 
207
  # Initialize the LLM conversation chain (model loaded in separate thread)
208
+ qa_chain = initialize_llmchain(llm_model, temperature=0.7, max_tokens=max_new_tokens, top_k=50, vector_db=vector_db)
209
 
210
  # Launch the Gradio interface with share option
 
211
  interface = gr.Interface(
212
  fn=conversation,
213
  inputs="textbox", # Use a single input textbox