File size: 2,494 Bytes
37cbdf5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
069addf
 
 
 
1281058
37cbdf5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import logging
from typing import List

from llama_index.core import QueryBundle
from llama_index.core.retrievers import BaseRetriever, VectorIndexRetriever
from llama_index.core.schema import NodeWithScore, TextNode

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


class CustomRetriever(BaseRetriever):
    """Custom retriever that performs both semantic search and hybrid search."""

    def __init__(
        self,
        vector_retriever: VectorIndexRetriever,
        document_dict: dict,
    ) -> None:
        """Init params."""

        self._vector_retriever = vector_retriever
        self._document_dict = document_dict
        super().__init__()

    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        """Retrieve nodes given query."""

        # LlamaIndex adds "\ninput is " to the query string
        query_bundle.query_str = query_bundle.query_str.replace("\ninput is ", "")
        query_bundle.query_str = query_bundle.query_str.rstrip()

        logger.info(f"Retrieving nodes for query: {query_bundle}")

        nodes = self._vector_retriever.retrieve(query_bundle)

        # Filter out nodes with the same ref_doc_id
        def filter_nodes_by_unique_doc_id(nodes):
            unique_nodes = {}
            for node in nodes:
                doc_id = node.node.ref_doc_id
                if doc_id is not None and doc_id not in unique_nodes:
                    unique_nodes[doc_id] = node
            return list(unique_nodes.values())

        nodes = filter_nodes_by_unique_doc_id(nodes)
        print(f"number of nodes after filtering: {len(nodes)}")

        nodes_context = []
        for node in nodes:
            # print("Node ID\t", node.node_id)
            # print("Title\t", node.metadata["title"])
            # print("Text\t", node.text)
            # print("Score\t", node.score)
            # print("Metadata\t", node.metadata)
            # print("-_" * 20)
            if node.metadata["retrieve_doc"] == True:
                # print("This node will be replaced by the document")
                doc = self._document_dict[node.node.ref_doc_id]
                # print(doc.text)
                new_node = NodeWithScore(
                    node=TextNode(text=doc.text, metadata=node.metadata),
                    score=node.score,
                )
                nodes_context.append(new_node)
            else:
                nodes_context.append(node)

        return nodes_context