File size: 3,913 Bytes
5ea926c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import msgpack
import numpy as np
from scipy.sparse import csr_matrix, vstack, save_npz, load_npz
from some_wikipedia_loader_module import WikipediaLoader
from some_chunking_module import chunk_documents, SparseEmbedding

class GameScribesCacheEmbeddings:
    def __init__(self, namespace, queries, dense_model, sparse_model, hf_home=None):
        self.namespace = namespace
        self.queries = queries
        self.dense_model = dense_model
        self.sparse_model = sparse_model
        self.hf_home = hf_home or os.getenv('HF_HOME')
        self.embeddings_path = os.path.join(self.hf_home, 'embeddings')
        self.chunks_path = os.path.join(self.embeddings_path, f'{self.namespace}_chunks.msgpack')
        self.dense_path = os.path.join(self.embeddings_path, f'{self.namespace}_dense.npz')
        self.sparse_path = os.path.join(self.embeddings_path, f'{self.namespace}_sparse.npz')
        
        if not os.path.exists(self.embeddings_path):
            os.mkdir(self.embeddings_path)

    def load_or_create_chunks(self):
        if not os.path.exists(self.chunks_path):
            docs = self._load_documents()
            chunks, _, _ = chunk_documents(docs, self.dense_model, self.sparse_model)
            self._save_chunks(chunks)
        else:
            chunks = self._load_chunks()
        return chunks

    def load_or_create_dense_embeddings(self):
        if not os.path.exists(self.dense_path):
            docs = self._load_documents()
            _, dense_embeddings, _ = chunk_documents(docs, self.dense_model, self.sparse_model)
            self._save_dense_embeddings(dense_embeddings)
        else:
            dense_embeddings = self._load_dense_embeddings()
        return dense_embeddings

    def load_or_create_sparse_embeddings(self):
        if not os.path.exists(self.sparse_path):
            docs = self._load_documents()
            _, _, sparse_embeddings = chunk_documents(docs, self.dense_model, self.sparse_model)
            self._save_sparse_embeddings(sparse_embeddings)
        else:
            sparse_embeddings = self._load_sparse_embeddings()
        return sparse_embeddings

    def _load_documents(self):
        docs = []
        for query in self.queries:
            docs.extend(WikipediaLoader(query=query).load())
        return docs

    def _save_chunks(self, chunks):
        with open(self.chunks_path, "wb") as outfile:
            packed = msgpack.packb(chunks, use_bin_type=True)
            outfile.write(packed)

    def _save_dense_embeddings(self, dense_embeddings):
        np.savez_compressed(self.dense_path, *dense_embeddings)

    def _save_sparse_embeddings(self, sparse_embeddings):
        max_index = max(np.max(embedding.indices) for embedding in sparse_embeddings)
        sparse_matrices = []
        for embedding in sparse_embeddings:
            data = embedding.values
            indices = embedding.indices
            indptr = np.array([0, len(data)])
            matrix = csr_matrix((data, indices, indptr), shape=(1, max_index + 1))
            sparse_matrices.append(matrix)
        combined_sparse_matrix = vstack(sparse_matrices)
        save_npz(self.sparse_path, combined_sparse_matrix)

    def _load_chunks(self):
        with open(self.chunks_path, "rb") as data_file:
            byte_data = data_file.read()
        return msgpack.unpackb(byte_data, raw=False)

    def _load_dense_embeddings(self):
        return list(np.load(self.dense_path).values())

    def _load_sparse_embeddings(self):
        sparse_embeddings = []
        loaded_sparse_matrix = load_npz(self.sparse_path)
        for i in range(loaded_sparse_matrix.shape[0]):
            row = loaded_sparse_matrix.getrow(i)
            values = row.data
            indices = row.indices
            embedding = SparseEmbedding(values, indices)
            sparse_embeddings.append(embedding)
        return sparse_embeddings