|
import gradio as gr |
|
import plotly.express as px |
|
|
|
from backend.data import load_cot_data |
|
from backend.envs import API, REPO_ID, TOKEN |
|
|
|
logo1_url = "https://raw.githubusercontent.com/logikon-ai/cot-eval/main/assets/AI2_Logo_Square.png" |
|
logo2_url = "https://raw.githubusercontent.com/logikon-ai/cot-eval/main/assets/logo_logikon_notext_withborder.png" |
|
LOGOS = f'<div style="display: flex; justify-content: center;"><a href="https://allenai.org/"><img src="{logo1_url}" alt="AI2" style="width: 30vw; min-width: 20px; max-width: 60px;"></a> <a href="https://logikon.ai"><img src="{logo2_url}" alt="Logikon AI" style="width: 30vw; min-width: 20px; max-width: 60px; margin-left: 10px;"></a></div>' |
|
|
|
TITLE = f'<h1 align="center" id="space-title"> Open CoT Dashboard</h1> {LOGOS}' |
|
|
|
INTRODUCTION_TEXT = """ |
|
Baseline accuracies and marginal accuracy gains for specific models and CoT regimes from the [Open CoT Leaderboard](https://huggingface.co./spaces/logikon/open_cot_leaderboard). |
|
""" |
|
|
|
def restart_space(): |
|
API.restart_space(repo_id=REPO_ID, token=TOKEN) |
|
|
|
try: |
|
df_cot_err, df_cot_regimes = load_cot_data() |
|
except Exception: |
|
restart_space() |
|
|
|
|
|
def plot_evals(model_id, plotly_mode, request: gr.Request): |
|
df = df_cot_err.copy() |
|
if request and "model" in request.query_params: |
|
model_param = request.query_params["model"] |
|
if model_param in df.model.to_list(): |
|
model_id = model_param |
|
df["selected"] = df_cot_err.model.apply(lambda x: "selected" if x==model_id else "-") |
|
|
|
template = "plotly_dark" if plotly_mode=="dark" else "plotly" |
|
fig = px.scatter(df, x="base accuracy", y="marginal acc. gain", color="selected", symbol="model", |
|
facet_col="task", facet_col_wrap=3, |
|
category_orders={"selected": ["selected", "-"]}, |
|
color_discrete_sequence=["Orange", "Gray"], |
|
template=template, |
|
error_y="acc_gain-err", hover_data=['model', "cot accuracy"], |
|
width=1200, height=700) |
|
|
|
fig.update_layout( |
|
title={"automargin": True}, |
|
) |
|
return fig, model_id |
|
|
|
def get_model_table(model_id): |
|
|
|
def make_pretty(styler): |
|
styler.hide(axis="index") |
|
styler.format(precision=1), |
|
styler.background_gradient( |
|
axis=None, |
|
subset=["acc_base", "acc_cot"], |
|
vmin=20, vmax=100, cmap="YlGnBu" |
|
) |
|
styler.background_gradient( |
|
axis=None, |
|
subset=["acc_gain"], |
|
vmin=-20, vmax=20, cmap="coolwarm" |
|
) |
|
styler.set_table_styles({ |
|
'task': [{'selector': '', |
|
'props': [('font-weight', 'bold')]}], |
|
'B': [{'selector': 'td', |
|
'props': 'color: blue;'}] |
|
}, overwrite=False) |
|
return styler |
|
|
|
df_cot_model = df_cot_regimes[df_cot_regimes.model.eq(model_id)][['task', 'cot_chain', 'best_of', |
|
'temperature', 'top_k', 'top_p', 'acc_base', 'acc_cot', 'delta_abs']] |
|
|
|
df_cot_model = df_cot_model \ |
|
.rename(columns={"temperature": "temp"}) \ |
|
.replace({'cot_chain': 'ReflectBeforeRun'}, "Reflect") \ |
|
.sort_values(["task", "cot_chain"]) \ |
|
.reset_index(drop=True) |
|
|
|
return df_cot_model.style.pipe(make_pretty) |
|
|
|
def styled_model_table(model_id, request: gr.Request): |
|
if request and "model" in request.query_params: |
|
model_param = request.query_params["model"] |
|
if model_param in df_cot_regimes.model.to_list(): |
|
model_id = model_param |
|
return get_model_table(model_id) |
|
|
|
|
|
demo = gr.Blocks() |
|
|
|
with demo: |
|
|
|
gr.HTML(TITLE) |
|
gr.Markdown(INTRODUCTION_TEXT) |
|
with gr.Row(): |
|
model_list = gr.Dropdown(list(df_cot_err.model.unique()), value="allenai/tulu-2-70b", label="Model", scale=2) |
|
plotly_mode = gr.Radio(["dark","light"], value="dark", label="Plot theme", scale=1) |
|
submit = gr.Button("Update", scale=1) |
|
table = gr.DataFrame() |
|
plot = gr.Plot(label="evals") |
|
|
|
|
|
submit.click(plot_evals, [model_list, plotly_mode], [plot, model_list]) |
|
submit.click(styled_model_table, model_list, table) |
|
demo.load(plot_evals, [model_list, plotly_mode], [plot, model_list]) |
|
demo.load(styled_model_table, model_list, table) |
|
|
|
demo.launch() |