devve1 commited on
Commit
41791ed
1 Parent(s): 22954c8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -67
app.py CHANGED
@@ -6,7 +6,6 @@ import msgpack
6
  import numpy as np
7
  import streamlit as st
8
  from numpy import ndarray
9
- from transformers import AutoModelForMaskedLM, AutoTokenizer
10
  from scipy.sparse import csr_matrix, save_npz, load_npz, vstack
11
  from qdrant_client import QdrantClient, models
12
  from langchain_community.llms.llamacpp import LlamaCpp
@@ -15,7 +14,9 @@ from langchain_community.document_loaders.unstructured import UnstructuredFileLo
15
  from langchain_core.prompts import PromptTemplate
16
  from langchain.chains.summarize import load_summarize_chain
17
  from langchain_experimental.text_splitter import SemanticChunker
18
- from langchain_huggingface import HuggingFaceEmbeddings
 
 
19
  from langchain_core.documents import Document
20
  from huggingface_hub import hf_hub_download
21
  from qdrant_client.models import (
@@ -46,10 +47,10 @@ VERBOSE SUMMARY:
46
  """
47
 
48
 
49
- def make_points(chunks: list[str], dense: list[ndarray], indices, values)-> list[PointStruct]:
50
  points = []
51
- for idx, (indice, value, chunk, dense_vector) in enumerate(zip(indices, values, chunks, dense)):
52
- sparse_vector = SparseVector(indices=indice, values=value)
53
  point = PointStruct(
54
  id=idx,
55
  vector={
@@ -63,7 +64,7 @@ def make_points(chunks: list[str], dense: list[ndarray], indices, values)-> list
63
  points.append(point)
64
  return points
65
 
66
- def search(client: QdrantClient, collection_name: str, dense, indices, values):
67
  search_results = client.search_batch(
68
  collection_name,
69
  [
@@ -78,8 +79,8 @@ def search(client: QdrantClient, collection_name: str, dense, indices, values):
78
  vector=NamedSparseVector(
79
  name="text-sparse",
80
  vector=SparseVector(
81
- indices=indices,
82
- values=values,
83
  ),
84
  ),
85
  limit=10
@@ -116,7 +117,7 @@ def rrf(rank_lists, alpha=60, default_rank=1000):
116
  return sorted_items
117
 
118
 
119
- def main(query: str, client: QdrantClient, collection_name: str, llm, dense_model, sparse_model, sparse_tokenizer):
120
  # name = 'Kia_EV6'
121
  # filepath = os.path.join(os.getcwd(), name + '.pdf')
122
 
@@ -130,17 +131,15 @@ def main(query: str, client: QdrantClient, collection_name: str, llm, dense_mode
130
  # )
131
 
132
  # docs = docs.load()
133
-
134
 
135
- dense_query = compute_dense_query(query, dense_model)
136
- sparse_query_indices, sparse_query_values = compute_sparse(query, sparse_model, sparse_tokenizer)
137
 
138
  search_results = search(
139
  client,
140
  collection_name,
141
  dense_query,
142
- sparse_query_indices,
143
- sparse_query_values
144
  )
145
 
146
  dense_rank_list, sparse_rank_list = rank_list(search_results[0]), rank_list(search_results[1])
@@ -183,33 +182,18 @@ def main(query: str, client: QdrantClient, collection_name: str, llm, dense_mode
183
  output = reduce_chain.invoke([summaries])
184
  return output['output_text']
185
 
186
- def compute_sparse(sentence, model, tokenizer):
187
- inputs = tokenizer(sentence, return_tensors='pt')
188
- inputs = {key: val.to(device) for key, val in inputs.items()}
189
- input_ids = inputs['input_ids']
190
-
191
- attention_mask = inputs['attention_mask']
192
-
193
- outputs = model(**inputs)
194
-
195
- logits, attention_mask = outputs.logits, attention_mask
196
- relu_log = torch.log(1 + torch.relu(logits))
197
- weighted_log = relu_log * attention_mask.unsqueeze(-1)
198
- max_val, _ = torch.max(weighted_log, dim=1)
199
- vector = max_val.squeeze()
200
-
201
- cols = vector.nonzero().squeeze().tolist()
202
- weights = vector[cols].tolist()
203
-
204
- return cols, weights
205
-
206
- def compute_dense_query(sentence, model):
207
- return model.embed_query(f'Represent this sentence for searching relevant passages: {sentence}')
208
-
209
- def compute_dense_docs(docs, model):
210
- return model.embed_documents(docs)
211
-
212
  def load_models_and_documents():
 
 
 
 
 
 
 
 
 
 
 
213
  with st.spinner('Load models...'):
214
  model_path = hf_hub_download(repo_id='NousResearch/Hermes-2-Theta-Llama-3-8B-GGUF',
215
  filename='Hermes-2-Pro-Llama-3-Instruct-Merged-DPO-Q8_0.gguf'
@@ -223,15 +207,21 @@ def load_models_and_documents():
223
  n_batch=512,
224
  f16_kv=True
225
  )
226
-
227
- sparse_tokenizer = AutoTokenizer.from_pretrained('prithivida/Splade_PP_en_v2')
228
- reverse_voc = {v: k for k, v in sparse_tokenizer.vocab.items()}
229
- sparse_model = AutoModelForMaskedLM.from_pretrained('prithivida/Splade_PP_en_v2')
230
 
231
- dense_model = HuggingFaceEmbeddings(model_name='mixedbread-ai/mxbai-embed-large-v1',
232
- cache_folder=os.getenv('HF_HOME'),
233
- model_kwargs={'truncate_dim':512}
234
- )
 
 
 
 
 
 
 
 
 
 
235
 
236
  client = QdrantClient(path=os.getenv('HF_HOME'))
237
  collection_name = 'collection_demo'
@@ -283,21 +273,21 @@ def load_models_and_documents():
283
  os.mkdir(embeddings_path)
284
 
285
  docs = WikipediaLoader(query='Action-RPG').load()
286
- chunks, dense_embeddings, indices, values = chunk_documents(docs, dense_model, sparse_model, sparse_tokenizer)
287
 
288
  with open(chunks_path, "wb") as outfile:
289
  packed = msgpack.packb(chunks, use_bin_type=True)
290
  outfile.write(packed)
291
 
292
  np.savez_compressed(dense_path, *dense_embeddings)
293
- max_index = max(np.max(indice) for indice in indices)
294
 
295
  sparse_matrices = []
296
- for indice, value in zip(indices, values):
297
- data = value
298
- indices = indice
299
  indptr = np.array([0, len(data)])
300
- matrix = csr_matrix((data, indice, indptr), shape=(1, max_index + 1))
301
  sparse_matrices.append(matrix)
302
 
303
  combined_sparse_matrix = vstack(sparse_matrices)
@@ -310,14 +300,15 @@ def load_models_and_documents():
310
 
311
  dense_embeddings = list(np.load(dense_path).values())
312
 
313
- indices = []
314
- values = []
315
  loaded_sparse_matrix = load_npz(sparse_path)
316
 
317
  for i in range(loaded_sparse_matrix.shape[0]):
318
  row = loaded_sparse_matrix.getrow(i)
319
- values.append(row.data)
320
- indices.append(row.indices)
 
 
321
 
322
  with st.spinner('Save documents...'):
323
  client.upsert(
@@ -325,8 +316,7 @@ def load_models_and_documents():
325
  make_points(
326
  chunks,
327
  dense_embeddings,
328
- indices,
329
- values,
330
  )
331
  )
332
  client.update_collection(
@@ -334,9 +324,9 @@ def load_models_and_documents():
334
  optimizer_config=models.OptimizersConfigDiff(indexing_threshold=20000)
335
  )
336
 
337
- return client, collection_name, llm, dense_model, sparse_model, sparse_tokenizer
338
 
339
- def chunk_documents(docs, dense_model, sparse_model, sparse_tokenizer):
340
  text_splitter = SemanticChunker(
341
  dense_model,
342
  breakpoint_threshold_type='standard_deviation'
@@ -344,10 +334,10 @@ def chunk_documents(docs, dense_model, sparse_model, sparse_tokenizer):
344
 
345
  documents = [doc.page_content for doc in text_splitter.transform_documents(list(docs))]
346
 
347
- dense_embeddings = compute_dense_docs(documents, dense_model)
348
- indices, values = compute_sparse(documents, sparse_model, sparse_tokenizer)
349
 
350
- return documents, dense_embeddings, indices, values
351
 
352
  if __name__ == '__main__':
353
  st.set_page_config(page_title="Video Game Assistant",
@@ -356,7 +346,7 @@ if __name__ == '__main__':
356
  st.title("Video Game Assistant :sunglasses:")
357
 
358
  if 'models_loaded' not in st.session_state:
359
- st.session_state.client, st.session_state.collection_name, st.session_state.llm, st.session_state.dense_model, st.session_state.sparse_model, st.session_state.sparse_tokenizer = load_models_and_documents()
360
  st.session_state.models_loaded = True
361
 
362
  if st.session.state.models_loaded:
@@ -371,7 +361,7 @@ if __name__ == '__main__':
371
  st.chat_message("user").markdown(prompt)
372
  st.session_state.messages.append({"role": "user", "content": prompt})
373
 
374
- ai_response = main(prompt, st.session_state.client, st.session_state.collection_name, st.session_state.llm, st.session_state.dense_model, st.session_state.sparse_model, st.session_state.sparse_tokenizer)
375
  response = f"Echo: {ai_response}"
376
  with st.chat_message("assistant"):
377
  message_placeholder = st.empty()
 
6
  import numpy as np
7
  import streamlit as st
8
  from numpy import ndarray
 
9
  from scipy.sparse import csr_matrix, save_npz, load_npz, vstack
10
  from qdrant_client import QdrantClient, models
11
  from langchain_community.llms.llamacpp import LlamaCpp
 
14
  from langchain_core.prompts import PromptTemplate
15
  from langchain.chains.summarize import load_summarize_chain
16
  from langchain_experimental.text_splitter import SemanticChunker
17
+ from fastembed.sparse.splade_pp import supported_splade_models
18
+ from fastembed import SparseTextEmbedding, SparseEmbedding
19
+ from fastembed_ext import FastEmbedEmbeddingsLc
20
  from langchain_core.documents import Document
21
  from huggingface_hub import hf_hub_download
22
  from qdrant_client.models import (
 
47
  """
48
 
49
 
50
+ def make_points(chunks: list[str], dense: list[ndarray], sparse)-> list[PointStruct]:
51
  points = []
52
+ for idx, (sparse_vector, chunk, dense_vector) in enumerate(zip(sparse, chunks, dense)):
53
+ sparse_vec = SparseVector(indices=sparse_vector.indices.tolist(), values=sparse_vector.values.tolist())
54
  point = PointStruct(
55
  id=idx,
56
  vector={
 
64
  points.append(point)
65
  return points
66
 
67
+ def search(client: QdrantClient, collection_name: str, dense, sparse):
68
  search_results = client.search_batch(
69
  collection_name,
70
  [
 
79
  vector=NamedSparseVector(
80
  name="text-sparse",
81
  vector=SparseVector(
82
+ indices=sparse[0].indices.tolist(),
83
+ values=sparse[0].values.tolist(),
84
  ),
85
  ),
86
  limit=10
 
117
  return sorted_items
118
 
119
 
120
+ def main(query: str, client: QdrantClient, collection_name: str, llm, dense_model, sparse_model):
121
  # name = 'Kia_EV6'
122
  # filepath = os.path.join(os.getcwd(), name + '.pdf')
123
 
 
131
  # )
132
 
133
  # docs = docs.load()
 
134
 
135
+ dense_query = list(dense_model.embed_query(query, 32))
136
+ sparse_query = list(sparse_model.embed(query, 32))
137
 
138
  search_results = search(
139
  client,
140
  collection_name,
141
  dense_query,
142
+ sparse_query
 
143
  )
144
 
145
  dense_rank_list, sparse_rank_list = rank_list(search_results[0]), rank_list(search_results[1])
 
182
  output = reduce_chain.invoke([summaries])
183
  return output['output_text']
184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  def load_models_and_documents():
186
+ supported_splade_models[0] = {
187
+ "model": "prithivida/Splade_PP_en_v2",
188
+ "vocab_size": 30522,
189
+ "description": "Implementation of SPLADE++ Model for English v2",
190
+ "size_in_GB": 0.532,
191
+ "sources": {
192
+ "hf": "devve1/Splade_PP_en_v2_onnx"
193
+ },
194
+ "model_file": "model.onnx"
195
+ }
196
+
197
  with st.spinner('Load models...'):
198
  model_path = hf_hub_download(repo_id='NousResearch/Hermes-2-Theta-Llama-3-8B-GGUF',
199
  filename='Hermes-2-Pro-Llama-3-Instruct-Merged-DPO-Q8_0.gguf'
 
207
  n_batch=512,
208
  f16_kv=True
209
  )
 
 
 
 
210
 
211
+ provider = ['CPUExecutionProvider']
212
+
213
+ dense_model = FastEmbedEmbeddingsLc(
214
+ model_name='mixedbread-ai/mxbai-embed-large-v1',
215
+ providers=provider,
216
+ cache_dir=os.getenv('HF_HOME'),
217
+ batch_size=32
218
+ )
219
+
220
+ sparse_model = SparseTextEmbedding(
221
+ 'prithivida/Splade_PP_en_v2',
222
+ cache_dir=os.getenv('HF_HOME'),
223
+ providers=provider
224
+ )
225
 
226
  client = QdrantClient(path=os.getenv('HF_HOME'))
227
  collection_name = 'collection_demo'
 
273
  os.mkdir(embeddings_path)
274
 
275
  docs = WikipediaLoader(query='Action-RPG').load()
276
+ chunks, dense_embeddings, sparse_embeddings = chunk_documents(docs, dense_model, sparse_model)
277
 
278
  with open(chunks_path, "wb") as outfile:
279
  packed = msgpack.packb(chunks, use_bin_type=True)
280
  outfile.write(packed)
281
 
282
  np.savez_compressed(dense_path, *dense_embeddings)
283
+ max_index = max(np.max(embedding.indices) for embedding in sparse_embeddings)
284
 
285
  sparse_matrices = []
286
+ for embedding in sparse_embeddings:
287
+ data = embedding.values
288
+ indices = embedding.indices
289
  indptr = np.array([0, len(data)])
290
+ matrix = csr_matrix((data, indices, indptr), shape=(1, max_index + 1))
291
  sparse_matrices.append(matrix)
292
 
293
  combined_sparse_matrix = vstack(sparse_matrices)
 
300
 
301
  dense_embeddings = list(np.load(dense_path).values())
302
 
303
+ sparse_embeddings = []
 
304
  loaded_sparse_matrix = load_npz(sparse_path)
305
 
306
  for i in range(loaded_sparse_matrix.shape[0]):
307
  row = loaded_sparse_matrix.getrow(i)
308
+ values = row.data
309
+ indices = row.indices
310
+ embedding = SparseEmbedding(values, indices)
311
+ sparse_embeddings.append(embedding)
312
 
313
  with st.spinner('Save documents...'):
314
  client.upsert(
 
316
  make_points(
317
  chunks,
318
  dense_embeddings,
319
+ sparse_embeddings
 
320
  )
321
  )
322
  client.update_collection(
 
324
  optimizer_config=models.OptimizersConfigDiff(indexing_threshold=20000)
325
  )
326
 
327
+ return client, collection_name, llm, dense_model, sparse_model
328
 
329
+ def chunk_documents(docs, dense_model, sparse_model):
330
  text_splitter = SemanticChunker(
331
  dense_model,
332
  breakpoint_threshold_type='standard_deviation'
 
334
 
335
  documents = [doc.page_content for doc in text_splitter.transform_documents(list(docs))]
336
 
337
+ dense_embeddings = dense_model.embed_documents(documents,32)
338
+ sparse_embeddings = list(sparse_model.embed(documents, 32))
339
 
340
+ return documents, dense_embeddings, sparse_embeddings
341
 
342
  if __name__ == '__main__':
343
  st.set_page_config(page_title="Video Game Assistant",
 
346
  st.title("Video Game Assistant :sunglasses:")
347
 
348
  if 'models_loaded' not in st.session_state:
349
+ st.session_state.client, st.session_state.collection_name, st.session_state.llm, st.session_state.dense_model, st.session_state.sparse_model = load_models_and_documents()
350
  st.session_state.models_loaded = True
351
 
352
  if st.session.state.models_loaded:
 
361
  st.chat_message("user").markdown(prompt)
362
  st.session_state.messages.append({"role": "user", "content": prompt})
363
 
364
+ ai_response = main(prompt, st.session_state.client, st.session_state.collection_name, st.session_state.llm, st.session_state.dense_model, st.session_state.sparse_model)
365
  response = f"Echo: {ai_response}"
366
  with st.chat_message("assistant"):
367
  message_placeholder = st.empty()