|
import asyncio |
|
import html |
|
import json |
|
import logging |
|
import os |
|
import pdb |
|
import pickle |
|
import random |
|
import time |
|
from typing import Dict, List, Optional, Tuple |
|
|
|
import aiofiles |
|
import chromadb |
|
import logfire |
|
import pandas as pd |
|
from custom_retriever import CustomRetriever |
|
from llama_index.agent.openai import OpenAIAgent |
|
from llama_index.core import Document, SimpleKeywordTableIndex, VectorStoreIndex |
|
from llama_index.core.base.base_retriever import BaseRetriever |
|
from llama_index.core.bridge.pydantic import Field, SerializeAsAny |
|
from llama_index.core.chat_engine.types import ( |
|
AGENT_CHAT_RESPONSE_TYPE, |
|
AgentChatResponse, |
|
ChatResponseMode, |
|
) |
|
from llama_index.core.evaluation import ( |
|
AnswerRelevancyEvaluator, |
|
BatchEvalRunner, |
|
EmbeddingQAFinetuneDataset, |
|
FaithfulnessEvaluator, |
|
RelevancyEvaluator, |
|
) |
|
from llama_index.core.evaluation.base import EvaluationResult |
|
from llama_index.core.evaluation.retrieval.base import ( |
|
BaseRetrievalEvaluator, |
|
RetrievalEvalMode, |
|
RetrievalEvalResult, |
|
) |
|
from llama_index.core.indices.base_retriever import BaseRetriever |
|
from llama_index.core.ingestion import IngestionPipeline |
|
from llama_index.core.node_parser import SentenceSplitter |
|
from llama_index.core.postprocessor.types import BaseNodePostprocessor |
|
from llama_index.core.retrievers import ( |
|
BaseRetriever, |
|
KeywordTableSimpleRetriever, |
|
VectorIndexRetriever, |
|
) |
|
from llama_index.core.schema import ImageNode, NodeWithScore, QueryBundle, TextNode |
|
from llama_index.core.tools import RetrieverTool, ToolMetadata |
|
from llama_index.core.vector_stores import ( |
|
FilterOperator, |
|
MetadataFilter, |
|
MetadataFilters, |
|
) |
|
from llama_index.embeddings.cohere import CohereEmbedding |
|
from llama_index.embeddings.openai import OpenAIEmbedding |
|
from llama_index.llms.gemini import Gemini |
|
from llama_index.llms.openai import OpenAI |
|
from llama_index.vector_stores.chroma import ChromaVectorStore |
|
from prompts import system_message_openai_agent |
|
from pydantic import BaseModel, Field |
|
from tqdm.asyncio import tqdm_asyncio |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RotatingJSONLWriter: |
|
def __init__( |
|
self, base_filename: str, max_size: int = 10**6, backup_count: int = 5 |
|
): |
|
""" |
|
Initialize the rotating JSONL writer. |
|
|
|
Args: |
|
base_filename (str): The base filename for the JSONL files. |
|
max_size (int): Maximum size in bytes before rotating. |
|
backup_count (int): Number of backup files to keep. |
|
""" |
|
self.base_filename = base_filename |
|
self.max_size = max_size |
|
self.backup_count = backup_count |
|
self.current_file = base_filename |
|
|
|
async def write(self, data: dict): |
|
|
|
if ( |
|
os.path.exists(self.current_file) |
|
and os.path.getsize(self.current_file) > self.max_size |
|
): |
|
await self.rotate_files() |
|
|
|
async with aiofiles.open(self.current_file, "a", encoding="utf-8") as f: |
|
await f.write(json.dumps(data, ensure_ascii=False) + "\n") |
|
|
|
async def rotate_files(self): |
|
|
|
oldest_backup = f"{self.base_filename}.{self.backup_count}" |
|
if os.path.exists(oldest_backup): |
|
os.remove(oldest_backup) |
|
|
|
|
|
for i in range(self.backup_count - 1, 0, -1): |
|
src = f"{self.base_filename}.{i}" |
|
dst = f"{self.base_filename}.{i + 1}" |
|
if os.path.exists(src): |
|
os.rename(src, dst) |
|
|
|
|
|
os.rename(self.current_file, f"{self.base_filename}.1") |
|
|
|
|
|
class AsyncKeywordTableSimpleRetriever(KeywordTableSimpleRetriever): |
|
async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: |
|
loop = asyncio.get_event_loop() |
|
return await loop.run_in_executor(None, self._retrieve, query_bundle) |
|
|
|
|
|
class SampleableEmbeddingQADataset: |
|
def __init__(self, dataset: EmbeddingQAFinetuneDataset): |
|
self.dataset = dataset |
|
|
|
def sample(self, n: int) -> EmbeddingQAFinetuneDataset: |
|
""" |
|
Sample n queries from the dataset. |
|
|
|
Args: |
|
n (int): Number of queries to sample. |
|
|
|
Returns: |
|
EmbeddingQAFinetuneDataset: A new dataset with the sampled queries. |
|
""" |
|
if n > len(self.dataset.queries): |
|
raise ValueError( |
|
f"n ({n}) is greater than the number of queries ({len(self.dataset.queries)})" |
|
) |
|
|
|
sampled_query_ids = random.sample(list(self.dataset.queries.keys()), n) |
|
|
|
sampled_queries = {qid: self.dataset.queries[qid] for qid in sampled_query_ids} |
|
sampled_relevant_docs = { |
|
qid: self.dataset.relevant_docs[qid] for qid in sampled_query_ids |
|
} |
|
|
|
|
|
sampled_doc_ids = set() |
|
for doc_ids in sampled_relevant_docs.values(): |
|
sampled_doc_ids.update(doc_ids) |
|
|
|
sampled_corpus = { |
|
doc_id: self.dataset.corpus[doc_id] for doc_id in sampled_doc_ids |
|
} |
|
|
|
return EmbeddingQAFinetuneDataset( |
|
queries=sampled_queries, |
|
corpus=sampled_corpus, |
|
relevant_docs=sampled_relevant_docs, |
|
mode=self.dataset.mode, |
|
) |
|
|
|
def __getattr__(self, name): |
|
return getattr(self.dataset, name) |
|
|
|
|
|
class RetrieverEvaluator(BaseRetrievalEvaluator): |
|
"""Retriever evaluator. |
|
|
|
This module will evaluate a retriever using a set of metrics. |
|
|
|
Args: |
|
metrics (List[BaseRetrievalMetric]): Sequence of metrics to evaluate |
|
retriever: Retriever to evaluate. |
|
node_postprocessors (Optional[List[BaseNodePostprocessor]]): Post-processor to apply after retrieval. |
|
""" |
|
|
|
retriever: BaseRetriever = Field(..., description="Retriever to evaluate") |
|
node_postprocessors: Optional[List[SerializeAsAny[BaseNodePostprocessor]]] = Field( |
|
default=None, description="Optional post-processor" |
|
) |
|
|
|
async def _aget_retrieved_ids_and_texts( |
|
self, |
|
query: str, |
|
mode: RetrievalEvalMode = RetrievalEvalMode.TEXT, |
|
source: str = "", |
|
) -> Tuple[List[str], List[str]]: |
|
"""Get retrieved ids and texts, potentially applying a post-processor.""" |
|
try: |
|
retrieved_nodes: list[NodeWithScore] = await self.retriever.aretrieve(query) |
|
logfire.info(f"Retrieved {len(retrieved_nodes)} nodes for: '{query}'") |
|
except Exception as e: |
|
return ["00000000-0000-0000-0000-000000000000"], [str(e)] |
|
|
|
if len(retrieved_nodes) == 0 or retrieved_nodes is None: |
|
print(f"No nodes retrieved for {query}") |
|
return ["00000000-0000-0000-0000-000000000000"], ["No nodes retrieved"] |
|
|
|
if self.node_postprocessors: |
|
for node_postprocessor in self.node_postprocessors: |
|
retrieved_nodes = node_postprocessor.postprocess_nodes( |
|
retrieved_nodes, query_str=query |
|
) |
|
|
|
return ( |
|
[node.node.node_id for node in retrieved_nodes], |
|
[node.node.text for node in retrieved_nodes], |
|
) |
|
|
|
|
|
class OpenAIAgentRetrieverEvaluator(BaseRetrievalEvaluator): |
|
agent: OpenAIAgent = Field(description="The OpenAI agent used for retrieval") |
|
|
|
async def _aget_retrieved_ids_and_texts( |
|
self, |
|
query: str, |
|
mode: RetrievalEvalMode = RetrievalEvalMode.TEXT, |
|
source: str = "", |
|
) -> Tuple[List[str], List[str]]: |
|
|
|
self.agent.memory.reset() |
|
|
|
try: |
|
logfire.info(f"Executing agent with query: {query}") |
|
response: AgentChatResponse = await self.agent.achat(query) |
|
except Exception as e: |
|
|
|
|
|
|
|
return ["00000000-0000-0000-0000-000000000000"], [str(e)] |
|
|
|
retrieved_nodes: list[NodeWithScore] = get_nodes_with_score(response) |
|
logfire.info(f"Retrieved {len(retrieved_nodes)} to answer: '{query}'") |
|
retrieved_nodes = retrieved_nodes[:6] |
|
|
|
if len(retrieved_nodes) == 0 or retrieved_nodes is None: |
|
|
|
|
|
|
|
return ["00000000-0000-0000-0000-000000000000"], ["No nodes retrieved"] |
|
|
|
retrieved_ids = [node.node.node_id for node in retrieved_nodes] |
|
retrieved_texts = [node.node.text for node in retrieved_nodes] |
|
|
|
|
|
await self._save_response_data_async( |
|
source=source, query=query, context="", response=response.response |
|
) |
|
|
|
return retrieved_ids, retrieved_texts |
|
|
|
async def _save_response_data_async(self, source, query, context, response): |
|
data = { |
|
"source": source, |
|
"question": query, |
|
|
|
"answer": response, |
|
} |
|
await rotating_writer.write(data) |
|
|
|
|
|
def get_nodes_with_score(completion) -> list[NodeWithScore]: |
|
retrieved_nodes = [] |
|
for source in completion.sources: |
|
if source.is_error == True: |
|
continue |
|
for node in source.raw_output: |
|
retrieved_nodes.append(node) |
|
return retrieved_nodes |
|
|
|
|
|
def setup_basic_database(db_collection, dict_file_name, keyword_retriever): |
|
db = chromadb.PersistentClient(path=f"data/{db_collection}") |
|
chroma_collection = db.get_or_create_collection(db_collection) |
|
vector_store = ChromaVectorStore(chroma_collection=chroma_collection) |
|
|
|
|
|
embed_model = CohereEmbedding( |
|
api_key=os.environ["COHERE_API_KEY"], |
|
model_name="embed-english-v3.0", |
|
input_type="search_query", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
index = VectorStoreIndex.from_vector_store( |
|
vector_store=vector_store, |
|
show_progress=True, |
|
) |
|
vector_retriever = VectorIndexRetriever( |
|
index=index, |
|
similarity_top_k=15, |
|
embed_model=embed_model, |
|
) |
|
with open(f"data/{db_collection}/{dict_file_name}", "rb") as f: |
|
document_dict = pickle.load(f) |
|
|
|
return CustomRetriever(vector_retriever, document_dict, keyword_retriever, "OR") |
|
|
|
|
|
def update_query_engine_tools(selected_sources, custom_retriever_all_sources): |
|
tools = [] |
|
source_mapping = { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"All Sources": ( |
|
custom_retriever_all_sources, |
|
"all_sources_info", |
|
"""Useful for all questions, contains information about the field of AI.""", |
|
), |
|
} |
|
|
|
for source in selected_sources: |
|
if source in source_mapping: |
|
retriever, name, description = source_mapping[source] |
|
tools.append( |
|
RetrieverTool( |
|
retriever=retriever, |
|
metadata=ToolMetadata( |
|
name=name, |
|
description=description, |
|
), |
|
) |
|
) |
|
|
|
return tools |
|
|
|
|
|
def setup_agent(custom_retriever_all_sources) -> OpenAIAgent: |
|
|
|
llm = OpenAI( |
|
temperature=1, |
|
|
|
model="gpt-4o-mini", |
|
max_tokens=5000, |
|
max_retries=3, |
|
) |
|
client = llm._get_client() |
|
logfire.instrument_openai(client) |
|
aclient = llm._get_aclient() |
|
logfire.instrument_openai(aclient) |
|
|
|
tools_available = [ |
|
|
|
|
|
|
|
|
|
|
|
|
|
"All Sources", |
|
] |
|
query_engine_tools = update_query_engine_tools( |
|
tools_available, custom_retriever_all_sources |
|
) |
|
|
|
agent = OpenAIAgent.from_tools( |
|
llm=llm, |
|
tools=query_engine_tools, |
|
system_prompt=system_message_openai_agent, |
|
) |
|
|
|
return agent |
|
|
|
|
|
async def evaluate_answers(): |
|
start_time = time.time() |
|
|
|
|
|
|
|
llm = OpenAI(model="gpt-4o-mini", temperature=1, max_tokens=1000) |
|
relevancy_evaluator = AnswerRelevancyEvaluator(llm=llm) |
|
|
|
|
|
query_response_pairs = [] |
|
with open("response_data.jsonl", "r") as f: |
|
for line in f: |
|
data = json.loads(line) |
|
query_response_pairs.append( |
|
(data["source"], data["query"], data["response"]) |
|
) |
|
|
|
logfire.info(f"Number of queries and answers: {len(query_response_pairs)}") |
|
|
|
semaphore = asyncio.Semaphore(90) |
|
|
|
async def evaluate_query_response(source, query, response): |
|
async with semaphore: |
|
try: |
|
result: EvaluationResult = await relevancy_evaluator.aevaluate( |
|
query=query, response=response |
|
) |
|
return source, result |
|
except Exception as e: |
|
logfire.error(f"Error evaluating query for {source}: {str(e)}") |
|
return source, None |
|
|
|
|
|
results = await tqdm_asyncio.gather( |
|
*[ |
|
evaluate_query_response(source, query, response) |
|
for source, query, response in query_response_pairs |
|
], |
|
desc="Evaluating answers", |
|
total=len(query_response_pairs), |
|
) |
|
|
|
|
|
eval_results = {} |
|
for item in results: |
|
if isinstance(item, tuple) and len(item) == 2: |
|
source, result = item |
|
if result is not None: |
|
if source not in eval_results: |
|
eval_results[source] = [] |
|
eval_results[source].append(result) |
|
else: |
|
logfire.error(f"Unexpected result: {item}") |
|
|
|
|
|
for source, results in eval_results.items(): |
|
with open(f"eval_answers_results_{source}.pkl", "wb") as f: |
|
pickle.dump(results, f) |
|
|
|
end_time = time.time() |
|
logfire.info(f"Total evaluation time: {round(end_time - start_time, 3)} seconds") |
|
|
|
return eval_results |
|
|
|
|
|
def create_docs(input_file: str) -> List[Document]: |
|
with open(input_file, "r") as f: |
|
documents = [] |
|
for line in f: |
|
data = json.loads(line) |
|
documents.append( |
|
Document( |
|
doc_id=data["doc_id"], |
|
text=data["content"], |
|
metadata={ |
|
"url": data["url"], |
|
"title": data["name"], |
|
"tokens": data["tokens"], |
|
"retrieve_doc": data["retrieve_doc"], |
|
"source": data["source"], |
|
}, |
|
excluded_llm_metadata_keys=[ |
|
"title", |
|
"tokens", |
|
"retrieve_doc", |
|
"source", |
|
], |
|
excluded_embed_metadata_keys=[ |
|
"url", |
|
"tokens", |
|
"retrieve_doc", |
|
"source", |
|
], |
|
) |
|
) |
|
return documents |
|
|
|
|
|
def get_sample_size(source: str, total_queries: int) -> int: |
|
"""Determine the number of queries to sample based on the source.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
small_datasets = {"peft": 49, "trl": 34, "openai_cookbooks": 170} |
|
large_datasets = { |
|
"transformers": 200, |
|
"llama_index": 200, |
|
"langchain": 200, |
|
"tai_blog": 200, |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if source in small_datasets: |
|
return small_datasets[source] |
|
elif source in large_datasets: |
|
return large_datasets[source] |
|
else: |
|
return min(100, total_queries) |
|
|
|
|
|
async def evaluate_retriever(): |
|
start_time = time.time() |
|
with open("data/keyword_retriever_async.pkl", "rb") as f: |
|
keyword_retriever = pickle.load(f) |
|
|
|
custom_retriever_all_sources: CustomRetriever = setup_basic_database( |
|
"chroma-db-all_sources", "document_dict_all_sources.pkl", keyword_retriever |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
end_time = time.time() |
|
logfire.info( |
|
f"Time taken for setup the custom retriever: {round(end_time - start_time, 2)} seconds" |
|
) |
|
|
|
sources_to_evaluate = [ |
|
"transformers", |
|
"peft", |
|
"trl", |
|
"llama_index", |
|
"langchain", |
|
"openai_cookbooks", |
|
"tai_blog", |
|
] |
|
|
|
|
|
|
|
|
|
retriever_evaluator = RetrieverEvaluator.from_metric_names( |
|
["mrr", "hit_rate"], retriever=custom_retriever_all_sources |
|
) |
|
|
|
|
|
|
|
|
|
all_query_pairs = [] |
|
for source in sources_to_evaluate: |
|
rag_eval_dataset = EmbeddingQAFinetuneDataset.from_json( |
|
f"scripts/rag_eval_{source}.json" |
|
) |
|
sampleable_dataset = SampleableEmbeddingQADataset(rag_eval_dataset) |
|
sample_size = get_sample_size(source, len(sampleable_dataset.queries)) |
|
sampled_dataset = sampleable_dataset.sample(n=sample_size) |
|
query_expected_ids_pairs = sampled_dataset.query_docid_pairs |
|
all_query_pairs.extend( |
|
[(source, pair[0], pair[1]) for pair in query_expected_ids_pairs] |
|
) |
|
|
|
semaphore = asyncio.Semaphore(220) |
|
|
|
|
|
async def evaluate_query(source, query, expected_ids): |
|
async with semaphore: |
|
try: |
|
result: RetrievalEvalResult = await retriever_evaluator.aevaluate( |
|
query=query, |
|
expected_ids=expected_ids, |
|
mode=RetrievalEvalMode.TEXT, |
|
source=source, |
|
) |
|
return source, result |
|
except Exception as e: |
|
logfire.error(f"Error evaluating query for {source}: {str(e)}") |
|
return source, None |
|
|
|
|
|
results = await tqdm_asyncio.gather( |
|
*[ |
|
evaluate_query(source, query, expected_ids) |
|
for source, query, expected_ids in all_query_pairs |
|
], |
|
desc="Evaluating queries", |
|
total=len(all_query_pairs), |
|
) |
|
|
|
|
|
eval_results = {source: [] for source in sources_to_evaluate} |
|
for item in results: |
|
if isinstance(item, tuple) and len(item) == 2: |
|
source, result = item |
|
if result is not None: |
|
eval_results[source].append(result) |
|
else: |
|
logfire.error(f"Unexpected result: {item}") |
|
|
|
|
|
for source, results in eval_results.items(): |
|
with open(f"eval_results_{source}.pkl", "wb") as f: |
|
pickle.dump(results, f) |
|
|
|
|
|
end_time = time.time() |
|
logfire.info(f"Total evaluation time: {round(end_time - start_time, 3)} seconds") |
|
|
|
|
|
def display_results_retriever(name, eval_results): |
|
"""Display results from evaluate.""" |
|
|
|
metric_dicts = [] |
|
for eval_result in eval_results: |
|
metric_dict = eval_result.metric_vals_dict |
|
metric_dicts.append(metric_dict) |
|
|
|
full_df = pd.DataFrame(metric_dicts) |
|
|
|
hit_rate = full_df["hit_rate"].mean() |
|
mrr = full_df["mrr"].mean() |
|
|
|
metric_df = pd.DataFrame( |
|
{"Retriever Name": [name], "Hit Rate": [hit_rate], "MRR": [mrr]} |
|
) |
|
|
|
return metric_df |
|
|
|
|
|
def display_results(): |
|
|
|
sources = [ |
|
"transformers", |
|
"peft", |
|
"trl", |
|
"llama_index", |
|
"langchain", |
|
"openai_cookbooks", |
|
"tai_blog", |
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if True: |
|
|
|
|
|
for source in sources: |
|
with open(f"eval_results_{source}.pkl", "rb") as f: |
|
eval_results = pickle.load(f) |
|
print(display_results_retriever(f"{source}", eval_results)) |
|
|
|
|
|
def display_results_answers(): |
|
|
|
sources = [ |
|
"transformers", |
|
"peft", |
|
"trl", |
|
"llama_index", |
|
"langchain", |
|
"openai_cookbooks", |
|
"tai_blog", |
|
] |
|
|
|
for source in sources: |
|
with open(f"eval_answers_results_{source}.pkl", "rb") as f: |
|
eval_results = pickle.load(f) |
|
print( |
|
f"Score for {source}:", |
|
sum(result.score for result in eval_results) / len(eval_results), |
|
) |
|
|
|
|
|
async def main(): |
|
await evaluate_retriever() |
|
display_results() |
|
|
|
|
|
return |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
logfire.configure() |
|
rotating_writer = RotatingJSONLWriter( |
|
"response_data.jsonl", max_size=10**7, backup_count=5 |
|
) |
|
|
|
start_time = time.time() |
|
asyncio.run(main()) |
|
end_time = time.time() |
|
logfire.info( |
|
f"Time taken to run script: {round((end_time - start_time), 3)} seconds" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|