File size: 12,024 Bytes
8a58cf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
"""File for core data structures."""

import random
import sys
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Set, Tuple

from dataclasses_json import DataClassJsonMixin

from gpt_index.data_structs.struct_type import IndexStructType
from gpt_index.schema import BaseDocument
from gpt_index.utils import get_new_int_id


@dataclass
class IndexStruct(BaseDocument, DataClassJsonMixin):
    """A base data struct for a LlamaIndex."""

    # NOTE: the text field, inherited from BaseDocument,
    # represents a summary of the content of the index struct.
    # primarily used for composing indices with other indices

    # NOTE: the doc_id field, inherited from BaseDocument,
    # represents a unique identifier for the index struct
    # that will be put in the docstore.
    # Not all index_structs need to have a doc_id. Only index_structs that
    # represent a complete data structure (e.g. IndexGraph, IndexList),
    # and are used to compose a higher level index, will have a doc_id.


@dataclass
class Node(IndexStruct):
    """A generic node of data.

    Base struct used in most indices.

    """

    def __post_init__(self) -> None:
        """Post init."""
        super().__post_init__()
        # NOTE: for Node objects, the text field is required
        if self.text is None:
            raise ValueError("text field not set.")

    # used for GPTTreeIndex
    index: int = 0
    child_indices: Set[int] = field(default_factory=set)

    # embeddings
    embedding: Optional[List[float]] = None

    # reference document id
    ref_doc_id: Optional[str] = None

    # extra node info
    node_info: Optional[Dict[str, Any]] = None

    def get_text(self) -> str:
        """Get text."""
        text = super().get_text()
        result_text = (
            text if self.extra_info_str is None else f"{self.extra_info_str}\n\n{text}"
        )
        return result_text

    @classmethod
    def get_type(cls) -> str:
        """Get type."""
        # TODO: consolidate with IndexStructType
        return "node"


@dataclass
class IndexGraph(IndexStruct):
    """A graph representing the tree-structured index."""

    all_nodes: Dict[int, Node] = field(default_factory=dict)
    root_nodes: Dict[int, Node] = field(default_factory=dict)

    @property
    def size(self) -> int:
        """Get the size of the graph."""
        return len(self.all_nodes)

    def get_children(self, parent_node: Optional[Node]) -> Dict[int, Node]:
        """Get nodes given indices."""
        if parent_node is None:
            return self.root_nodes
        else:
            return {i: self.all_nodes[i] for i in parent_node.child_indices}

    def insert_under_parent(self, node: Node, parent_node: Optional[Node]) -> None:
        """Insert under parent node."""
        if node.index in self.all_nodes:
            raise ValueError(
                "Cannot insert a new node with the same index as an existing node."
            )
        if parent_node is None:
            self.root_nodes[node.index] = node
        else:
            parent_node.child_indices.add(node.index)

        self.all_nodes[node.index] = node

    @classmethod
    def get_type(cls) -> str:
        """Get type."""
        return "tree"


@dataclass
class KeywordTable(IndexStruct):
    """A table of keywords mapping keywords to text chunks."""

    table: Dict[str, Set[int]] = field(default_factory=dict)
    text_chunks: Dict[int, Node] = field(default_factory=dict)

    def _get_index(self) -> int:
        """Get the next index for the text chunk."""
        # randomly generate until we get a unique index
        while True:
            idx = random.randint(0, sys.maxsize)
            if idx not in self.text_chunks:
                break
        return idx

    def add_node(self, keywords: List[str], node: Node) -> int:
        """Add text to table."""
        cur_idx = self._get_index()
        for keyword in keywords:
            if keyword not in self.table:
                self.table[keyword] = set()
            self.table[keyword].add(cur_idx)
        self.text_chunks[cur_idx] = node
        return cur_idx

    def get_texts(self, keyword: str) -> List[str]:
        """Get texts given keyword."""
        if keyword not in self.table:
            raise ValueError("Keyword not found in table.")
        return [self.text_chunks[idx].get_text() for idx in self.table[keyword]]

    @property
    def keywords(self) -> Set[str]:
        """Get all keywords in the table."""
        return set(self.table.keys())

    @property
    def size(self) -> int:
        """Get the size of the table."""
        return len(self.table)

    @classmethod
    def get_type(cls) -> str:
        """Get type."""
        return "keyword_table"


@dataclass
class IndexList(IndexStruct):
    """A list of documents."""

    nodes: List[Node] = field(default_factory=list)

    def add_node(self, node: Node) -> None:
        """Add text to table, return current position in list."""
        # don't worry about child indices for now, nodes are all in order
        self.nodes.append(node)

    @classmethod
    def get_type(cls) -> str:
        """Get type."""
        return "list"


@dataclass
class IndexDict(IndexStruct):
    """A simple dictionary of documents."""

    nodes_dict: Dict[int, Node] = field(default_factory=dict)
    id_map: Dict[str, int] = field(default_factory=dict)

    # TODO: temporary hack to store embeddings for simple vector index
    # this should be empty for all other indices
    embeddings_dict: Dict[str, List[float]] = field(default_factory=dict)

    def add_node(
        self,
        node: Node,
        text_id: Optional[str] = None,
    ) -> str:
        """Add text to table, return current position in list."""
        int_id = get_new_int_id(set(self.nodes_dict.keys()))
        if text_id in self.id_map:
            raise ValueError("text_id cannot already exist in index.")
        elif text_id is not None and not isinstance(text_id, str):
            raise ValueError("text_id must be a string.")
        elif text_id is None:
            text_id = str(int_id)
        self.id_map[text_id] = int_id

        # don't worry about child indices for now, nodes are all in order
        self.nodes_dict[int_id] = node
        return text_id

    def get_nodes(self, text_ids: List[str]) -> List[Node]:
        """Get nodes."""
        nodes = []
        for text_id in text_ids:
            if text_id not in self.id_map:
                raise ValueError("text_id not found in id_map")
            elif not isinstance(text_id, str):
                raise ValueError("text_id must be a string.")
            int_id = self.id_map[text_id]
            if int_id not in self.nodes_dict:
                raise ValueError("int_id not found in nodes_dict")
            nodes.append(self.nodes_dict[int_id])
        return nodes

    def get_node(self, text_id: str) -> Node:
        """Get node."""
        return self.get_nodes([text_id])[0]

    def delete(self, doc_id: str) -> None:
        """Delete a document."""
        text_ids_to_delete = set()
        int_ids_to_delete = set()
        for text_id, int_id in self.id_map.items():
            node = self.nodes_dict[int_id]
            if node.ref_doc_id != doc_id:
                continue
            text_ids_to_delete.add(text_id)
            int_ids_to_delete.add(int_id)

        for int_id, text_id in zip(int_ids_to_delete, text_ids_to_delete):
            del self.nodes_dict[int_id]
            del self.id_map[text_id]

    @classmethod
    def get_type(cls) -> str:
        """Get type."""
        return IndexStructType.VECTOR_STORE


@dataclass
class KG(IndexStruct):
    """A table of keywords mapping keywords to text chunks."""

    # Unidirectional

    table: Dict[str, Set[str]] = field(default_factory=dict)
    text_chunks: Dict[str, Node] = field(default_factory=dict)
    rel_map: Dict[str, List[Tuple[str, str]]] = field(default_factory=dict)
    embedding_dict: Dict[str, List[float]] = field(default_factory=dict)

    def add_to_embedding_dict(self, triplet_str: str, embedding: List[float]) -> None:
        """Add embedding to dict."""
        self.embedding_dict[triplet_str] = embedding

    def upsert_triplet(self, triplet: Tuple[str, str, str], node: Node) -> None:
        """Upsert a knowledge triplet to the graph."""
        subj, relationship, obj = triplet
        self.add_node([subj, obj], node)
        if subj not in self.rel_map:
            self.rel_map[subj] = []
        self.rel_map[subj].append((obj, relationship))

    def add_node(self, keywords: List[str], node: Node) -> None:
        """Add text to table."""
        node_id = node.get_doc_id()
        for keyword in keywords:
            if keyword not in self.table:
                self.table[keyword] = set()
            self.table[keyword].add(node_id)
        self.text_chunks[node_id] = node

    def get_rel_map_texts(self, keyword: str) -> List[str]:
        """Get the corresponding knowledge for a given keyword."""
        # NOTE: return a single node for now
        if keyword not in self.rel_map:
            return []
        texts = []
        for obj, rel in self.rel_map[keyword]:
            texts.append(str((keyword, rel, obj)))
        return texts

    def get_rel_map_tuples(self, keyword: str) -> List[Tuple[str, str]]:
        """Get the corresponding knowledge for a given keyword."""
        # NOTE: return a single node for now
        if keyword not in self.rel_map:
            return []
        return self.rel_map[keyword]

    def get_node_ids(self, keyword: str, depth: int = 1) -> List[str]:
        """Get the corresponding knowledge for a given keyword."""
        if depth > 1:
            raise ValueError("Depth > 1 not supported yet.")
        if keyword not in self.table:
            return []
        keywords = [keyword]
        # some keywords may correspond to a leaf node, may not be in rel_map
        if keyword in self.rel_map:
            keywords.extend([child for child, _ in self.rel_map[keyword]])

        node_ids: List[str] = []
        for keyword in keywords:
            for node_id in self.table.get(keyword, set()):
                node_ids.append(node_id)
            # TODO: Traverse (with depth > 1)
        return node_ids

    @classmethod
    def get_type(cls) -> str:
        """Get type."""
        return "kg"


# TODO: remove once we centralize UX around vector index


class SimpleIndexDict(IndexDict):
    """Index dict for simple vector index."""

    @classmethod
    def get_type(cls) -> str:
        """Get type."""
        return IndexStructType.SIMPLE_DICT


class FaissIndexDict(IndexDict):
    """Index dict for Faiss vector index."""

    @classmethod
    def get_type(cls) -> str:
        """Get type."""
        return IndexStructType.DICT


class WeaviateIndexDict(IndexDict):
    """Index dict for Weaviate vector index."""

    @classmethod
    def get_type(cls) -> str:
        """Get type."""
        return IndexStructType.WEAVIATE


class PineconeIndexDict(IndexDict):
    """Index dict for Pinecone vector index."""

    @classmethod
    def get_type(cls) -> str:
        """Get type."""
        return IndexStructType.PINECONE


class QdrantIndexDict(IndexDict):
    """Index dict for Qdrant vector index."""

    @classmethod
    def get_type(cls) -> str:
        """Get type."""
        return IndexStructType.QDRANT


class ChromaIndexDict(IndexDict):
    """Index dict for Chroma vector index."""

    @classmethod
    def get_type(cls) -> str:
        """Get type."""
        return IndexStructType.CHROMA


class OpensearchIndexDict(IndexDict):
    """Index dict for Opensearch vector index."""

    @classmethod
    def get_type(cls) -> str:
        """Get type."""
        return IndexStructType.OPENSEARCH


class EmptyIndex(IndexStruct):
    """Empty index."""

    @classmethod
    def get_type(cls) -> str:
        """Get type."""
        return IndexStructType.EMPTY