File size: 6,355 Bytes
be473e6
3573a39
 
 
5ad7125
2070be3
3573a39
58c39e0
 
be473e6
8c47a22
2861b85
2070be3
 
3573a39
be473e6
 
 
2861b85
be473e6
2861b85
be473e6
 
0607989
be473e6
 
 
2861b85
3573a39
 
5ad7125
be473e6
2861b85
be473e6
 
3573a39
be473e6
 
4a85196
be473e6
 
3573a39
be473e6
 
3573a39
be473e6
4a85196
be473e6
 
3573a39
be473e6
8f809e2
be473e6
3573a39
 
 
be473e6
 
3573a39
be473e6
 
 
 
3573a39
 
6565530
947816c
be473e6
3573a39
 
6565530
947816c
be473e6
3573a39
 
6565530
3573a39
be473e6
 
7487fdb
2070be3
 
2861b85
ed3fe33
8c47a22
be473e6
 
 
 
 
666860b
 
7055d8b
3573a39
be473e6
3573a39
be473e6
666860b
 
 
 
 
 
 
8f114e2
 
 
 
 
 
 
666860b
 
be473e6
3573a39
 
8f114e2
3573a39
 
 
 
4a85196
3573a39
 
58c39e0
 
 
 
3573a39
 
be473e6
 
3573a39
666860b
2070be3
 
 
 
5ad7125
 
666860b
5ad7125
 
 
666860b
5ad7125
7487fdb
3573a39
 
 
 
666860b
 
3573a39
 
666860b
3573a39
 
666860b
5ad7125
8c47a22
be473e6
3573a39
be473e6
4a85196
58c39e0
4a85196
58c39e0
be473e6
666860b
 
76d3665
be473e6
 
2861b85
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import logging

import datasets
import gradio as gr
import pandas as pd
import datetime

from fetch_utils import (check_dataset_and_get_config,
                         check_dataset_and_get_split)

import leaderboard
logger = logging.getLogger(__name__)
global update_time 
update_time = datetime.datetime.fromtimestamp(0)

def get_records_from_dataset_repo(dataset_id):
    dataset_config = check_dataset_and_get_config(dataset_id)

    logger.info(f"Dataset {dataset_id} has configs {dataset_config}")
    dataset_split = check_dataset_and_get_split(dataset_id, dataset_config[0])
    logger.info(f"Dataset {dataset_id} has splits {dataset_split}")

    try:
        ds = datasets.load_dataset(dataset_id, dataset_config[0], split=dataset_split[0])
        df = ds.to_pandas()
        return df
    except Exception as e:
        logger.warning(
            f"Failed to load dataset {dataset_id} with config {dataset_config}: {e}"
        )
        return pd.DataFrame()

    
def get_model_ids(ds):
    logging.info(f"Dataset {ds} column names: {ds['model_id']}")
    models = ds["model_id"].tolist()
    # return unique elements in the list model_ids
    model_ids = list(set(models))
    model_ids.insert(0, "Any")
    return model_ids


def get_dataset_ids(ds):
    logging.info(f"Dataset {ds} column names: {ds['dataset_id']}")
    datasets = ds["dataset_id"].tolist()
    dataset_ids = list(set(datasets))
    dataset_ids.insert(0, "Any")
    return dataset_ids


def get_types(ds):
    # set types for each column
    types = [str(t) for t in ds.dtypes.to_list()]
    types = [t.replace("object", "markdown") for t in types]
    types = [t.replace("float64", "number") for t in types]
    types = [t.replace("int64", "number") for t in types]
    return types


def get_display_df(df):
    # style all elements in the model_id column
    display_df = df.copy()
    columns = display_df.columns.tolist()
    if "model_id" in columns:
        display_df["model_id"] = display_df["model_id"].apply(
            lambda x: f'<a href="https://huggingface.co./{x}" target="_blank" style="color:blue">πŸ”—{x}</a>'
        )
    # style all elements in the dataset_id column
    if "dataset_id" in columns:
        display_df["dataset_id"] = display_df["dataset_id"].apply(
            lambda x: f'<a href="https://huggingface.co./datasets/{x}" target="_blank" style="color:blue">πŸ”—{x}</a>'
        )
    # style all elements in the report_link column
    if "report_link" in columns:
        display_df["report_link"] = display_df["report_link"].apply(
            lambda x: f'<a href="{x}" target="_blank" style="color:blue">πŸ”—{x}</a>'
        )
    return display_df

def get_demo(leaderboard_tab):
    global update_time
    update_time = datetime.datetime.now()
    logger.info("Loading leaderboard records")
    leaderboard.records = get_records_from_dataset_repo(leaderboard.LEADERBOARD)
    records = leaderboard.records

    model_ids = get_model_ids(records)
    dataset_ids = get_dataset_ids(records)

    column_names = records.columns.tolist()
    issue_columns = column_names[:11]
    info_columns = column_names[15:]
    default_columns = ["model_id", "dataset_id", "total_issues", "report_link"]
    default_df = records[default_columns]  # extract columns selected
    types = get_types(default_df)
    display_df = get_display_df(default_df)  # the styled dataframe to display

    with gr.Row():
        with gr.Column():
          info_columns_select = gr.CheckboxGroup(
              label="Info Columns",
              choices=info_columns,
              value=default_columns,
              interactive=True,
        )
        with gr.Column():
          issue_columns_select = gr.CheckboxGroup(
              label="Issue Columns",
              choices=issue_columns,
              value=[],
              interactive=True,
          )
    
    with gr.Row():
        task_select = gr.Dropdown(
            label="Task",
            choices=["text_classification"],
            value="text_classification",
            interactive=True,
        )
        model_select = gr.Dropdown(
            label="Model id", choices=model_ids, value=model_ids[0], interactive=True
        )
        dataset_select = gr.Dropdown(
            label="Dataset id",
            choices=dataset_ids,
            value=dataset_ids[0],
            interactive=True,
        )

    with gr.Row():
        leaderboard_df = gr.DataFrame(display_df, datatype=types, interactive=False)

    def update_leaderboard_records(model_id, dataset_id, issue_columns, info_columns, task):
        global update_time
        if datetime.datetime.now() - update_time < datetime.timedelta(minutes=10):
            return gr.update()
        update_time = datetime.datetime.now()
        logger.info("Updating leaderboard records")
        leaderboard.records = get_records_from_dataset_repo(leaderboard.LEADERBOARD)
        return filter_table(model_id, dataset_id, issue_columns, info_columns, task)

    leaderboard_tab.select(
        fn=update_leaderboard_records, 
        inputs=[model_select, dataset_select, issue_columns_select, info_columns_select, task_select], 
        outputs=[leaderboard_df])

    @gr.on(
        triggers=[
            model_select.change,
            dataset_select.change,
            issue_columns_select.change,
            info_columns_select.change,
            task_select.change,
        ],
        inputs=[model_select, dataset_select, issue_columns_select, info_columns_select, task_select],
        outputs=[leaderboard_df],
    )
    def filter_table(model_id, dataset_id, issue_columns, info_columns, task):
        logger.info("Filtering leaderboard records")
        records = leaderboard.records
        # filter the table based on task
        df = records[(records["task"] == task)]
        # filter the table based on the model_id and dataset_id
        if model_id and model_id != "Any":
            df = df[(df["model_id"] == model_id)]
        if dataset_id and dataset_id != "Any":
            df = df[(df["dataset_id"] == dataset_id)]

        # filter the table based on the columns 
        issue_columns.sort()
        df = df[info_columns + issue_columns]
        types = get_types(df)
        display_df = get_display_df(df)
        return gr.update(value=display_df, datatype=types, interactive=False)