|
import ast |
|
import pandas as pd |
|
import gradio as gr |
|
import litellm |
|
import plotly.express as px |
|
from collections import defaultdict |
|
from datetime import datetime |
|
import os |
|
from datasets import load_dataset |
|
import sqlite3 |
|
|
|
def initialize_database(): |
|
conn = sqlite3.connect('afrimmlu_results.db') |
|
cursor = conn.cursor() |
|
|
|
cursor.execute(''' |
|
CREATE TABLE IF NOT EXISTS summary_results ( |
|
id INTEGER PRIMARY KEY AUTOINCREMENT, |
|
language TEXT, |
|
subject TEXT, |
|
accuracy REAL, |
|
timestamp TEXT |
|
) |
|
''') |
|
|
|
cursor.execute(''' |
|
CREATE TABLE IF NOT EXISTS detailed_results ( |
|
id INTEGER PRIMARY KEY AUTOINCREMENT, |
|
language TEXT, |
|
timestamp TEXT, |
|
subject TEXT, |
|
question TEXT, |
|
model_answer TEXT, |
|
correct_answer TEXT, |
|
is_correct INTEGER, |
|
total_tokens INTEGER |
|
) |
|
''') |
|
|
|
conn.commit() |
|
conn.close() |
|
|
|
def save_results_to_database(language, summary_results, detailed_results): |
|
conn = sqlite3.connect('afrimmlu_results.db') |
|
cursor = conn.cursor() |
|
timestamp = datetime.now().isoformat() |
|
|
|
|
|
for subject, accuracy in summary_results.items(): |
|
cursor.execute(''' |
|
INSERT INTO summary_results (language, subject, accuracy, timestamp) |
|
VALUES (?, ?, ?, ?) |
|
''', (language, subject, accuracy, timestamp)) |
|
|
|
|
|
for result in detailed_results: |
|
cursor.execute(''' |
|
INSERT INTO detailed_results ( |
|
language, timestamp, subject, question, model_answer, |
|
correct_answer, is_correct, total_tokens |
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?) |
|
''', ( |
|
language, |
|
result['timestamp'], |
|
result['subject'], |
|
result['question'], |
|
result['model_answer'], |
|
result['correct_answer'], |
|
int(result['is_correct']), |
|
result['total_tokens'] |
|
)) |
|
|
|
conn.commit() |
|
conn.close() |
|
|
|
def load_afrimmlu_data(language_code="swa"): |
|
""" |
|
Load AfriMMLU dataset for a specific language. |
|
""" |
|
try: |
|
dataset = load_dataset( |
|
'masakhane/afrimmlu', |
|
language_code, |
|
token=os.environ['HF_TOKEN'], |
|
) |
|
test_data = dataset['test'].to_list() |
|
return test_data |
|
except Exception as e: |
|
print(f"Error loading dataset: {str(e)}") |
|
return None |
|
|
|
def preprocess_dataset(test_data): |
|
""" |
|
Preprocess the dataset to convert the 'choices' field from a string to a list of strings. |
|
""" |
|
preprocessed_data = [] |
|
for example in test_data: |
|
if isinstance(example['choices'], str): |
|
choices_str = example['choices'] |
|
if choices_str.startswith("'") and choices_str.endswith("'"): |
|
choices_str = choices_str[1:-1] |
|
elif choices_str.startswith('"') and choices_str.endswith('"'): |
|
choices_str = choices_str[1:-1] |
|
choices_str = choices_str.replace("\\'", "'") |
|
try: |
|
example['choices'] = ast.literal_eval(choices_str) |
|
except (ValueError, SyntaxError): |
|
print(f"Error parsing choices: {choices_str}") |
|
continue |
|
preprocessed_data.append(example) |
|
return preprocessed_data |
|
|
|
def evaluate_afrimmlu(test_data, model_name="deepseek/deepseek-chat", language="swa"): |
|
""" |
|
Evaluate the model on the AfriMMLU dataset. |
|
""" |
|
results = [] |
|
correct = 0 |
|
total = 0 |
|
subject_results = defaultdict(lambda: {"correct": 0, "total": 0}) |
|
|
|
for example in test_data: |
|
question = example['question'] |
|
choices = example['choices'] |
|
answer = example['answer'] |
|
subject = example['subject'] |
|
|
|
prompt = ( |
|
f"Answer the following multiple-choice question. " |
|
f"Return only the letter corresponding to the correct answer (A, B, C, or D).\n" |
|
f"Question: {question}\n" |
|
f"Options:\n" |
|
f"A. {choices[0]}\n" |
|
f"B. {choices[1]}\n" |
|
f"C. {choices[2]}\n" |
|
f"D. {choices[3]}\n" |
|
f"Answer:" |
|
) |
|
|
|
try: |
|
response = litellm.completion( |
|
model=model_name, |
|
messages=[{"role": "user", "content": prompt}] |
|
) |
|
model_output = response.choices[0].message.content.strip().upper() |
|
|
|
model_answer = None |
|
for char in model_output: |
|
if char in ['A', 'B', 'C', 'D']: |
|
model_answer = char |
|
break |
|
|
|
is_correct = model_answer == answer.upper() |
|
if is_correct: |
|
correct += 1 |
|
subject_results[subject]["correct"] += 1 |
|
total += 1 |
|
subject_results[subject]["total"] += 1 |
|
|
|
results.append({ |
|
'timestamp': datetime.now().isoformat(), |
|
'subject': subject, |
|
'question': question, |
|
'model_answer': model_answer, |
|
'correct_answer': answer.upper(), |
|
'is_correct': is_correct, |
|
'total_tokens': response.usage.total_tokens |
|
}) |
|
|
|
except Exception as e: |
|
print(f"Error processing question: {str(e)}") |
|
continue |
|
|
|
accuracy = (correct / total * 100) if total > 0 else 0 |
|
subject_accuracy = { |
|
subject: (stats["correct"] / stats["total"] * 100) if stats["total"] > 0 else 0 |
|
for subject, stats in subject_results.items() |
|
} |
|
|
|
|
|
save_results_to_database(language, {**subject_accuracy, 'Overall': accuracy}, results) |
|
|
|
return { |
|
"accuracy": accuracy, |
|
"subject_accuracy": subject_accuracy, |
|
"detailed_results": results |
|
} |
|
|
|
def create_visualization(results_dict): |
|
""" |
|
Create visualization from evaluation results. |
|
""" |
|
summary_data = [ |
|
{'Subject': subject, 'Accuracy (%)': accuracy} |
|
for subject, accuracy in results_dict['subject_accuracy'].items() |
|
] |
|
summary_data.append({'Subject': 'Overall', 'Accuracy (%)': results_dict['accuracy']}) |
|
summary_df = pd.DataFrame(summary_data) |
|
|
|
fig = px.bar( |
|
summary_df, |
|
x='Subject', |
|
y='Accuracy (%)', |
|
title='AfriMMLU Evaluation Results', |
|
labels={'Subject': 'Subject', 'Accuracy (%)': 'Accuracy (%)'} |
|
) |
|
fig.update_layout( |
|
xaxis_tickangle=-45, |
|
showlegend=False, |
|
height=600 |
|
) |
|
|
|
return summary_df, fig |
|
|
|
|
|
def query_database(query): |
|
conn = sqlite3.connect('afrimmlu_results.db') |
|
try: |
|
df = pd.read_sql_query(query, conn) |
|
return df |
|
except Exception as e: |
|
return pd.DataFrame({'Error': [str(e)]}) |
|
finally: |
|
conn.close() |
|
|
|
def create_gradio_interface(): |
|
language_options = { |
|
"swa": "Swahili", |
|
"yor": "Yoruba", |
|
"wol": "Wolof", |
|
"lin": "Lingala", |
|
"ewe": "Ewe", |
|
"ibo": "Igbo" |
|
} |
|
|
|
initialize_database() |
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown("# AfriMMLU Evaluation Dashboard") |
|
|
|
with gr.Tabs(): |
|
|
|
with gr.Tab("Model Evaluation"): |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
language_input = gr.Dropdown( |
|
choices=list(language_options.keys()), |
|
label="Select Language", |
|
value="swa" |
|
) |
|
model_input = gr.Dropdown( |
|
choices=["deepseek/deepseek-chat"], |
|
label="Select Model", |
|
value="deepseek/deepseek-chat" |
|
) |
|
evaluate_btn = gr.Button("Evaluate", variant="primary") |
|
|
|
with gr.Row(): |
|
summary_table = gr.Dataframe( |
|
headers=["Subject", "Accuracy (%)"], |
|
label="Summary Results" |
|
) |
|
|
|
with gr.Row(): |
|
summary_plot = gr.Plot(label="Performance by Subject") |
|
|
|
with gr.Row(): |
|
detailed_results = gr.Dataframe( |
|
label="Detailed Results", |
|
wrap=True |
|
) |
|
|
|
|
|
with gr.Tab("Database Analysis"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
example_queries = gr.Dropdown( |
|
choices=[ |
|
"SELECT language, AVG(accuracy) as avg_accuracy FROM summary_results WHERE subject='Overall' GROUP BY language", |
|
"SELECT subject, AVG(accuracy) as avg_accuracy FROM summary_results GROUP BY subject", |
|
"SELECT language, subject, accuracy, timestamp FROM summary_results ORDER BY timestamp DESC LIMIT 10", |
|
"SELECT language, COUNT(*) as total_questions, SUM(is_correct) as correct_answers FROM detailed_results GROUP BY language", |
|
"SELECT subject, COUNT(*) as total_evaluations FROM summary_results GROUP BY subject" |
|
], |
|
label="Example Queries", |
|
value="SELECT language, AVG(accuracy) as avg_accuracy FROM summary_results WHERE subject='Overall' GROUP BY language" |
|
) |
|
|
|
query_input = gr.Textbox( |
|
label="SQL Query", |
|
placeholder="Enter your SQL query here", |
|
lines=3 |
|
) |
|
|
|
query_button = gr.Button("Run Query", variant="primary") |
|
|
|
gr.Markdown(""" |
|
### Available Tables: |
|
1. summary_results (id, language, subject, accuracy, timestamp) |
|
2. detailed_results (id, language, timestamp, subject, question, model_answer, correct_answer, is_correct, total_tokens) |
|
""") |
|
|
|
with gr.Row(): |
|
query_output = gr.Dataframe( |
|
label="Query Results", |
|
wrap=True |
|
) |
|
|
|
def evaluate_language(language_code, model_name): |
|
test_data = load_afrimmlu_data(language_code) |
|
if test_data is None: |
|
return None, None, None |
|
|
|
preprocessed_data = preprocess_dataset(test_data) |
|
results = evaluate_afrimmlu(preprocessed_data, model_name, language_code) |
|
summary_df, plot = create_visualization(results) |
|
detailed_df = pd.DataFrame(results["detailed_results"]) |
|
|
|
return summary_df, plot, detailed_df |
|
|
|
|
|
|
|
evaluate_btn.click( |
|
fn=evaluate_language, |
|
inputs=[language_input, model_input], |
|
outputs=[summary_table, summary_plot, detailed_results] |
|
) |
|
|
|
|
|
example_queries.change( |
|
fn=lambda x: x, |
|
inputs=[example_queries], |
|
outputs=[query_input] |
|
) |
|
|
|
query_button.click( |
|
fn=query_database, |
|
inputs=[query_input], |
|
outputs=[query_output] |
|
) |
|
|
|
return demo |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
os.environ['DEEPSEEK_API_KEY'] |
|
os.environ['HF_TOKEN'] |
|
|
|
demo = create_gradio_interface() |
|
demo.launch(share=True) |