batlahiya commited on
Commit
d6e297d
·
verified ·
1 Parent(s): 96005d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -72
app.py CHANGED
@@ -12,7 +12,7 @@ from langchain.chains import ConversationalRetrievalChain
12
  from langchain.memory import ConversationBufferMemory
13
 
14
  from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
15
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, TextIteratorStreamer
16
  import chromadb
17
  import torch
18
  from concurrent.futures import ThreadPoolExecutor
@@ -86,7 +86,7 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
86
  model=model,
87
  tokenizer=tokenizer,
88
  device_map='auto',
89
- max_new_tokens=max_tokens,
90
  do_sample=True,
91
  top_k=top_k,
92
  num_return_sequences=1,
@@ -119,73 +119,43 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
119
  print("Initialization complete!")
120
  return qa_chain
121
 
122
- # TextIteratorStreamer class (likely from another library, not provided)
123
- # This class is probably responsible for handling chunked text processing
124
- # for the LLM generation. You'll need to implement this class or use an
125
- # alternative approach for streaming text generation.
126
-
127
- # Load the PDF document and create the vector database (replace with your logic)
128
- pdf_filepath = predefined_pdf
129
- doc_splits = load_doc([pdf_filepath], chunk_size=2048, chunk_overlap=512)
130
- collection_name = create_collection_name(pdf_filepath)
131
- vector_db = create_db(doc_splits, collection_name)
132
-
133
- # Initialize the LLM chain with threading
134
- qa_chain = initialize_llmchain(predefined_llm, temperature=0.7, max_tokens=64, top_k=1, vector_db=vector_db)
135
-
136
- # Check if qa_chain is properly initialized
137
- if qa_chain is None:
138
- print("Failed to initialize the QA chain. Please check the CUDA availability and model paths.")
139
- else:
140
- # Define the conversation function with streaming
141
- @spaces.GPU()
142
- def conversation(message):
143
- global qa_chain # Assuming qa_chain is a global variable
144
-
145
- tokenizer = AutoTokenizer.from_pretrained(predefined_llm) # Initialize tokenizer here
146
-
147
- outputs = []
148
- generated_response = None # Initialize a variable to hold the final response
149
-
150
- def generate_chunks():
151
- input_ids = tokenizer(message, return_tensors="pt")["input_ids"]
152
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
153
- generate_kwargs = dict(
154
- {"input_ids": input_ids},
155
- streamer=streamer,
156
- max_new_tokens=max_new_tokens,
157
- do_sample=True,
158
- top_p=top_p,
159
- top_k=top_k,
160
- temperature=temperature,
161
- num_beams=1,
162
- repetition_penalty=repetition_penalty,
163
- )
164
- t = threading.Thread(target=model.generate, kwargs=generate_kwargs)
165
- t.start()
166
-
167
- for text in streamer:
168
- outputs.append(text)
169
-
170
- # Wait for the thread to finish and capture the generated text
171
- t.join()
172
- generated_response = "".join(outputs)
173
-
174
- thread = threading.Thread(target=generate_chunks)
175
- thread.start()
176
-
177
- # If the generated response is available, yield it. Otherwise, yield the placeholder.
178
- if generated_response:
179
- yield generated_response
180
- else:
181
- yield "Generating response..."
182
-
183
- # Launch the Gradio interface with share option
184
- interface = gr.Interface(
185
- fn=conversation,
186
- inputs="textbox", # Use a single input textbox
187
- outputs="text", # Text output for streaming
188
- title="Conversational AI with Retrieval",
189
- description="Ask me anything about the uploaded PDF document!",
190
- )
191
- interface.launch(share=True) # Set share=True to create a public link
 
12
  from langchain.memory import ConversationBufferMemory
13
 
14
  from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
15
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
16
  import chromadb
17
  import torch
18
  from concurrent.futures import ThreadPoolExecutor
 
86
  model=model,
87
  tokenizer=tokenizer,
88
  device_map='auto',
89
+ max_new_tokens=max_tokens, # Define max_tokens here
90
  do_sample=True,
91
  top_k=top_k,
92
  num_return_sequences=1,
 
119
  print("Initialization complete!")
120
  return qa_chain
121
 
122
+ # Define the conversation function with streaming (modified approach)
123
+ @spaces.GPU()
124
+ def conversation(message):
125
+ global qa_chain # Assuming qa_chain is a global variable
126
+
127
+ tokenizer = AutoTokenizer.from_pretrained(predefined_llm) # Initialize tokenizer here
128
+ max_new_tokens = 64 # Define max_new_tokens here (or pass it as an argument)
129
+
130
+ outputs = []
131
+ generated_response = None
132
+
133
+ def generate_chunks(message, max_new_tokens):
134
+ max_chunk_length = 512 # Adjust this value based on your model and memory constraints
135
+
136
+ # Split the message into chunks
137
+ chunks = [message[i:i+max_chunk_length] for i in range(0, len(message), max_chunk_length)]
138
+
139
+ for chunk in chunks:
140
+ input_ids = tokenizer(chunk, return_tensors="pt")["input_ids"]
141
+ generated_chunk = model.generate(input_ids=input_ids, max_new_tokens=max_new_tokens, ...) # ... other generation arguments
142
+ outputs.append(generated_chunk[0]['generated_text']) # Assuming generated text is in the first element
143
+
144
+ thread = threading.Thread(target=generate_chunks, args=(message, max_new_tokens))
145
+ thread.start()
146
+
147
+ # If the generated response is available, yield it. Otherwise, yield the placeholder.
148
+ if generated_response:
149
+ yield generated_response
150
+ else:
151
+ yield "Generating response..."
152
+
153
+ # Launch the Gradio interface with share option
154
+ interface = gr.Interface(
155
+ fn=conversation,
156
+ inputs="textbox", # Use a single input textbox
157
+ outputs="text", # Text output for streaming
158
+ title="Conversational AI with Retrieval",
159
+ description="Ask me anything about the uploaded PDF document!",
160
+ )
161
+ interface.launch(share=True) # Set share=True to create a public link