File size: 3,247 Bytes
88faaa4
885c364
af674e3
 
66b916b
885c364
 
66b916b
 
 
88faaa4
66b916b
 
 
 
 
885c364
66b916b
 
 
 
 
 
 
 
 
 
97db185
66b916b
885c364
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88faaa4
885c364
 
 
 
 
66b916b
88faaa4
 
885c364
 
 
 
 
66b916b
 
88faaa4
66b916b
 
97db185
885c364
88faaa4
 
66b916b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import altair as alt
import gradio as gr
import pandas as pd

from functools import partial
from datasets import load_dataset

def get_data():
    model_id = "ybelkada/model_cards_correct_tag"
    dataset = load_dataset(model_id, split="train").to_pandas()

    # Convert dataset to a pandas DataFrame and sort by commit_dates
    df = pd.DataFrame(dataset)
    df["commit_dates"] = pd.to_datetime(df["commit_dates"])  # Convert commit_dates to datetime format
    df = df.sort_values(by="commit_dates")
    melted_df = pd.melt(df, id_vars=['commit_dates'], value_vars=['total_transformers_model', 'missing_library_name'], var_name='type')

    df['ratio'] = (1 - df['missing_library_name'] / df['total_transformers_model']) * 100
    ratio_df = df[['commit_dates', 'ratio']].copy()
    return ratio_df, melted_df

ratio_df, melted_df = get_data()

def make_plot(plot_type, refresh=False):
    global ratio_df, melted_df

    if refresh:
        
        ratio_df, melted_df = get_data()

    if plot_type == "Total models with missing 'transformers' tag":
        highlight = alt.selection(type='single', on='mouseover',
                                fields=['type'], nearest=True)


        base = alt.Chart(melted_df).encode(
            x=alt.X('commit_dates:T', title='Date'),
            y=alt.Y('value:Q', scale=alt.Scale(domain=(melted_df['value'].min(), melted_df['value'].max())), title="Count"),
            color='type:N',
        )

        points = base.mark_circle().encode(
            opacity=alt.value(1),
        ).add_selection(
            highlight
        ).properties(
            width=1200,
            height=800,
        )

        lines = base.mark_line().encode(
            size=alt.condition(~highlight, alt.value(1), alt.value(3))
        )

        return points + lines
    else:
        highlight = alt.selection(type='single', on='mouseover',
                                fields=['ratio'], nearest=True)

        base = alt.Chart(ratio_df).encode(
            x=alt.X('commit_dates:T', title='Date'),
            y=alt.Y('ratio:Q', scale=alt.Scale(domain=(ratio_df['ratio'].min(), ratio_df['ratio'].max())), title="(1 - missing_library_name / total_transformers_model) * 100 - Higher is better"),
        )

        points = base.mark_circle().encode(
            opacity=alt.value(1)
        ).add_selection(
            highlight
        ).properties(
            width=1200,
            height=800,
        )

        lines = base.mark_line().encode(
            size=alt.condition(~highlight, alt.value(1), alt.value(3))
        )
        
        return points + lines
    

with gr.Blocks() as demo:
    button = gr.Radio(
        label="Plot type",
        choices=["Total models with missing 'transformers' tag", "Proportion of models correctly tagged with 'transformers' tag"], 
        value="Total models with missing 'transformers' tag"
    )
    refresh_button = gr.Button(value="Fetch latest data")

    plot = gr.Plot(label="Plot")

    button.change(make_plot, inputs=[button], outputs=[plot])
    refresh_button.click(partial(make_plot, refresh=True), inputs=[button], outputs=[plot])
    demo.load(make_plot, inputs=[button], outputs=[plot])

if __name__ == "__main__":
    demo.launch()