import asyncio import os import time import traceback from typing import List, Optional import logfire import tiktoken from cohere import AsyncClient from dotenv import load_dotenv from llama_index.core import Document, QueryBundle from llama_index.core.async_utils import run_async_tasks from llama_index.core.callbacks import CBEventType, EventPayload from llama_index.core.retrievers import ( BaseRetriever, KeywordTableSimpleRetriever, VectorIndexRetriever, ) from llama_index.core.schema import MetadataMode, NodeWithScore, QueryBundle, TextNode from llama_index.core.vector_stores import ( FilterCondition, FilterOperator, MetadataFilter, MetadataFilters, ) from llama_index.postprocessor.cohere_rerank import CohereRerank from llama_index.postprocessor.cohere_rerank.base import CohereRerank load_dotenv() class AsyncCohereRerank(CohereRerank): def __init__( self, top_n: int = 5, model: str = "rerank-english-v3.0", api_key: Optional[str] = None, ) -> None: super().__init__(top_n=top_n, model=model, api_key=api_key) self._api_key = api_key self._model = model self._top_n = top_n async def apostprocess_nodes( self, nodes: List[NodeWithScore], query_bundle: Optional[QueryBundle] = None, ) -> List[NodeWithScore]: if query_bundle is None: raise ValueError("Query bundle must be provided.") if len(nodes) == 0: return [] async_client = AsyncClient(api_key=self._api_key) with self.callback_manager.event( CBEventType.RERANKING, payload={ EventPayload.NODES: nodes, EventPayload.MODEL_NAME: self._model, EventPayload.QUERY_STR: query_bundle.query_str, EventPayload.TOP_K: self._top_n, }, ) as event: texts = [ node.node.get_content(metadata_mode=MetadataMode.EMBED) for node in nodes ] results = await async_client.rerank( model=self._model, top_n=self._top_n, query=query_bundle.query_str, documents=texts, ) new_nodes = [] for result in results.results: new_node_with_score = NodeWithScore( node=nodes[result.index].node, score=result.relevance_score ) new_nodes.append(new_node_with_score) event.on_end(payload={EventPayload.NODES: new_nodes}) return new_nodes class CustomRetriever(BaseRetriever): """Custom retriever that performs both semantic search and hybrid search.""" def __init__( self, vector_retriever: VectorIndexRetriever, document_dict: dict, keyword_retriever=None, mode: str = "AND", ) -> None: """Init params.""" self._vector_retriever = vector_retriever self._document_dict = document_dict self._keyword_retriever = keyword_retriever if mode not in ("AND", "OR"): raise ValueError("Invalid mode.") self._mode = mode super().__init__() async def _process_retrieval( self, query_bundle: QueryBundle, is_async: bool = True ) -> List[NodeWithScore]: """Common processing logic for both sync and async retrieval.""" # Clean query string query_bundle.query_str = query_bundle.query_str.replace( "\ninput is ", "" ).rstrip() logfire.info(f"Retrieving nodes with string: '{query_bundle}'") start = time.time() # Get nodes from both retrievers if is_async: nodes = await self._vector_retriever.aretrieve(query_bundle) else: nodes = self._vector_retriever.retrieve(query_bundle) keyword_nodes = [] if self._keyword_retriever: if is_async: keyword_nodes = await self._keyword_retriever.aretrieve(query_bundle) else: keyword_nodes = self._keyword_retriever.retrieve(query_bundle) logfire.info(f"Number of vector nodes: {len(nodes)}") logfire.info(f"Number of keyword nodes: {len(keyword_nodes)}") # # Filter keyword nodes based on metadata filters from vector retriever # if ( # hasattr(self._vector_retriever, "_filters") # and self._vector_retriever._filters # ): # filtered_keyword_nodes = [] # for node in keyword_nodes: # node_source = node.node.metadata.get("source") # # Check if node's source matches any of the filter conditions # for filter in self._vector_retriever._filters.filters: # if ( # isinstance(filter, MetadataFilter) # and filter.key == "source" # and filter.operator == FilterOperator.EQ # and filter.value == node_source # ): # filtered_keyword_nodes.append(node) # break # keyword_nodes = filtered_keyword_nodes # logfire.info( # f"Number of keyword nodes after filtering: {len(keyword_nodes)}" # ) # Combine results based on mode vector_ids = {n.node.node_id for n in nodes} keyword_ids = {n.node.node_id for n in keyword_nodes} combined_dict = {n.node.node_id: n for n in nodes} combined_dict.update({n.node.node_id: n for n in keyword_nodes}) # If no keyword retriever or no keyword nodes, just use vector nodes if not self._keyword_retriever or not keyword_nodes: retrieve_ids = vector_ids else: retrieve_ids = ( vector_ids.intersection(keyword_ids) if self._mode == "AND" else vector_ids.union(keyword_ids) ) nodes = [combined_dict[rid] for rid in retrieve_ids] logfire.info(f"Number of combined nodes: {len(nodes)}") # Filter unique doc IDs nodes = self._filter_nodes_by_unique_doc_id(nodes) logfire.info(f"Number of nodes without duplicate doc IDs: {len(nodes)}") # Process node contents for node in nodes: doc_id = node.node.source_node.node_id if node.metadata["retrieve_doc"]: doc = self._document_dict[doc_id] node.node.text = doc.text node.node.node_id = doc_id # Rerank results try: reranker = ( AsyncCohereRerank(top_n=5, model="rerank-english-v3.0") if is_async else CohereRerank(top_n=5, model="rerank-english-v3.0") ) nodes = ( await reranker.apostprocess_nodes(nodes, query_bundle) if is_async else reranker.postprocess_nodes(nodes, query_bundle) ) except Exception as e: error_msg = f"Error during reranking: {type(e).__name__}: {str(e)}\n" error_msg += "Traceback:\n" error_msg += traceback.format_exc() logfire.error(error_msg) # Filter by score and token count nodes_filtered = self._filter_by_score_and_tokens(nodes) duration = time.time() - start logfire.info(f"Retrieving nodes took {duration:.2f}s") logfire.info(f"Nodes sent to LLM: {nodes_filtered[:5]}") return nodes_filtered[:5] def _filter_nodes_by_unique_doc_id( self, nodes: List[NodeWithScore] ) -> List[NodeWithScore]: """Filter nodes to keep only unique doc IDs.""" unique_nodes = {} for node in nodes: doc_id = node.node.source_node.node_id if doc_id is not None and doc_id not in unique_nodes: unique_nodes[doc_id] = node return list(unique_nodes.values()) def _filter_by_score_and_tokens( self, nodes: List[NodeWithScore] ) -> List[NodeWithScore]: """Filter nodes by score and token count.""" nodes_filtered = [] total_tokens = 0 enc = tiktoken.encoding_for_model("gpt-4o-mini") for node in nodes: if node.score < 0.10: logfire.info(f"Skipping node with score {node.score}") continue node_tokens = len(enc.encode(node.node.text)) if total_tokens + node_tokens > 100_000: logfire.info("Skipping node due to token count exceeding 100k") break total_tokens += node_tokens nodes_filtered.append(node) return nodes_filtered async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: """Async retrieve nodes given query.""" return await self._process_retrieval(query_bundle, is_async=True) def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: """Sync retrieve nodes given query.""" return asyncio.run(self._process_retrieval(query_bundle, is_async=False))