|
import time |
|
from typing import List |
|
|
|
import logfire |
|
from llama_index.core import QueryBundle |
|
from llama_index.core.retrievers import BaseRetriever, VectorIndexRetriever |
|
from llama_index.core.schema import NodeWithScore, TextNode |
|
from llama_index.postprocessor.cohere_rerank import CohereRerank |
|
|
|
|
|
class CustomRetriever(BaseRetriever): |
|
"""Custom retriever that performs both semantic search and hybrid search.""" |
|
|
|
def __init__( |
|
self, |
|
vector_retriever: VectorIndexRetriever, |
|
document_dict: dict, |
|
) -> None: |
|
"""Init params.""" |
|
|
|
self._vector_retriever = vector_retriever |
|
self._document_dict = document_dict |
|
super().__init__() |
|
|
|
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: |
|
"""Retrieve nodes given query.""" |
|
|
|
|
|
query_bundle.query_str = query_bundle.query_str.replace("\ninput is ", "") |
|
query_bundle.query_str = query_bundle.query_str.rstrip() |
|
|
|
logfire.info(f"Retrieving 10 nodes with string: '{query_bundle}'") |
|
start = time.time() |
|
nodes = self._vector_retriever.retrieve(query_bundle) |
|
|
|
duration = time.time() - start |
|
logfire.info(f"Retrieving nodes took {duration:.2f}s") |
|
|
|
|
|
def filter_nodes_by_unique_doc_id(nodes): |
|
unique_nodes = {} |
|
for node in nodes: |
|
doc_id = node.node.ref_doc_id |
|
if doc_id is not None and doc_id not in unique_nodes: |
|
unique_nodes[doc_id] = node |
|
return list(unique_nodes.values()) |
|
|
|
nodes = filter_nodes_by_unique_doc_id(nodes) |
|
logfire.info( |
|
f"Number of nodes after filtering the ones with same ref_doc_id: {len(nodes)}" |
|
) |
|
logfire.info(f"Nodes retrieved: {nodes}") |
|
|
|
nodes_context = [] |
|
for node in nodes: |
|
|
|
|
|
|
|
|
|
|
|
|
|
if node.score < 0.2: |
|
continue |
|
if node.metadata["retrieve_doc"] == True: |
|
|
|
doc = self._document_dict[node.node.ref_doc_id] |
|
|
|
new_node = NodeWithScore( |
|
node=TextNode(text=doc.text, metadata=node.metadata), |
|
score=node.score, |
|
) |
|
nodes_context.append(new_node) |
|
else: |
|
nodes_context.append(node) |
|
|
|
reranker = CohereRerank(top_n=5, model="rerank-english-v3.0") |
|
nodes_context = reranker.postprocess_nodes(nodes_context, query_bundle) |
|
nodes_filtered = [] |
|
for node in nodes_context: |
|
if node.score < 0.15: |
|
continue |
|
else: |
|
nodes_filtered.append(node) |
|
logfire.info(f"Cohere raranking to {len(nodes_filtered)} nodes") |
|
return nodes_filtered |
|
|