batlahiya commited on
Commit
7b0eff8
·
verified ·
1 Parent(s): b5386e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -114
app.py CHANGED
@@ -15,8 +15,9 @@ from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
15
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
16
  import chromadb
17
  import torch
18
- from concurrent.futures import ThreadPoolExecutor
19
- import threading
 
20
 
21
  # Environment configuration
22
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -25,145 +26,157 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false"
25
  predefined_pdf = "t6.pdf" # Replace with your PDF filepath
26
  predefined_llm = "meta-llama/Llama-2-7b-hf" # Use a smaller model for faster responses
27
 
 
28
  def load_doc(list_file_path, chunk_size, chunk_overlap):
29
- loaders = [PyPDFLoader(x) for x in list_file_path]
30
- pages = []
31
- for loader in loaders:
32
- pages.extend(loader.load())
33
- text_splitter = RecursiveCharacterTextSplitter(
34
- chunk_size=chunk_size,
35
- chunk_overlap=chunk_overlap)
36
- doc_splits = text_splitter.split_documents(pages)
37
- return doc_splits
 
 
38
 
39
  def create_db(splits, collection_name):
40
- embedding = HuggingFaceEmbeddings()
41
- new_client = chromadb.EphemeralClient()
42
- vectordb = Chroma.from_documents(
43
- documents=splits,
44
- embedding=embedding,
45
- client=new_client,
46
- collection_name=collection_name,
47
- )
48
- return vectordb
 
49
 
50
  def load_db():
51
- embedding = HuggingFaceEmbeddings()
52
- vectordb = Chroma(
53
- embedding_function=embedding)
54
- return vectordb
 
 
55
 
56
  def create_collection_name(filepath):
57
- collection_name = Path(filepath).stem
58
- collection_name = collection_name.replace(" ", "-")
59
- collection_name = unidecode(collection_name)
60
- collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
61
- collection_name = collection_name[:50]
62
- if len(collection_name) < 3:
63
- collection_name = collection_name + 'xyz'
64
- if not collection_name[0].isalnum():
65
- collection_name = 'A' + collection_name[1:]
66
- if not collection_name[-1].isalnum():
67
- collection_name = collection_name[:-1] + 'Z'
68
- print('Filepath: ', filepath)
69
- print('Collection name: ', collection_name)
70
- return collection_name
71
-
72
- # **Improved `initialize_llmchain` function:**
73
- def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
74
- if not torch.cuda.is_available():
75
- print("CUDA is not available. This demo does not work on CPU.")
76
- return None
77
-
78
- def init_llm():
79
- print("Initializing HF model and tokenizer...")
80
- model = AutoModelForCausalLM.from_pretrained(llm_model, device_map="auto", load_in_4bit=True)
81
- tokenizer = AutoTokenizer.from_pretrained(llm_model)
82
- tokenizer.use_default_system_prompt = True
83
-
84
- print("Initializing HF pipeline...")
85
- hf_pipeline = pipeline(
86
- "text-generation",
87
- model=model,
88
- tokenizer=tokenizer,
89
- device_map='auto',
90
- max_new_tokens=max_tokens, # Define max_tokens here
91
- do_sample=True,
92
- top_k=top_k,
93
- num_return_sequences=1,
94
- eos_token_id=tokenizer.eos_token_id
95
- )
96
- llm = HuggingFacePipeline(pipeline=hf_pipeline, model_kwargs={'temperature': temperature})
97
 
98
- print("Defining buffer memory...")
99
- memory = ConversationBufferMemory(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  memory_key="chat_history",
101
  output_key='answer',
102
  return_messages=True
103
  )
104
- retriever = vector_db.as_retriever()
105
-
106
- print("Defining retrieval chain...")
107
- qa_chain = ConversationalRetrievalChain.from_llm(
108
- llm,
109
- memory=memory,
110
- retriever=retriever,
111
- chain_type="stuff",
112
- return_source_documents=True,
113
- verbose=False,
114
- )
115
- return qa_chain
116
 
117
- with ThreadPoolExecutor() as executor:
118
- future = executor.submit(init_llm)
119
- qa_chain = future.result()
120
- print("Initialization complete!")
121
- return qa_chain
122
 
123
- # **Improved `conversation` function:**
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  @spaces.GPU()
125
- def conversation(message, max_chunk_length=512): # Define max_chunk_length as an argument
126
- global qa_chain # Assuming qa_chain is a global variable
 
 
 
 
127
 
128
- # **Model loading moved to `initialize_llmchain` to avoid duplication**
 
 
 
 
129
 
130
- max_new_tokens = 64 # Define max_tokens here (moved from `conversation`)
 
131
 
132
- def generate_chunks(message, max_new_tokens, callback):
133
- """
134
- This function splits the message into chunks, generates responses for
135
- each chunk using the LLM model, and calls the provided callback function
136
- with the generated response.
137
- """
138
 
139
- # Adjust max_chunk_length based on your model and memory constraints
140
- # ... rest of the generate_chunks function ... (unchanged)
141
 
142
- def handle_response(response):
143
- if response:
144
- yield response
145
- else:
146
- yield "No response generated." # Provide a fallback message
147
 
148
- # **Consider using `asyncio` for non-blocking generation instead of threads**
149
- # This would potentially avoid deadlocks.
150
- # Replace the threading approach with appropriate asyncio implementation.
 
 
 
 
151
 
152
- # Yield a placeholder message initially
153
- yield "Generating response..."
154
 
155
  # Load or create the document database (adjust as needed)
156
  pdf_filepath = predefined_pdf
157
  collection_name = create_collection_name(pdf_filepath)
158
  if os.path.exists(collection_name):
159
- vector_db = load_db()
160
- vector_db.connect(collection_name)
161
- print("Loaded document database from:", collection_name)
162
  else:
163
- print("Creating document database...")
164
- doc_splits = load_doc([pdf_filepath], chunk_size=4096, chunk_overlap=512)
165
- vector_db = create_db(doc_splits, collection_name)
166
- print("Document database created:", collection_name)
167
 
168
  # Initialize the LLM conversation chain (model loaded within `initialize_llmchain`)
169
  qa_chain = initialize_llmchain(predefined_llm, temperature=0.7, max_tokens=64, top_k=50, vector_db=vector_db) # Adjust parameters as needed
@@ -177,4 +190,4 @@ interface = gr.Interface(
177
  title="Conversational AI with Retrieval",
178
  description="Ask me anything about the uploaded PDF document!",
179
  )
180
- interface.launch(share=True) # Set share=True to create a public link
 
15
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
16
  import chromadb
17
  import torch
18
+ from asyncio import run # Use `run` from `asyncio` for single-threaded execution
19
+ import aiohttp # Required for making HTTP requests within the loop (if needed)
20
+
21
 
22
  # Environment configuration
23
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
 
26
  predefined_pdf = "t6.pdf" # Replace with your PDF filepath
27
  predefined_llm = "meta-llama/Llama-2-7b-hf" # Use a smaller model for faster responses
28
 
29
+
30
  def load_doc(list_file_path, chunk_size, chunk_overlap):
31
+ loaders = [PyPDFLoader(x) for x in list_file_path]
32
+ pages = []
33
+ for loader in loaders:
34
+ pages.extend(loader.load())
35
+ text_splitter = RecursiveCharacterTextSplitter(
36
+ chunk_size=chunk_size,
37
+ chunk_overlap=chunk_overlap
38
+ )
39
+ doc_splits = text_splitter.split_documents(pages)
40
+ return doc_splits
41
+
42
 
43
  def create_db(splits, collection_name):
44
+ embedding = HuggingFaceEmbeddings()
45
+ new_client = chromadb.EphemeralClient()
46
+ vectordb = Chroma.from_documents(
47
+ documents=splits,
48
+ embedding=embedding,
49
+ client=new_client,
50
+ collection_name=collection_name,
51
+ )
52
+ return vectordb
53
+
54
 
55
  def load_db():
56
+ embedding = HuggingFaceEmbeddings()
57
+ vectordb = Chroma(
58
+ embedding_function=embedding
59
+ )
60
+ return vectordb
61
+
62
 
63
  def create_collection_name(filepath):
64
+ collection_name = Path(filepath).stem
65
+ collection_name = collection_name.replace(" ", "-")
66
+ collection_name = unidecode(collection_name)
67
+ collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
68
+ collection_name = collection_name[:50]
69
+ if len(collection_name) < 3:
70
+ collection_name = collection_name + 'xyz'
71
+ if not collection_name[0].isalnum():
72
+ collection_name = 'A' + collection_name[1:]
73
+ if not collection_name[-1].isalnum():
74
+ collection_name = collection_name[:-1] + 'Z'
75
+ print('Filepath: ', filepath)
76
+ print('Collection name: ', collection_name)
77
+ return collection_name
78
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
+ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
81
+ if not torch.cuda.is_available():
82
+ print("CUDA is not available. This demo may not perform well on CPU.")
83
+ return None
84
+
85
+ def init_llm():
86
+ print("Initializing HF model and tokenizer...")
87
+ model = AutoModelForCausalLM.from_pretrained(llm_model, device_map="auto", load_in_4bit=True)
88
+ tokenizer = AutoTokenizer.from_pretrained(llm_model)
89
+ tokenizer.use_default_system_prompt = True
90
+
91
+ print("Initializing HF pipeline...")
92
+ hf_pipeline = pipeline(
93
+ "text-generation",
94
+ model=model,
95
+ tokenizer=tokenizer,
96
+ device_map='auto',
97
+ max_new_tokens=max_tokens, # Define max_tokens here
98
+ do_sample=True,
99
+ top_k=top_k,
100
+ num_return_sequences=1,
101
+ eos_token_id=tokenizer.eos_token_id
102
+ )
103
+ llm = HuggingFacePipeline(pipeline=hf_pipeline, model_kwargs={'temperature': temperature})
104
+
105
+ print("Defining buffer memory...")
106
+ memory = ConversationBufferMemory(
107
  memory_key="chat_history",
108
  output_key='answer',
109
  return_messages=True
110
  )
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
+ retriever = vector_db.as_retriever()
 
 
 
 
113
 
114
+ print("Defining retrieval chain...")
115
+ qa_chain = ConversationalRetrievalChain.from_llm(
116
+ llm,
117
+ memory=memory,
118
+ retriever=retriever,
119
+ chain_type="stuff",
120
+ return_source_documents=True,
121
+ verbose=False,
122
+ )
123
+ return qa_chain
124
+
125
+ return run(init_llm()) # Run the initialization function within the event loop
126
+
127
+ # Asynchronous conversation function
128
  @spaces.GPU()
129
+ async def conversation(message, max_chunk_length=512): # Define max_chunk_length as an argument
130
+ global qa_chain # Assuming qa_chain is a global variable
131
+
132
+ # Model loading moved to `initialize_llmchain` to avoid duplication
133
+
134
+ max_new_tokens = 64 # Define max_tokens here (moved from `conversation`)
135
 
136
+ async def generate_chunks(message, max_new_tokens):
137
+ """
138
+ This function splits the message into chunks, generates responses for
139
+ each chunk using the LLM model, and returns the generated response.
140
+ """
141
 
142
+ # Adjust max_chunk_length based on your model and memory constraints
143
+ # ... rest of the generate_chunks function ... (unchanged)
144
 
145
+ async with aiohttp.ClientSession() as session: # Use session for HTTP requests (if needed)
146
+ # ... make HTTP requests using session here ...
147
+ pass
 
 
 
148
 
149
+ return response # Return the generated response
 
150
 
151
+ async def handle_response(response):
152
+ if response:
153
+ yield response
154
+ else:
155
+ yield "No response generated." # Provide a fallback message
156
 
157
+ # Use asyncio to run generation tasks concurrently
158
+ try:
159
+ task = asyncio.create_task(generate_chunks(message, max_new_tokens))
160
+ response = await task # Wait for the generation task to complete
161
+ except Exception as e:
162
+ print(f"Error during generation: {e}")
163
+ response = None
164
 
165
+ # Yield the generated response
166
+ yield response
167
 
168
  # Load or create the document database (adjust as needed)
169
  pdf_filepath = predefined_pdf
170
  collection_name = create_collection_name(pdf_filepath)
171
  if os.path.exists(collection_name):
172
+ vector_db = load_db()
173
+ vector_db.connect(collection_name)
174
+ print("Loaded document database from:", collection_name)
175
  else:
176
+ print("Creating document database...")
177
+ doc_splits = load_doc([pdf_filepath], chunk_size=4096, chunk_overlap=512)
178
+ vector_db = create_db(doc_splits, collection_name)
179
+ print("Document database created:", collection_name)
180
 
181
  # Initialize the LLM conversation chain (model loaded within `initialize_llmchain`)
182
  qa_chain = initialize_llmchain(predefined_llm, temperature=0.7, max_tokens=64, top_k=50, vector_db=vector_db) # Adjust parameters as needed
 
190
  title="Conversational AI with Retrieval",
191
  description="Ask me anything about the uploaded PDF document!",
192
  )
193
+ interface.launch(share=True) # Set share=True to create a public link