devve1's picture
Update caching.py
5ea926c verified
raw
history blame
No virus
3.91 kB
import os
import msgpack
import numpy as np
from scipy.sparse import csr_matrix, vstack, save_npz, load_npz
from some_wikipedia_loader_module import WikipediaLoader
from some_chunking_module import chunk_documents, SparseEmbedding
class GameScribesCacheEmbeddings:
def __init__(self, namespace, queries, dense_model, sparse_model, hf_home=None):
self.namespace = namespace
self.queries = queries
self.dense_model = dense_model
self.sparse_model = sparse_model
self.hf_home = hf_home or os.getenv('HF_HOME')
self.embeddings_path = os.path.join(self.hf_home, 'embeddings')
self.chunks_path = os.path.join(self.embeddings_path, f'{self.namespace}_chunks.msgpack')
self.dense_path = os.path.join(self.embeddings_path, f'{self.namespace}_dense.npz')
self.sparse_path = os.path.join(self.embeddings_path, f'{self.namespace}_sparse.npz')
if not os.path.exists(self.embeddings_path):
os.mkdir(self.embeddings_path)
def load_or_create_chunks(self):
if not os.path.exists(self.chunks_path):
docs = self._load_documents()
chunks, _, _ = chunk_documents(docs, self.dense_model, self.sparse_model)
self._save_chunks(chunks)
else:
chunks = self._load_chunks()
return chunks
def load_or_create_dense_embeddings(self):
if not os.path.exists(self.dense_path):
docs = self._load_documents()
_, dense_embeddings, _ = chunk_documents(docs, self.dense_model, self.sparse_model)
self._save_dense_embeddings(dense_embeddings)
else:
dense_embeddings = self._load_dense_embeddings()
return dense_embeddings
def load_or_create_sparse_embeddings(self):
if not os.path.exists(self.sparse_path):
docs = self._load_documents()
_, _, sparse_embeddings = chunk_documents(docs, self.dense_model, self.sparse_model)
self._save_sparse_embeddings(sparse_embeddings)
else:
sparse_embeddings = self._load_sparse_embeddings()
return sparse_embeddings
def _load_documents(self):
docs = []
for query in self.queries:
docs.extend(WikipediaLoader(query=query).load())
return docs
def _save_chunks(self, chunks):
with open(self.chunks_path, "wb") as outfile:
packed = msgpack.packb(chunks, use_bin_type=True)
outfile.write(packed)
def _save_dense_embeddings(self, dense_embeddings):
np.savez_compressed(self.dense_path, *dense_embeddings)
def _save_sparse_embeddings(self, sparse_embeddings):
max_index = max(np.max(embedding.indices) for embedding in sparse_embeddings)
sparse_matrices = []
for embedding in sparse_embeddings:
data = embedding.values
indices = embedding.indices
indptr = np.array([0, len(data)])
matrix = csr_matrix((data, indices, indptr), shape=(1, max_index + 1))
sparse_matrices.append(matrix)
combined_sparse_matrix = vstack(sparse_matrices)
save_npz(self.sparse_path, combined_sparse_matrix)
def _load_chunks(self):
with open(self.chunks_path, "rb") as data_file:
byte_data = data_file.read()
return msgpack.unpackb(byte_data, raw=False)
def _load_dense_embeddings(self):
return list(np.load(self.dense_path).values())
def _load_sparse_embeddings(self):
sparse_embeddings = []
loaded_sparse_matrix = load_npz(self.sparse_path)
for i in range(loaded_sparse_matrix.shape[0]):
row = loaded_sparse_matrix.getrow(i)
values = row.data
indices = row.indices
embedding = SparseEmbedding(values, indices)
sparse_embeddings.append(embedding)
return sparse_embeddings