Spaces:
Running
Running
import asyncio | |
import weave | |
from medrag_multi_modal.assistant import LLMClient, MedQAAssistant | |
from medrag_multi_modal.metrics import MMLUOptionAccuracy | |
from medrag_multi_modal.retrieval.text_retrieval import ( | |
BM25sRetriever, | |
ContrieverRetriever, | |
MedCPTRetriever, | |
NVEmbed2Retriever, | |
) | |
def test_mmlu_correctness_anatomy_bm25s(model_name: str): | |
weave.init("ml-colabs/medrag-multi-modal") | |
retriever = BM25sRetriever().from_index( | |
index_repo_id="ashwiniai/medrag-text-corpus-chunks-bm25s" | |
) | |
llm_client = LLMClient(model_name=model_name) | |
medqa_assistant = MedQAAssistant( | |
llm_client=llm_client, | |
retriever=retriever, | |
top_k_chunks_for_query=5, | |
top_k_chunks_for_options=3, | |
) | |
dataset = weave.ref("mmlu-anatomy-test:v2").get() | |
with weave.attributes( | |
{"retriever": retriever.__class__.__name__, "llm": llm_client.model_name} | |
): | |
evaluation = weave.Evaluation( | |
dataset=dataset, | |
scorers=[MMLUOptionAccuracy()], | |
name="MMLU-Anatomy-BM25s", | |
) | |
summary = asyncio.run( | |
evaluation.evaluate( | |
medqa_assistant, | |
__weave={"display_name": evaluation.name + ":" + llm_client.model_name}, | |
) | |
) | |
assert ( | |
summary["MMLUOptionAccuracy"]["correct"]["true_count"] | |
> summary["MMLUOptionAccuracy"]["correct"]["false_count"] | |
) | |
def test_mmlu_correctness_anatomy_contriever(model_name: str): | |
weave.init("ml-colabs/medrag-multi-modal") | |
retriever = ContrieverRetriever().from_index( | |
index_repo_id="ashwiniai/medrag-text-corpus-chunks-contriever", | |
chunk_dataset="ashwiniai/medrag-text-corpus-chunks", | |
) | |
llm_client = LLMClient(model_name=model_name) | |
medqa_assistant = MedQAAssistant( | |
llm_client=llm_client, | |
retriever=retriever, | |
top_k_chunks_for_query=5, | |
top_k_chunks_for_options=3, | |
) | |
dataset = weave.ref("mmlu-anatomy-test:v2").get() | |
with weave.attributes( | |
{"retriever": retriever.__class__.__name__, "llm": llm_client.model_name} | |
): | |
evaluation = weave.Evaluation( | |
dataset=dataset, | |
scorers=[MMLUOptionAccuracy()], | |
name="MMLU-Anatomy-Contriever", | |
) | |
summary = asyncio.run( | |
evaluation.evaluate( | |
medqa_assistant, | |
__weave={"display_name": evaluation.name + ":" + llm_client.model_name}, | |
) | |
) | |
assert ( | |
summary["MMLUOptionAccuracy"]["correct"]["true_count"] | |
> summary["MMLUOptionAccuracy"]["correct"]["false_count"] | |
) | |
def test_mmlu_correctness_anatomy_medcpt(model_name: str): | |
weave.init("ml-colabs/medrag-multi-modal") | |
retriever = MedCPTRetriever().from_index( | |
index_repo_id="ashwiniai/medrag-text-corpus-chunks-medcpt", | |
chunk_dataset="ashwiniai/medrag-text-corpus-chunks", | |
) | |
llm_client = LLMClient(model_name=model_name) | |
medqa_assistant = MedQAAssistant( | |
llm_client=llm_client, | |
retriever=retriever, | |
top_k_chunks_for_query=5, | |
top_k_chunks_for_options=3, | |
) | |
dataset = weave.ref("mmlu-anatomy-test:v2").get() | |
with weave.attributes( | |
{"retriever": retriever.__class__.__name__, "llm": llm_client.model_name} | |
): | |
evaluation = weave.Evaluation( | |
dataset=dataset, | |
scorers=[MMLUOptionAccuracy()], | |
name="MMLU-Anatomy-MedCPT", | |
) | |
summary = asyncio.run( | |
evaluation.evaluate( | |
medqa_assistant, | |
__weave={"display_name": evaluation.name + ":" + llm_client.model_name}, | |
) | |
) | |
assert ( | |
summary["MMLUOptionAccuracy"]["correct"]["true_count"] | |
> summary["MMLUOptionAccuracy"]["correct"]["false_count"] | |
) | |
def test_mmlu_correctness_anatomy_nvembed2(model_name: str): | |
weave.init("ml-colabs/medrag-multi-modal") | |
retriever = NVEmbed2Retriever().from_index( | |
index_repo_id="ashwiniai/medrag-text-corpus-chunks-nv-embed-2", | |
chunk_dataset="ashwiniai/medrag-text-corpus-chunks", | |
) | |
llm_client = LLMClient(model_name=model_name) | |
medqa_assistant = MedQAAssistant( | |
llm_client=llm_client, | |
retriever=retriever, | |
top_k_chunks_for_query=5, | |
top_k_chunks_for_options=3, | |
) | |
dataset = weave.ref("mmlu-anatomy-test:v2").get() | |
with weave.attributes( | |
{"retriever": retriever.__class__.__name__, "llm": llm_client.model_name} | |
): | |
evaluation = weave.Evaluation( | |
dataset=dataset, | |
scorers=[MMLUOptionAccuracy()], | |
name="MMLU-Anatomy-NVEmbed2", | |
) | |
summary = asyncio.run( | |
evaluation.evaluate( | |
medqa_assistant, | |
__weave={"display_name": evaluation.name + ":" + llm_client.model_name}, | |
) | |
) | |
assert ( | |
summary["MMLUOptionAccuracy"]["correct"]["true_count"] | |
> summary["MMLUOptionAccuracy"]["correct"]["false_count"] | |
) | |