|
import gradio as gr |
|
import pandas as pd |
|
|
|
|
|
UGI_COLS = [ |
|
'#P', 'Model', 'UGI 🏆', 'Willingness👍', 'QuActivities', 'Internet', 'CrimeStats', 'Stories/Jokes', 'Pol Contro' |
|
] |
|
|
|
|
|
def load_leaderboard_data(csv_file_path): |
|
try: |
|
df = pd.read_csv(csv_file_path) |
|
|
|
df['Model'] = df.apply(lambda row: f'<a href="{row["Link"]}" target="_blank" style="color: blue; text-decoration: none;">{row["Model"]}</a>' if pd.notna(row["Link"]) else row["Model"], axis=1) |
|
|
|
df.drop(columns=['Link'], inplace=True) |
|
return df |
|
except Exception as e: |
|
print(f"Error loading CSV file: {e}") |
|
return pd.DataFrame(columns=UGI_COLS) |
|
|
|
|
|
def update_table(df: pd.DataFrame, query: str, param_ranges: dict) -> pd.DataFrame: |
|
filtered_df = df |
|
if any(param_ranges.values()): |
|
conditions = [] |
|
for param_range, checked in param_ranges.items(): |
|
if checked: |
|
if param_range == '~1.5': |
|
conditions.append((filtered_df['Params'] < 2.5)) |
|
elif param_range == '~3': |
|
conditions.append(((filtered_df['Params'] >= 2.5) & (filtered_df['Params'] < 6))) |
|
elif param_range == '~7': |
|
conditions.append(((filtered_df['Params'] >= 6) & (filtered_df['Params'] < 9.5))) |
|
elif param_range == '~13': |
|
conditions.append(((filtered_df['Params'] >= 9.5) & (filtered_df['Params'] < 16))) |
|
elif param_range == '~20': |
|
conditions.append(((filtered_df['Params'] >= 16) & (filtered_df['Params'] < 28))) |
|
elif param_range == '~34': |
|
conditions.append(((filtered_df['Params'] >= 28) & (filtered_df['Params'] < 40))) |
|
elif param_range == '~50': |
|
conditions.append(((filtered_df['Params'] >= 40) & (filtered_df['Params'] < 60))) |
|
elif param_range == '~70+': |
|
conditions.append((filtered_df['Params'] >= 60)) |
|
|
|
if all(param_ranges.values()): |
|
conditions.append(filtered_df['Params'].isna()) |
|
|
|
filtered_df = filtered_df[pd.concat(conditions, axis=1).any(axis=1)] |
|
|
|
if query: |
|
filtered_df = filtered_df[filtered_df.apply(lambda row: query.lower() in row.to_string().lower(), axis=1)] |
|
|
|
return filtered_df[UGI_COLS] |
|
|
|
|
|
demo = gr.Blocks() |
|
|
|
with demo: |
|
gr.Markdown("## UGI Leaderboard", elem_classes="text-lg") |
|
with gr.Column(): |
|
with gr.Row(): |
|
search_bar = gr.Textbox(placeholder=" 🔍 Search for a model...", show_label=False) |
|
with gr.Row(): |
|
gr.Markdown("Model sizes (in billions of parameters)", elem_classes="text-sm") |
|
param_range_1 = gr.Checkbox(label="~1.5", value=False) |
|
param_range_2 = gr.Checkbox(label="~3", value=False) |
|
param_range_3 = gr.Checkbox(label="~7", value=False) |
|
param_range_4 = gr.Checkbox(label="~13", value=False) |
|
param_range_5 = gr.Checkbox(label="~20", value=False) |
|
param_range_6 = gr.Checkbox(label="~34", value=False) |
|
param_range_7 = gr.Checkbox(label="~50", value=False) |
|
param_range_8 = gr.Checkbox(label="~70+", value=False) |
|
|
|
|
|
leaderboard_df = load_leaderboard_data("ugi-leaderboard-data.csv") |
|
|
|
|
|
datatypes = ['html' if col == 'Model' else 'str' for col in UGI_COLS] |
|
|
|
leaderboard_table = gr.Dataframe( |
|
value=leaderboard_df[UGI_COLS], |
|
datatype=datatypes, |
|
interactive=False, |
|
visible=True, |
|
elem_classes="text-sm" |
|
) |
|
|
|
|
|
inputs = [ |
|
search_bar, |
|
param_range_1, |
|
param_range_2, |
|
param_range_3, |
|
param_range_4, |
|
param_range_5, |
|
param_range_6, |
|
param_range_7, |
|
param_range_8 |
|
] |
|
|
|
outputs = leaderboard_table |
|
|
|
search_bar.change( |
|
fn=lambda query, r1, r2, r3, r4, r5, r6, r7, r8: update_table(leaderboard_df, query, { |
|
'~1.5': r1, |
|
'~3': r2, |
|
'~7': r3, |
|
'~13': r4, |
|
'~20': r5, |
|
'~34': r6, |
|
'~50': r7, |
|
'~70+': r8 |
|
}), |
|
inputs=inputs, |
|
outputs=outputs |
|
) |
|
|
|
for param_range in inputs[1:]: |
|
param_range.change( |
|
fn=lambda query, r1, r2, r3, r4, r5, r6, r7, r8: update_table(leaderboard_df, query, { |
|
'~1.5': r1, |
|
'~3': r2, |
|
'~7': r3, |
|
'~13': r4, |
|
'~20': r5, |
|
'~34': r6, |
|
'~50': r7, |
|
'~70+': r8 |
|
}), |
|
inputs=inputs, |
|
outputs=outputs |
|
) |
|
|
|
|
|
demo.launch() |