devve1 commited on
Commit
5ea926c
1 Parent(s): 68ec80e

Update caching.py

Browse files
Files changed (1) hide show
  1. caching.py +93 -0
caching.py CHANGED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import msgpack
3
+ import numpy as np
4
+ from scipy.sparse import csr_matrix, vstack, save_npz, load_npz
5
+ from some_wikipedia_loader_module import WikipediaLoader
6
+ from some_chunking_module import chunk_documents, SparseEmbedding
7
+
8
+ class GameScribesCacheEmbeddings:
9
+ def __init__(self, namespace, queries, dense_model, sparse_model, hf_home=None):
10
+ self.namespace = namespace
11
+ self.queries = queries
12
+ self.dense_model = dense_model
13
+ self.sparse_model = sparse_model
14
+ self.hf_home = hf_home or os.getenv('HF_HOME')
15
+ self.embeddings_path = os.path.join(self.hf_home, 'embeddings')
16
+ self.chunks_path = os.path.join(self.embeddings_path, f'{self.namespace}_chunks.msgpack')
17
+ self.dense_path = os.path.join(self.embeddings_path, f'{self.namespace}_dense.npz')
18
+ self.sparse_path = os.path.join(self.embeddings_path, f'{self.namespace}_sparse.npz')
19
+
20
+ if not os.path.exists(self.embeddings_path):
21
+ os.mkdir(self.embeddings_path)
22
+
23
+ def load_or_create_chunks(self):
24
+ if not os.path.exists(self.chunks_path):
25
+ docs = self._load_documents()
26
+ chunks, _, _ = chunk_documents(docs, self.dense_model, self.sparse_model)
27
+ self._save_chunks(chunks)
28
+ else:
29
+ chunks = self._load_chunks()
30
+ return chunks
31
+
32
+ def load_or_create_dense_embeddings(self):
33
+ if not os.path.exists(self.dense_path):
34
+ docs = self._load_documents()
35
+ _, dense_embeddings, _ = chunk_documents(docs, self.dense_model, self.sparse_model)
36
+ self._save_dense_embeddings(dense_embeddings)
37
+ else:
38
+ dense_embeddings = self._load_dense_embeddings()
39
+ return dense_embeddings
40
+
41
+ def load_or_create_sparse_embeddings(self):
42
+ if not os.path.exists(self.sparse_path):
43
+ docs = self._load_documents()
44
+ _, _, sparse_embeddings = chunk_documents(docs, self.dense_model, self.sparse_model)
45
+ self._save_sparse_embeddings(sparse_embeddings)
46
+ else:
47
+ sparse_embeddings = self._load_sparse_embeddings()
48
+ return sparse_embeddings
49
+
50
+ def _load_documents(self):
51
+ docs = []
52
+ for query in self.queries:
53
+ docs.extend(WikipediaLoader(query=query).load())
54
+ return docs
55
+
56
+ def _save_chunks(self, chunks):
57
+ with open(self.chunks_path, "wb") as outfile:
58
+ packed = msgpack.packb(chunks, use_bin_type=True)
59
+ outfile.write(packed)
60
+
61
+ def _save_dense_embeddings(self, dense_embeddings):
62
+ np.savez_compressed(self.dense_path, *dense_embeddings)
63
+
64
+ def _save_sparse_embeddings(self, sparse_embeddings):
65
+ max_index = max(np.max(embedding.indices) for embedding in sparse_embeddings)
66
+ sparse_matrices = []
67
+ for embedding in sparse_embeddings:
68
+ data = embedding.values
69
+ indices = embedding.indices
70
+ indptr = np.array([0, len(data)])
71
+ matrix = csr_matrix((data, indices, indptr), shape=(1, max_index + 1))
72
+ sparse_matrices.append(matrix)
73
+ combined_sparse_matrix = vstack(sparse_matrices)
74
+ save_npz(self.sparse_path, combined_sparse_matrix)
75
+
76
+ def _load_chunks(self):
77
+ with open(self.chunks_path, "rb") as data_file:
78
+ byte_data = data_file.read()
79
+ return msgpack.unpackb(byte_data, raw=False)
80
+
81
+ def _load_dense_embeddings(self):
82
+ return list(np.load(self.dense_path).values())
83
+
84
+ def _load_sparse_embeddings(self):
85
+ sparse_embeddings = []
86
+ loaded_sparse_matrix = load_npz(self.sparse_path)
87
+ for i in range(loaded_sparse_matrix.shape[0]):
88
+ row = loaded_sparse_matrix.getrow(i)
89
+ values = row.data
90
+ indices = row.indices
91
+ embedding = SparseEmbedding(values, indices)
92
+ sparse_embeddings.append(embedding)
93
+ return sparse_embeddings