Update app.py
Browse files
app.py
CHANGED
@@ -6,7 +6,6 @@ import msgpack
|
|
6 |
import numpy as np
|
7 |
import streamlit as st
|
8 |
from numpy import ndarray
|
9 |
-
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
10 |
from scipy.sparse import csr_matrix, save_npz, load_npz, vstack
|
11 |
from qdrant_client import QdrantClient, models
|
12 |
from langchain_community.llms.llamacpp import LlamaCpp
|
@@ -15,7 +14,9 @@ from langchain_community.document_loaders.unstructured import UnstructuredFileLo
|
|
15 |
from langchain_core.prompts import PromptTemplate
|
16 |
from langchain.chains.summarize import load_summarize_chain
|
17 |
from langchain_experimental.text_splitter import SemanticChunker
|
18 |
-
from
|
|
|
|
|
19 |
from langchain_core.documents import Document
|
20 |
from huggingface_hub import hf_hub_download
|
21 |
from qdrant_client.models import (
|
@@ -46,10 +47,10 @@ VERBOSE SUMMARY:
|
|
46 |
"""
|
47 |
|
48 |
|
49 |
-
def make_points(chunks: list[str], dense: list[ndarray],
|
50 |
points = []
|
51 |
-
for idx, (
|
52 |
-
|
53 |
point = PointStruct(
|
54 |
id=idx,
|
55 |
vector={
|
@@ -63,7 +64,7 @@ def make_points(chunks: list[str], dense: list[ndarray], indices, values)-> list
|
|
63 |
points.append(point)
|
64 |
return points
|
65 |
|
66 |
-
def search(client: QdrantClient, collection_name: str, dense,
|
67 |
search_results = client.search_batch(
|
68 |
collection_name,
|
69 |
[
|
@@ -78,8 +79,8 @@ def search(client: QdrantClient, collection_name: str, dense, indices, values):
|
|
78 |
vector=NamedSparseVector(
|
79 |
name="text-sparse",
|
80 |
vector=SparseVector(
|
81 |
-
indices=indices,
|
82 |
-
values=values,
|
83 |
),
|
84 |
),
|
85 |
limit=10
|
@@ -116,7 +117,7 @@ def rrf(rank_lists, alpha=60, default_rank=1000):
|
|
116 |
return sorted_items
|
117 |
|
118 |
|
119 |
-
def main(query: str, client: QdrantClient, collection_name: str, llm, dense_model, sparse_model
|
120 |
# name = 'Kia_EV6'
|
121 |
# filepath = os.path.join(os.getcwd(), name + '.pdf')
|
122 |
|
@@ -130,17 +131,15 @@ def main(query: str, client: QdrantClient, collection_name: str, llm, dense_mode
|
|
130 |
# )
|
131 |
|
132 |
# docs = docs.load()
|
133 |
-
|
134 |
|
135 |
-
dense_query =
|
136 |
-
|
137 |
|
138 |
search_results = search(
|
139 |
client,
|
140 |
collection_name,
|
141 |
dense_query,
|
142 |
-
|
143 |
-
sparse_query_values
|
144 |
)
|
145 |
|
146 |
dense_rank_list, sparse_rank_list = rank_list(search_results[0]), rank_list(search_results[1])
|
@@ -183,33 +182,18 @@ def main(query: str, client: QdrantClient, collection_name: str, llm, dense_mode
|
|
183 |
output = reduce_chain.invoke([summaries])
|
184 |
return output['output_text']
|
185 |
|
186 |
-
def compute_sparse(sentence, model, tokenizer):
|
187 |
-
inputs = tokenizer(sentence, return_tensors='pt')
|
188 |
-
inputs = {key: val.to(device) for key, val in inputs.items()}
|
189 |
-
input_ids = inputs['input_ids']
|
190 |
-
|
191 |
-
attention_mask = inputs['attention_mask']
|
192 |
-
|
193 |
-
outputs = model(**inputs)
|
194 |
-
|
195 |
-
logits, attention_mask = outputs.logits, attention_mask
|
196 |
-
relu_log = torch.log(1 + torch.relu(logits))
|
197 |
-
weighted_log = relu_log * attention_mask.unsqueeze(-1)
|
198 |
-
max_val, _ = torch.max(weighted_log, dim=1)
|
199 |
-
vector = max_val.squeeze()
|
200 |
-
|
201 |
-
cols = vector.nonzero().squeeze().tolist()
|
202 |
-
weights = vector[cols].tolist()
|
203 |
-
|
204 |
-
return cols, weights
|
205 |
-
|
206 |
-
def compute_dense_query(sentence, model):
|
207 |
-
return model.embed_query(f'Represent this sentence for searching relevant passages: {sentence}')
|
208 |
-
|
209 |
-
def compute_dense_docs(docs, model):
|
210 |
-
return model.embed_documents(docs)
|
211 |
-
|
212 |
def load_models_and_documents():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
with st.spinner('Load models...'):
|
214 |
model_path = hf_hub_download(repo_id='NousResearch/Hermes-2-Theta-Llama-3-8B-GGUF',
|
215 |
filename='Hermes-2-Pro-Llama-3-Instruct-Merged-DPO-Q8_0.gguf'
|
@@ -223,15 +207,21 @@ def load_models_and_documents():
|
|
223 |
n_batch=512,
|
224 |
f16_kv=True
|
225 |
)
|
226 |
-
|
227 |
-
sparse_tokenizer = AutoTokenizer.from_pretrained('prithivida/Splade_PP_en_v2')
|
228 |
-
reverse_voc = {v: k for k, v in sparse_tokenizer.vocab.items()}
|
229 |
-
sparse_model = AutoModelForMaskedLM.from_pretrained('prithivida/Splade_PP_en_v2')
|
230 |
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
235 |
|
236 |
client = QdrantClient(path=os.getenv('HF_HOME'))
|
237 |
collection_name = 'collection_demo'
|
@@ -283,21 +273,21 @@ def load_models_and_documents():
|
|
283 |
os.mkdir(embeddings_path)
|
284 |
|
285 |
docs = WikipediaLoader(query='Action-RPG').load()
|
286 |
-
chunks, dense_embeddings,
|
287 |
|
288 |
with open(chunks_path, "wb") as outfile:
|
289 |
packed = msgpack.packb(chunks, use_bin_type=True)
|
290 |
outfile.write(packed)
|
291 |
|
292 |
np.savez_compressed(dense_path, *dense_embeddings)
|
293 |
-
max_index = max(np.max(
|
294 |
|
295 |
sparse_matrices = []
|
296 |
-
for
|
297 |
-
data =
|
298 |
-
indices =
|
299 |
indptr = np.array([0, len(data)])
|
300 |
-
matrix = csr_matrix((data,
|
301 |
sparse_matrices.append(matrix)
|
302 |
|
303 |
combined_sparse_matrix = vstack(sparse_matrices)
|
@@ -310,14 +300,15 @@ def load_models_and_documents():
|
|
310 |
|
311 |
dense_embeddings = list(np.load(dense_path).values())
|
312 |
|
313 |
-
|
314 |
-
values = []
|
315 |
loaded_sparse_matrix = load_npz(sparse_path)
|
316 |
|
317 |
for i in range(loaded_sparse_matrix.shape[0]):
|
318 |
row = loaded_sparse_matrix.getrow(i)
|
319 |
-
values
|
320 |
-
indices
|
|
|
|
|
321 |
|
322 |
with st.spinner('Save documents...'):
|
323 |
client.upsert(
|
@@ -325,8 +316,7 @@ def load_models_and_documents():
|
|
325 |
make_points(
|
326 |
chunks,
|
327 |
dense_embeddings,
|
328 |
-
|
329 |
-
values,
|
330 |
)
|
331 |
)
|
332 |
client.update_collection(
|
@@ -334,9 +324,9 @@ def load_models_and_documents():
|
|
334 |
optimizer_config=models.OptimizersConfigDiff(indexing_threshold=20000)
|
335 |
)
|
336 |
|
337 |
-
return client, collection_name, llm, dense_model, sparse_model
|
338 |
|
339 |
-
def chunk_documents(docs, dense_model, sparse_model
|
340 |
text_splitter = SemanticChunker(
|
341 |
dense_model,
|
342 |
breakpoint_threshold_type='standard_deviation'
|
@@ -344,10 +334,10 @@ def chunk_documents(docs, dense_model, sparse_model, sparse_tokenizer):
|
|
344 |
|
345 |
documents = [doc.page_content for doc in text_splitter.transform_documents(list(docs))]
|
346 |
|
347 |
-
dense_embeddings =
|
348 |
-
|
349 |
|
350 |
-
return documents, dense_embeddings,
|
351 |
|
352 |
if __name__ == '__main__':
|
353 |
st.set_page_config(page_title="Video Game Assistant",
|
@@ -356,7 +346,7 @@ if __name__ == '__main__':
|
|
356 |
st.title("Video Game Assistant :sunglasses:")
|
357 |
|
358 |
if 'models_loaded' not in st.session_state:
|
359 |
-
st.session_state.client, st.session_state.collection_name, st.session_state.llm, st.session_state.dense_model, st.session_state.sparse_model
|
360 |
st.session_state.models_loaded = True
|
361 |
|
362 |
if st.session.state.models_loaded:
|
@@ -371,7 +361,7 @@ if __name__ == '__main__':
|
|
371 |
st.chat_message("user").markdown(prompt)
|
372 |
st.session_state.messages.append({"role": "user", "content": prompt})
|
373 |
|
374 |
-
ai_response = main(prompt, st.session_state.client, st.session_state.collection_name, st.session_state.llm, st.session_state.dense_model, st.session_state.sparse_model
|
375 |
response = f"Echo: {ai_response}"
|
376 |
with st.chat_message("assistant"):
|
377 |
message_placeholder = st.empty()
|
|
|
6 |
import numpy as np
|
7 |
import streamlit as st
|
8 |
from numpy import ndarray
|
|
|
9 |
from scipy.sparse import csr_matrix, save_npz, load_npz, vstack
|
10 |
from qdrant_client import QdrantClient, models
|
11 |
from langchain_community.llms.llamacpp import LlamaCpp
|
|
|
14 |
from langchain_core.prompts import PromptTemplate
|
15 |
from langchain.chains.summarize import load_summarize_chain
|
16 |
from langchain_experimental.text_splitter import SemanticChunker
|
17 |
+
from fastembed.sparse.splade_pp import supported_splade_models
|
18 |
+
from fastembed import SparseTextEmbedding, SparseEmbedding
|
19 |
+
from fastembed_ext import FastEmbedEmbeddingsLc
|
20 |
from langchain_core.documents import Document
|
21 |
from huggingface_hub import hf_hub_download
|
22 |
from qdrant_client.models import (
|
|
|
47 |
"""
|
48 |
|
49 |
|
50 |
+
def make_points(chunks: list[str], dense: list[ndarray], sparse)-> list[PointStruct]:
|
51 |
points = []
|
52 |
+
for idx, (sparse_vector, chunk, dense_vector) in enumerate(zip(sparse, chunks, dense)):
|
53 |
+
sparse_vec = SparseVector(indices=sparse_vector.indices.tolist(), values=sparse_vector.values.tolist())
|
54 |
point = PointStruct(
|
55 |
id=idx,
|
56 |
vector={
|
|
|
64 |
points.append(point)
|
65 |
return points
|
66 |
|
67 |
+
def search(client: QdrantClient, collection_name: str, dense, sparse):
|
68 |
search_results = client.search_batch(
|
69 |
collection_name,
|
70 |
[
|
|
|
79 |
vector=NamedSparseVector(
|
80 |
name="text-sparse",
|
81 |
vector=SparseVector(
|
82 |
+
indices=sparse[0].indices.tolist(),
|
83 |
+
values=sparse[0].values.tolist(),
|
84 |
),
|
85 |
),
|
86 |
limit=10
|
|
|
117 |
return sorted_items
|
118 |
|
119 |
|
120 |
+
def main(query: str, client: QdrantClient, collection_name: str, llm, dense_model, sparse_model):
|
121 |
# name = 'Kia_EV6'
|
122 |
# filepath = os.path.join(os.getcwd(), name + '.pdf')
|
123 |
|
|
|
131 |
# )
|
132 |
|
133 |
# docs = docs.load()
|
|
|
134 |
|
135 |
+
dense_query = list(dense_model.embed_query(query, 32))
|
136 |
+
sparse_query = list(sparse_model.embed(query, 32))
|
137 |
|
138 |
search_results = search(
|
139 |
client,
|
140 |
collection_name,
|
141 |
dense_query,
|
142 |
+
sparse_query
|
|
|
143 |
)
|
144 |
|
145 |
dense_rank_list, sparse_rank_list = rank_list(search_results[0]), rank_list(search_results[1])
|
|
|
182 |
output = reduce_chain.invoke([summaries])
|
183 |
return output['output_text']
|
184 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
def load_models_and_documents():
|
186 |
+
supported_splade_models[0] = {
|
187 |
+
"model": "prithivida/Splade_PP_en_v2",
|
188 |
+
"vocab_size": 30522,
|
189 |
+
"description": "Implementation of SPLADE++ Model for English v2",
|
190 |
+
"size_in_GB": 0.532,
|
191 |
+
"sources": {
|
192 |
+
"hf": "devve1/Splade_PP_en_v2_onnx"
|
193 |
+
},
|
194 |
+
"model_file": "model.onnx"
|
195 |
+
}
|
196 |
+
|
197 |
with st.spinner('Load models...'):
|
198 |
model_path = hf_hub_download(repo_id='NousResearch/Hermes-2-Theta-Llama-3-8B-GGUF',
|
199 |
filename='Hermes-2-Pro-Llama-3-Instruct-Merged-DPO-Q8_0.gguf'
|
|
|
207 |
n_batch=512,
|
208 |
f16_kv=True
|
209 |
)
|
|
|
|
|
|
|
|
|
210 |
|
211 |
+
provider = ['CPUExecutionProvider']
|
212 |
+
|
213 |
+
dense_model = FastEmbedEmbeddingsLc(
|
214 |
+
model_name='mixedbread-ai/mxbai-embed-large-v1',
|
215 |
+
providers=provider,
|
216 |
+
cache_dir=os.getenv('HF_HOME'),
|
217 |
+
batch_size=32
|
218 |
+
)
|
219 |
+
|
220 |
+
sparse_model = SparseTextEmbedding(
|
221 |
+
'prithivida/Splade_PP_en_v2',
|
222 |
+
cache_dir=os.getenv('HF_HOME'),
|
223 |
+
providers=provider
|
224 |
+
)
|
225 |
|
226 |
client = QdrantClient(path=os.getenv('HF_HOME'))
|
227 |
collection_name = 'collection_demo'
|
|
|
273 |
os.mkdir(embeddings_path)
|
274 |
|
275 |
docs = WikipediaLoader(query='Action-RPG').load()
|
276 |
+
chunks, dense_embeddings, sparse_embeddings = chunk_documents(docs, dense_model, sparse_model)
|
277 |
|
278 |
with open(chunks_path, "wb") as outfile:
|
279 |
packed = msgpack.packb(chunks, use_bin_type=True)
|
280 |
outfile.write(packed)
|
281 |
|
282 |
np.savez_compressed(dense_path, *dense_embeddings)
|
283 |
+
max_index = max(np.max(embedding.indices) for embedding in sparse_embeddings)
|
284 |
|
285 |
sparse_matrices = []
|
286 |
+
for embedding in sparse_embeddings:
|
287 |
+
data = embedding.values
|
288 |
+
indices = embedding.indices
|
289 |
indptr = np.array([0, len(data)])
|
290 |
+
matrix = csr_matrix((data, indices, indptr), shape=(1, max_index + 1))
|
291 |
sparse_matrices.append(matrix)
|
292 |
|
293 |
combined_sparse_matrix = vstack(sparse_matrices)
|
|
|
300 |
|
301 |
dense_embeddings = list(np.load(dense_path).values())
|
302 |
|
303 |
+
sparse_embeddings = []
|
|
|
304 |
loaded_sparse_matrix = load_npz(sparse_path)
|
305 |
|
306 |
for i in range(loaded_sparse_matrix.shape[0]):
|
307 |
row = loaded_sparse_matrix.getrow(i)
|
308 |
+
values = row.data
|
309 |
+
indices = row.indices
|
310 |
+
embedding = SparseEmbedding(values, indices)
|
311 |
+
sparse_embeddings.append(embedding)
|
312 |
|
313 |
with st.spinner('Save documents...'):
|
314 |
client.upsert(
|
|
|
316 |
make_points(
|
317 |
chunks,
|
318 |
dense_embeddings,
|
319 |
+
sparse_embeddings
|
|
|
320 |
)
|
321 |
)
|
322 |
client.update_collection(
|
|
|
324 |
optimizer_config=models.OptimizersConfigDiff(indexing_threshold=20000)
|
325 |
)
|
326 |
|
327 |
+
return client, collection_name, llm, dense_model, sparse_model
|
328 |
|
329 |
+
def chunk_documents(docs, dense_model, sparse_model):
|
330 |
text_splitter = SemanticChunker(
|
331 |
dense_model,
|
332 |
breakpoint_threshold_type='standard_deviation'
|
|
|
334 |
|
335 |
documents = [doc.page_content for doc in text_splitter.transform_documents(list(docs))]
|
336 |
|
337 |
+
dense_embeddings = dense_model.embed_documents(documents,32)
|
338 |
+
sparse_embeddings = list(sparse_model.embed(documents, 32))
|
339 |
|
340 |
+
return documents, dense_embeddings, sparse_embeddings
|
341 |
|
342 |
if __name__ == '__main__':
|
343 |
st.set_page_config(page_title="Video Game Assistant",
|
|
|
346 |
st.title("Video Game Assistant :sunglasses:")
|
347 |
|
348 |
if 'models_loaded' not in st.session_state:
|
349 |
+
st.session_state.client, st.session_state.collection_name, st.session_state.llm, st.session_state.dense_model, st.session_state.sparse_model = load_models_and_documents()
|
350 |
st.session_state.models_loaded = True
|
351 |
|
352 |
if st.session.state.models_loaded:
|
|
|
361 |
st.chat_message("user").markdown(prompt)
|
362 |
st.session_state.messages.append({"role": "user", "content": prompt})
|
363 |
|
364 |
+
ai_response = main(prompt, st.session_state.client, st.session_state.collection_name, st.session_state.llm, st.session_state.dense_model, st.session_state.sparse_model)
|
365 |
response = f"Echo: {ai_response}"
|
366 |
with st.chat_message("assistant"):
|
367 |
message_placeholder = st.empty()
|