batlahiya commited on
Commit
d2407ee
·
verified ·
1 Parent(s): 1f18304

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -30
app.py CHANGED
@@ -4,7 +4,7 @@ import os
4
  import re
5
  from pathlib import Path
6
  from unidecode import unidecode
7
- from tqdm import tqdm
8
  from langchain_community.document_loaders import PyPDFLoader
9
  from langchain.text_splitter import RecursiveCharacterTextSplitter
10
  from langchain_community.vectorstores import Chroma
@@ -22,7 +22,7 @@ import threading
22
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
23
 
24
  # Predefined values
25
- predefined_pdf = "t6.pdf"
26
  predefined_llm = "meta-llama/Llama-2-7b-hf" # Use a smaller model for faster responses
27
 
28
  def load_doc(list_file_path, chunk_size, chunk_overlap):
@@ -101,7 +101,6 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
101
  return_messages=True
102
  )
103
  retriever = vector_db.as_retriever()
104
-
105
  print("Defining retrieval chain...")
106
  qa_chain = ConversationalRetrievalChain.from_llm(
107
  llm,
@@ -121,9 +120,7 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
121
 
122
  # Define the conversation function with callback (non-blocking)
123
  @spaces.GPU()
124
-
125
-
126
- def conversation(message):
127
  global qa_chain # Assuming qa_chain is a global variable
128
 
129
  # Model definition (ensure it's accessible within the function)
@@ -133,29 +130,37 @@ def conversation(message):
133
 
134
  max_new_tokens = 64 # Define max_new_tokens here
135
 
136
- def generate_chunks(message, max_new_tokens):
137
- max_chunk_length = 512 # Adjust this value based on your model and memory constraints
138
-
139
- # Split the message into chunks
140
- chunks = [message[i:i+max_chunk_length] for i in range(0, len(message), max_chunk_length)]
141
-
142
- outputs = []
143
- for chunk in chunks:
144
- input_ids = tokenizer(chunk, return_tensors="pt")["input_ids"]
145
- generated_chunk = model.generate(input_ids=input_ids, max_new_tokens=max_new_tokens) # ... other generation arguments
146
- outputs.append(generated_chunk[0]['generated_text']) # Assuming generated text is in the first element
147
-
148
- return "".join(outputs)
149
-
150
- # Generate response with progress bar
151
- with tqdm(total=len(message) // max_chunk_length + 1) as pbar:
152
- generated_response = generate_chunks(message, max_new_tokens)
153
- pbar.update()
154
-
155
- if generated_response:
156
- yield generated_response
157
- else:
158
- yield "No response generated." # Provide a fallback message
 
 
 
 
 
 
 
 
159
 
160
  # Launch the Gradio interface with share option
161
  interface = gr.Interface(
@@ -164,5 +169,7 @@ interface = gr.Interface(
164
  outputs="text", # Text output for streaming
165
  title="Conversational AI with Retrieval",
166
  description="Ask me anything about the uploaded PDF document!",
 
167
  )
168
- interface.launch(share=True)
 
 
4
  import re
5
  from pathlib import Path
6
  from unidecode import unidecode
7
+
8
  from langchain_community.document_loaders import PyPDFLoader
9
  from langchain.text_splitter import RecursiveCharacterTextSplitter
10
  from langchain_community.vectorstores import Chroma
 
22
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
23
 
24
  # Predefined values
25
+ predefined_pdf = "t6.pdf" # Replace with your PDF filepath
26
  predefined_llm = "meta-llama/Llama-2-7b-hf" # Use a smaller model for faster responses
27
 
28
  def load_doc(list_file_path, chunk_size, chunk_overlap):
 
101
  return_messages=True
102
  )
103
  retriever = vector_db.as_retriever()
 
104
  print("Defining retrieval chain...")
105
  qa_chain = ConversationalRetrievalChain.from_llm(
106
  llm,
 
120
 
121
  # Define the conversation function with callback (non-blocking)
122
  @spaces.GPU()
123
+ def conversation(message, max_chunk_length=512): # Define max_chunk_length as an argument
 
 
124
  global qa_chain # Assuming qa_chain is a global variable
125
 
126
  # Model definition (ensure it's accessible within the function)
 
130
 
131
  max_new_tokens = 64 # Define max_new_tokens here
132
 
133
+ def generate_chunks(message, max_new_tokens, callback):
134
+ # ... rest of the generate_chunks function ... (unchanged)
135
+
136
+ def handle_response(response):
137
+ if response:
138
+ yield response
139
+ else:
140
+ yield "No response generated." # Provide a fallback message
141
+
142
+ # Start generation in a thread with callback for response
143
+ thread = threading.Thread(target=generate_chunks, args=(message, max_new_tokens, handle_response))
144
+ thread.start()
145
+
146
+ # Yield a placeholder message initially
147
+ yield "Generating response..."
148
+
149
+ # Load or create the document database (adjust as needed)
150
+ pdf_filepath = predefined_pdf
151
+ collection_name = create_collection_name(pdf_filepath)
152
+ if os.path.exists(collection_name):
153
+ vector_db = load_db()
154
+ vector_db.connect(collection_name)
155
+ print("Loaded document database from:", collection_name)
156
+ else:
157
+ print("Creating document database...")
158
+ doc_splits = load_doc([pdf_filepath], chunk_size=4096, chunk_overlap=512)
159
+ vector_db = create_db(doc_splits, collection_name)
160
+ print("Document database created:", collection_name)
161
+
162
+ # Initialize the LLM conversation chain
163
+ qa_chain = initialize_llmchain(predefined_llm, temperature=0.7, max_tokens=64, top_k=50, vector_db=vector_db) # Adjust parameters as needed
164
 
165
  # Launch the Gradio interface with share option
166
  interface = gr.Interface(
 
169
  outputs="text", # Text output for streaming
170
  title="Conversational AI with Retrieval",
171
  description="Ask me anything about the uploaded PDF document!",
172
+ arguments=[("max_chunk_length", int)], # Pass max_chunk_length as an argument
173
  )
174
+ interface.launch(share=True) # Set share=True to create a public link
175
+