|
import os |
|
import msgpack |
|
import numpy as np |
|
from scipy.sparse import csr_matrix, vstack, save_npz, load_npz |
|
from some_wikipedia_loader_module import WikipediaLoader |
|
from some_chunking_module import chunk_documents, SparseEmbedding |
|
|
|
class GameScribesCacheEmbeddings: |
|
def __init__(self, namespace, queries, dense_model, sparse_model, hf_home=None): |
|
self.namespace = namespace |
|
self.queries = queries |
|
self.dense_model = dense_model |
|
self.sparse_model = sparse_model |
|
self.hf_home = hf_home or os.getenv('HF_HOME') |
|
self.embeddings_path = os.path.join(self.hf_home, 'embeddings') |
|
self.chunks_path = os.path.join(self.embeddings_path, f'{self.namespace}_chunks.msgpack') |
|
self.dense_path = os.path.join(self.embeddings_path, f'{self.namespace}_dense.npz') |
|
self.sparse_path = os.path.join(self.embeddings_path, f'{self.namespace}_sparse.npz') |
|
|
|
if not os.path.exists(self.embeddings_path): |
|
os.mkdir(self.embeddings_path) |
|
|
|
def load_or_create_chunks(self): |
|
if not os.path.exists(self.chunks_path): |
|
docs = self._load_documents() |
|
chunks, _, _ = chunk_documents(docs, self.dense_model, self.sparse_model) |
|
self._save_chunks(chunks) |
|
else: |
|
chunks = self._load_chunks() |
|
return chunks |
|
|
|
def load_or_create_dense_embeddings(self): |
|
if not os.path.exists(self.dense_path): |
|
docs = self._load_documents() |
|
_, dense_embeddings, _ = chunk_documents(docs, self.dense_model, self.sparse_model) |
|
self._save_dense_embeddings(dense_embeddings) |
|
else: |
|
dense_embeddings = self._load_dense_embeddings() |
|
return dense_embeddings |
|
|
|
def load_or_create_sparse_embeddings(self): |
|
if not os.path.exists(self.sparse_path): |
|
docs = self._load_documents() |
|
_, _, sparse_embeddings = chunk_documents(docs, self.dense_model, self.sparse_model) |
|
self._save_sparse_embeddings(sparse_embeddings) |
|
else: |
|
sparse_embeddings = self._load_sparse_embeddings() |
|
return sparse_embeddings |
|
|
|
def _load_documents(self): |
|
docs = [] |
|
for query in self.queries: |
|
docs.extend(WikipediaLoader(query=query).load()) |
|
return docs |
|
|
|
def _save_chunks(self, chunks): |
|
with open(self.chunks_path, "wb") as outfile: |
|
packed = msgpack.packb(chunks, use_bin_type=True) |
|
outfile.write(packed) |
|
|
|
def _save_dense_embeddings(self, dense_embeddings): |
|
np.savez_compressed(self.dense_path, *dense_embeddings) |
|
|
|
def _save_sparse_embeddings(self, sparse_embeddings): |
|
max_index = max(np.max(embedding.indices) for embedding in sparse_embeddings) |
|
sparse_matrices = [] |
|
for embedding in sparse_embeddings: |
|
data = embedding.values |
|
indices = embedding.indices |
|
indptr = np.array([0, len(data)]) |
|
matrix = csr_matrix((data, indices, indptr), shape=(1, max_index + 1)) |
|
sparse_matrices.append(matrix) |
|
combined_sparse_matrix = vstack(sparse_matrices) |
|
save_npz(self.sparse_path, combined_sparse_matrix) |
|
|
|
def _load_chunks(self): |
|
with open(self.chunks_path, "rb") as data_file: |
|
byte_data = data_file.read() |
|
return msgpack.unpackb(byte_data, raw=False) |
|
|
|
def _load_dense_embeddings(self): |
|
return list(np.load(self.dense_path).values()) |
|
|
|
def _load_sparse_embeddings(self): |
|
sparse_embeddings = [] |
|
loaded_sparse_matrix = load_npz(self.sparse_path) |
|
for i in range(loaded_sparse_matrix.shape[0]): |
|
row = loaded_sparse_matrix.getrow(i) |
|
values = row.data |
|
indices = row.indices |
|
embedding = SparseEmbedding(values, indices) |
|
sparse_embeddings.append(embedding) |
|
return sparse_embeddings |
|
|