Spaces:
Building
on
T4
Building
on
T4
import os | |
import msgpack | |
import numpy as np | |
from scipy.sparse import csr_matrix, vstack, save_npz, load_npz | |
from some_wikipedia_loader_module import WikipediaLoader | |
from some_chunking_module import chunk_documents, SparseEmbedding | |
class GameScribesCacheEmbeddings: | |
def __init__(self, namespace, queries, dense_model, sparse_model, hf_home=None): | |
self.namespace = namespace | |
self.queries = queries | |
self.dense_model = dense_model | |
self.sparse_model = sparse_model | |
self.hf_home = hf_home or os.getenv('HF_HOME') | |
self.embeddings_path = os.path.join(self.hf_home, 'embeddings') | |
self.chunks_path = os.path.join(self.embeddings_path, f'{self.namespace}_chunks.msgpack') | |
self.dense_path = os.path.join(self.embeddings_path, f'{self.namespace}_dense.npz') | |
self.sparse_path = os.path.join(self.embeddings_path, f'{self.namespace}_sparse.npz') | |
if not os.path.exists(self.embeddings_path): | |
os.mkdir(self.embeddings_path) | |
def load_or_create_chunks(self): | |
if not os.path.exists(self.chunks_path): | |
docs = self._load_documents() | |
chunks, _, _ = chunk_documents(docs, self.dense_model, self.sparse_model) | |
self._save_chunks(chunks) | |
else: | |
chunks = self._load_chunks() | |
return chunks | |
def load_or_create_dense_embeddings(self): | |
if not os.path.exists(self.dense_path): | |
docs = self._load_documents() | |
_, dense_embeddings, _ = chunk_documents(docs, self.dense_model, self.sparse_model) | |
self._save_dense_embeddings(dense_embeddings) | |
else: | |
dense_embeddings = self._load_dense_embeddings() | |
return dense_embeddings | |
def load_or_create_sparse_embeddings(self): | |
if not os.path.exists(self.sparse_path): | |
docs = self._load_documents() | |
_, _, sparse_embeddings = chunk_documents(docs, self.dense_model, self.sparse_model) | |
self._save_sparse_embeddings(sparse_embeddings) | |
else: | |
sparse_embeddings = self._load_sparse_embeddings() | |
return sparse_embeddings | |
def _load_documents(self): | |
docs = [] | |
for query in self.queries: | |
docs.extend(WikipediaLoader(query=query).load()) | |
return docs | |
def _save_chunks(self, chunks): | |
with open(self.chunks_path, "wb") as outfile: | |
packed = msgpack.packb(chunks, use_bin_type=True) | |
outfile.write(packed) | |
def _save_dense_embeddings(self, dense_embeddings): | |
np.savez_compressed(self.dense_path, *dense_embeddings) | |
def _save_sparse_embeddings(self, sparse_embeddings): | |
max_index = max(np.max(embedding.indices) for embedding in sparse_embeddings) | |
sparse_matrices = [] | |
for embedding in sparse_embeddings: | |
data = embedding.values | |
indices = embedding.indices | |
indptr = np.array([0, len(data)]) | |
matrix = csr_matrix((data, indices, indptr), shape=(1, max_index + 1)) | |
sparse_matrices.append(matrix) | |
combined_sparse_matrix = vstack(sparse_matrices) | |
save_npz(self.sparse_path, combined_sparse_matrix) | |
def _load_chunks(self): | |
with open(self.chunks_path, "rb") as data_file: | |
byte_data = data_file.read() | |
return msgpack.unpackb(byte_data, raw=False) | |
def _load_dense_embeddings(self): | |
return list(np.load(self.dense_path).values()) | |
def _load_sparse_embeddings(self): | |
sparse_embeddings = [] | |
loaded_sparse_matrix = load_npz(self.sparse_path) | |
for i in range(loaded_sparse_matrix.shape[0]): | |
row = loaded_sparse_matrix.getrow(i) | |
values = row.data | |
indices = row.indices | |
embedding = SparseEmbedding(values, indices) | |
sparse_embeddings.append(embedding) | |
return sparse_embeddings | |