Update caching.py
Browse files- caching.py +93 -0
caching.py
CHANGED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import msgpack
|
3 |
+
import numpy as np
|
4 |
+
from scipy.sparse import csr_matrix, vstack, save_npz, load_npz
|
5 |
+
from some_wikipedia_loader_module import WikipediaLoader
|
6 |
+
from some_chunking_module import chunk_documents, SparseEmbedding
|
7 |
+
|
8 |
+
class GameScribesCacheEmbeddings:
|
9 |
+
def __init__(self, namespace, queries, dense_model, sparse_model, hf_home=None):
|
10 |
+
self.namespace = namespace
|
11 |
+
self.queries = queries
|
12 |
+
self.dense_model = dense_model
|
13 |
+
self.sparse_model = sparse_model
|
14 |
+
self.hf_home = hf_home or os.getenv('HF_HOME')
|
15 |
+
self.embeddings_path = os.path.join(self.hf_home, 'embeddings')
|
16 |
+
self.chunks_path = os.path.join(self.embeddings_path, f'{self.namespace}_chunks.msgpack')
|
17 |
+
self.dense_path = os.path.join(self.embeddings_path, f'{self.namespace}_dense.npz')
|
18 |
+
self.sparse_path = os.path.join(self.embeddings_path, f'{self.namespace}_sparse.npz')
|
19 |
+
|
20 |
+
if not os.path.exists(self.embeddings_path):
|
21 |
+
os.mkdir(self.embeddings_path)
|
22 |
+
|
23 |
+
def load_or_create_chunks(self):
|
24 |
+
if not os.path.exists(self.chunks_path):
|
25 |
+
docs = self._load_documents()
|
26 |
+
chunks, _, _ = chunk_documents(docs, self.dense_model, self.sparse_model)
|
27 |
+
self._save_chunks(chunks)
|
28 |
+
else:
|
29 |
+
chunks = self._load_chunks()
|
30 |
+
return chunks
|
31 |
+
|
32 |
+
def load_or_create_dense_embeddings(self):
|
33 |
+
if not os.path.exists(self.dense_path):
|
34 |
+
docs = self._load_documents()
|
35 |
+
_, dense_embeddings, _ = chunk_documents(docs, self.dense_model, self.sparse_model)
|
36 |
+
self._save_dense_embeddings(dense_embeddings)
|
37 |
+
else:
|
38 |
+
dense_embeddings = self._load_dense_embeddings()
|
39 |
+
return dense_embeddings
|
40 |
+
|
41 |
+
def load_or_create_sparse_embeddings(self):
|
42 |
+
if not os.path.exists(self.sparse_path):
|
43 |
+
docs = self._load_documents()
|
44 |
+
_, _, sparse_embeddings = chunk_documents(docs, self.dense_model, self.sparse_model)
|
45 |
+
self._save_sparse_embeddings(sparse_embeddings)
|
46 |
+
else:
|
47 |
+
sparse_embeddings = self._load_sparse_embeddings()
|
48 |
+
return sparse_embeddings
|
49 |
+
|
50 |
+
def _load_documents(self):
|
51 |
+
docs = []
|
52 |
+
for query in self.queries:
|
53 |
+
docs.extend(WikipediaLoader(query=query).load())
|
54 |
+
return docs
|
55 |
+
|
56 |
+
def _save_chunks(self, chunks):
|
57 |
+
with open(self.chunks_path, "wb") as outfile:
|
58 |
+
packed = msgpack.packb(chunks, use_bin_type=True)
|
59 |
+
outfile.write(packed)
|
60 |
+
|
61 |
+
def _save_dense_embeddings(self, dense_embeddings):
|
62 |
+
np.savez_compressed(self.dense_path, *dense_embeddings)
|
63 |
+
|
64 |
+
def _save_sparse_embeddings(self, sparse_embeddings):
|
65 |
+
max_index = max(np.max(embedding.indices) for embedding in sparse_embeddings)
|
66 |
+
sparse_matrices = []
|
67 |
+
for embedding in sparse_embeddings:
|
68 |
+
data = embedding.values
|
69 |
+
indices = embedding.indices
|
70 |
+
indptr = np.array([0, len(data)])
|
71 |
+
matrix = csr_matrix((data, indices, indptr), shape=(1, max_index + 1))
|
72 |
+
sparse_matrices.append(matrix)
|
73 |
+
combined_sparse_matrix = vstack(sparse_matrices)
|
74 |
+
save_npz(self.sparse_path, combined_sparse_matrix)
|
75 |
+
|
76 |
+
def _load_chunks(self):
|
77 |
+
with open(self.chunks_path, "rb") as data_file:
|
78 |
+
byte_data = data_file.read()
|
79 |
+
return msgpack.unpackb(byte_data, raw=False)
|
80 |
+
|
81 |
+
def _load_dense_embeddings(self):
|
82 |
+
return list(np.load(self.dense_path).values())
|
83 |
+
|
84 |
+
def _load_sparse_embeddings(self):
|
85 |
+
sparse_embeddings = []
|
86 |
+
loaded_sparse_matrix = load_npz(self.sparse_path)
|
87 |
+
for i in range(loaded_sparse_matrix.shape[0]):
|
88 |
+
row = loaded_sparse_matrix.getrow(i)
|
89 |
+
values = row.data
|
90 |
+
indices = row.indices
|
91 |
+
embedding = SparseEmbedding(values, indices)
|
92 |
+
sparse_embeddings.append(embedding)
|
93 |
+
return sparse_embeddings
|