import os import msgpack import numpy as np from scipy.sparse import csr_matrix, vstack, save_npz, load_npz from some_wikipedia_loader_module import WikipediaLoader from some_chunking_module import chunk_documents, SparseEmbedding class GameScribesCacheEmbeddings: def __init__(self, namespace, queries, dense_model, sparse_model, hf_home=None): self.namespace = namespace self.queries = queries self.dense_model = dense_model self.sparse_model = sparse_model self.hf_home = hf_home or os.getenv('HF_HOME') self.embeddings_path = os.path.join(self.hf_home, 'embeddings') self.chunks_path = os.path.join(self.embeddings_path, f'{self.namespace}_chunks.msgpack') self.dense_path = os.path.join(self.embeddings_path, f'{self.namespace}_dense.npz') self.sparse_path = os.path.join(self.embeddings_path, f'{self.namespace}_sparse.npz') if not os.path.exists(self.embeddings_path): os.mkdir(self.embeddings_path) def load_or_create_chunks(self): if not os.path.exists(self.chunks_path): docs = self._load_documents() chunks, _, _ = chunk_documents(docs, self.dense_model, self.sparse_model) self._save_chunks(chunks) else: chunks = self._load_chunks() return chunks def load_or_create_dense_embeddings(self): if not os.path.exists(self.dense_path): docs = self._load_documents() _, dense_embeddings, _ = chunk_documents(docs, self.dense_model, self.sparse_model) self._save_dense_embeddings(dense_embeddings) else: dense_embeddings = self._load_dense_embeddings() return dense_embeddings def load_or_create_sparse_embeddings(self): if not os.path.exists(self.sparse_path): docs = self._load_documents() _, _, sparse_embeddings = chunk_documents(docs, self.dense_model, self.sparse_model) self._save_sparse_embeddings(sparse_embeddings) else: sparse_embeddings = self._load_sparse_embeddings() return sparse_embeddings def _load_documents(self): docs = [] for query in self.queries: docs.extend(WikipediaLoader(query=query).load()) return docs def _save_chunks(self, chunks): with open(self.chunks_path, "wb") as outfile: packed = msgpack.packb(chunks, use_bin_type=True) outfile.write(packed) def _save_dense_embeddings(self, dense_embeddings): np.savez_compressed(self.dense_path, *dense_embeddings) def _save_sparse_embeddings(self, sparse_embeddings): max_index = max(np.max(embedding.indices) for embedding in sparse_embeddings) sparse_matrices = [] for embedding in sparse_embeddings: data = embedding.values indices = embedding.indices indptr = np.array([0, len(data)]) matrix = csr_matrix((data, indices, indptr), shape=(1, max_index + 1)) sparse_matrices.append(matrix) combined_sparse_matrix = vstack(sparse_matrices) save_npz(self.sparse_path, combined_sparse_matrix) def _load_chunks(self): with open(self.chunks_path, "rb") as data_file: byte_data = data_file.read() return msgpack.unpackb(byte_data, raw=False) def _load_dense_embeddings(self): return list(np.load(self.dense_path).values()) def _load_sparse_embeddings(self): sparse_embeddings = [] loaded_sparse_matrix = load_npz(self.sparse_path) for i in range(loaded_sparse_matrix.shape[0]): row = loaded_sparse_matrix.getrow(i) values = row.data indices = row.indices embedding = SparseEmbedding(values, indices) sparse_embeddings.append(embedding) return sparse_embeddings