""" Main module for the WhisperKit Evaluation Dashboard. This module sets up and runs the Gradio interface for the WhisperKit Evaluation Dashboard, allowing users to explore and compare speech recognition model performance across different devices, operating systems, and datasets. """ import os from math import ceil, floor import re import gradio as gr import pandas as pd from argmax_gradio_components import RangeSlider from dotenv import load_dotenv from huggingface_hub import login # Import custom constants and utility functions from constants import ( BANNER_TEXT, CITATION_BUTTON_LABEL, CITATION_BUTTON_TEXT, COL_NAMES, HEADER, METHODOLOGY_TEXT, PERFORMANCE_TEXT, QUALITY_TEXT, ) from utils import ( add_datasets_to_performance_columns, add_datasets_to_quality_columns, create_initial_performance_column_dict, create_initial_quality_column_dict, css, fields, get_os_name_and_version, make_dataset_wer_clickable_link, make_model_name_clickable_link, plot_metric, read_json_line_by_line, ) # Load environment variables load_dotenv() # Get the Hugging Face token from the environment variable HF_TOKEN = os.getenv("HF_TOKEN") # Use the token for login login(token=HF_TOKEN, add_to_git_credential=True) # Define repository and directory information repo_id = "argmaxinc/whisperkit-evals-dataset" directory = "xcresults/benchmark_results" local_dir = "" # Load benchmark data from JSON files PERFORMANCE_DATA = read_json_line_by_line("dashboard_data/performance_data.json") QUALITY_DATA = read_json_line_by_line("dashboard_data/quality_data.json") # Convert JSON data to pandas DataFrames quality_df = pd.json_normalize(QUALITY_DATA) benchmark_df = pd.json_normalize(PERFORMANCE_DATA) # Process timestamp data def safe_parse_datetime(x): try: return pd.to_datetime(x, format="%Y-%m-%d-%H-%M-%S-%p") except (ValueError, TypeError): return pd.NaT # Return Not-a-Time for invalid dates benchmark_df["timestamp"] = benchmark_df["timestamp"].apply(safe_parse_datetime).dt.tz_localize(None) quality_df["timestamp"] = quality_df["timestamp"].apply(safe_parse_datetime).dt.tz_localize(None) # First create a temporary column for model length sorted_quality_df = ( quality_df.assign(model_len=quality_df["model"].str.len()) .sort_values( by=["model_len", "model", "timestamp"], ascending=[True, True, False], ) .drop(columns=["model_len"]) .drop_duplicates(subset=["model"], keep="first") .reset_index(drop=True) ) sorted_performance_df = ( benchmark_df.assign(model_len=benchmark_df["model"].str.len()) .sort_values( by=["model_len", "model", "device", "os", "timestamp"], ascending=[True, True, True, True, False], ) .drop(columns=["model_len"]) .drop_duplicates(subset=["model", "device", "os"], keep="first") .reset_index(drop=True) ) # Identify dataset-specific columns dataset_wer_columns = [ col for col in sorted_quality_df.columns if col.startswith("dataset_wer.") ] dataset_speed_columns = [ col for col in sorted_performance_df.columns if col.startswith("dataset_speed.") ] dataset_toks_columns = [ col for col in sorted_performance_df.columns if col.startswith("dataset_tokens_per_second.") ] # Extract dataset names QUALITY_DATASETS = [col.split(".")[-1] for col in dataset_wer_columns] PERFORMANCE_DATASETS = [col.split(".")[-1] for col in dataset_speed_columns] # Prepare DataFrames for display model_df = sorted_quality_df[ ["model", "average_wer", "qoi", "timestamp"] + dataset_wer_columns ] performance_df = sorted_performance_df[ [ "model", "device", "os", "average_wer", "qoi", "speed", "tokens_per_second", "timestamp", ] + dataset_speed_columns + dataset_toks_columns ].copy() # Rename columns for clarity performance_df = performance_df.rename( lambda x: COL_NAMES[x] if x in COL_NAMES else x, axis="columns" ) model_df = model_df.rename( lambda x: COL_NAMES[x] if x in COL_NAMES else x, axis="columns" ) # Process dataset-specific columns for col in dataset_wer_columns: dataset_name = col.split(".")[-1] model_df = model_df.rename(columns={col: dataset_name}) # model_df[dataset_name] = model_df.apply( # lambda x: make_dataset_wer_clickable_link(x, dataset_name), axis=1 # ) for col in dataset_speed_columns: dataset_name = col.split(".")[-1] performance_df = performance_df.rename( columns={ col: f"{'Short-Form' if dataset_name == 'librispeech-10mins' else 'Long-Form'} Speed" } ) for col in dataset_toks_columns: dataset_name = col.split(".")[-1] performance_df = performance_df.rename( columns={ col: f"{'Short-Form' if dataset_name == 'librispeech-10mins' else 'Long-Form'} Tok/s" } ) # Process model names for display model_df["model_raw"] = model_df["Model"].copy() performance_df["model_raw"] = performance_df["Model"].copy() model_df["Model"] = model_df["Model"].apply(lambda x: make_model_name_clickable_link(x)) performance_df["Model"] = performance_df["Model"].apply( lambda x: make_model_name_clickable_link(x) ) performance_df["Average WER"] = performance_df["Average WER"].apply( lambda x: x if x < 90 else f"""

{x}

""" ) # Extract unique devices and OS versions PERFORMANCE_DEVICES = performance_df["Device"].unique().tolist() PERFORMANCE_OS = performance_df["OS"].apply(get_os_name_and_version).unique().tolist() PERFORMANCE_OS.sort() # Create initial column dictionaries and update with dataset information initial_performance_column_dict = create_initial_performance_column_dict() initial_quality_column_dict = create_initial_quality_column_dict() performance_column_info = add_datasets_to_performance_columns( initial_performance_column_dict, PERFORMANCE_DATASETS ) quality_column_info = add_datasets_to_quality_columns( initial_quality_column_dict, QUALITY_DATASETS ) # Unpack the returned dictionaries updated_performance_column_dict = performance_column_info["column_dict"] updated_quality_column_dict = quality_column_info["column_dict"] PerformanceAutoEvalColumn = performance_column_info["AutoEvalColumn"] QualityAutoEvalColumn = quality_column_info["AutoEvalColumn"] # Define column sets for different views PERFORMANCE_COLS = performance_column_info["COLS"] QUALITY_COLS = quality_column_info["COLS"] PERFORMANCE_TYPES = performance_column_info["TYPES"] QUALITY_TYPES = quality_column_info["TYPES"] PERFORMANCE_ALWAYS_HERE_COLS = performance_column_info["ALWAYS_HERE_COLS"] QUALITY_ALWAYS_HERE_COLS = quality_column_info["ALWAYS_HERE_COLS"] PERFORMANCE_TOGGLE_COLS = performance_column_info["TOGGLE_COLS"] QUALITY_TOGGLE_COLS = quality_column_info["TOGGLE_COLS"] PERFORMANCE_SELECTED_COLS = performance_column_info["SELECTED_COLS"] QUALITY_SELECTED_COLS = quality_column_info["SELECTED_COLS"] def performance_filter( df, columns, model_query, exclude_models, devices, os, short_speed_slider, long_speed_slider, short_toks_slider, long_toks_slider, ): """ Filters the performance DataFrame based on specified criteria. :param df: The DataFrame to be filtered. :param columns: The columns to be included in the filtered DataFrame. :param model_query: The query string to filter the 'Model' column. :param exclude_models: Models to exclude from the results. :param devices: The devices to filter the 'Device' column. :param os: The list of operating systems to filter the 'OS' column. :param short_speed_slider: The range of values to filter the 'Short-Form Speed' column. :param long_speed_slider: The range of values to filter the 'Long-Form Speed' column. :param short_toks_slider: The range of values to filter the 'Short-Form Tok/s' column. :param long_toks_slider: The range of values to filter the 'Long-Form Tok/s' column. :return: The filtered DataFrame. """ # Select columns based on input and always-present columns filtered_df = df[ PERFORMANCE_ALWAYS_HERE_COLS + [c for c in PERFORMANCE_COLS if c in df.columns and c in columns] ] # Filter models based on query if model_query: filtered_df = filtered_df[ filtered_df["Model"].str.contains( "|".join(q.strip() for q in model_query.split(";")), case=False ) ] # Exclude specified models if exclude_models: exclude_list = [m.strip() for m in exclude_models.split(";")] filtered_df = filtered_df[ ~filtered_df["Model"].str.contains("|".join(exclude_list), case=False) ] # Filter by devices if devices: filtered_df = filtered_df[filtered_df["Device"].isin(devices)] else: filtered_df = pd.DataFrame(columns=filtered_df.columns) # Filter by operating systems filtered_df = ( filtered_df[ ( filtered_df["OS"].str.contains( "|".join(q.strip() for q in os), case=False ) ) ] if os else pd.DataFrame(columns=filtered_df.columns) ) # Apply short-form and long-form speed and tokens per second filters min_short_speed, max_short_speed = short_speed_slider min_long_speed, max_long_speed = long_speed_slider min_short_toks, max_short_toks = short_toks_slider min_long_toks, max_long_toks = long_toks_slider if "Short-Form Speed" in filtered_df.columns: filtered_df = filtered_df[ ((filtered_df["Short-Form Speed"] >= min_short_speed) & (filtered_df["Short-Form Speed"] <= max_short_speed)) | filtered_df["Short-Form Speed"].isna() ] if "Long-Form Speed" in filtered_df.columns: filtered_df = filtered_df[ ((filtered_df["Long-Form Speed"] >= min_long_speed) & (filtered_df["Long-Form Speed"] <= max_long_speed)) | filtered_df["Long-Form Speed"].isna() ] if "Short-Form Tok/s" in filtered_df.columns: filtered_df = filtered_df[ ((filtered_df["Short-Form Tok/s"] >= min_short_toks) & (filtered_df["Short-Form Tok/s"] <= max_short_toks)) | filtered_df["Short-Form Tok/s"].isna() ] if "Long-Form Tok/s" in filtered_df.columns: filtered_df = filtered_df[ ((filtered_df["Long-Form Tok/s"] >= min_long_toks) & (filtered_df["Long-Form Tok/s"] <= max_long_toks)) | filtered_df["Long-Form Tok/s"].isna() ] return filtered_df def quality_filter(df, columns, model_query, wer_slider, qoi_slider, exclude_models): """ Filters the quality DataFrame based on specified criteria. :param df: The DataFrame to be filtered. :param columns: The columns to be included in the filtered DataFrame. :param model_query: The query string to filter the 'Model' column. :param wer_slider: The range of values to filter the 'Average WER' column. :param qoi_slider: The range of values to filter the 'QoI' column. :param exclude_models: Models to exclude from the results. :return: The filtered DataFrame. """ # Select columns based on input and always-present columns filtered_df = df[ QUALITY_ALWAYS_HERE_COLS + [c for c in QUALITY_COLS if c in df.columns and c in columns] ] # Filter models based on query if model_query: filtered_df = filtered_df[ filtered_df["Model"].str.contains( "|".join(q.strip() for q in model_query.split(";")), case=False ) ] # Exclude specified models if exclude_models: exclude_list = [m.strip() for m in exclude_models.split(";")] filtered_df = filtered_df[ ~filtered_df["Model"].str.contains("|".join(exclude_list), case=False) ] # Apply WER and QoI filters min_wer_slider, max_wer_slider = wer_slider min_qoi_slider, max_qoi_slider = qoi_slider if "Average WER" in filtered_df.columns: filtered_df = filtered_df[ (filtered_df["Average WER"] >= min_wer_slider) & (filtered_df["Average WER"] <= max_wer_slider) ] if "QoI" in filtered_df.columns: filtered_df = filtered_df[ (filtered_df["QoI"] >= min_qoi_slider) & (filtered_df["QoI"] <= max_qoi_slider) ] return filtered_df diff_tab = gr.TabItem("Difference Checker", elem_id="diff_checker", id=2) text_diff_elems = [] tabs = gr.Tabs(elem_id="tab-elems") font = [ "Zwizz Regular", # Local font "IBM Plex Mono", # Monospace font "ui-sans-serif", "system-ui", "sans-serif", ] # Define the Gradio interface with gr.Blocks(css=css, theme=gr.themes.Base(font=font)) as demo: # Add header and banner to the interface gr.HTML(HEADER) gr.HTML(BANNER_TEXT, elem_classes="markdown-text") # Create tabs for different sections of the dashboard with tabs.render(): # Performance Tab with gr.TabItem("Performance", elem_id="benchmark", id=0): with gr.Row(): with gr.Column(scale=1): with gr.Row(): with gr.Column(scale=6, elem_classes="filter_models_column"): filter_performance_models = gr.Textbox( placeholder="🔍 Filter Model (separate multiple queries with ';')", label="Filter Models", ) with gr.Column(scale=4, elem_classes="exclude_models_column"): exclude_performance_models = gr.Textbox( placeholder="🔍 Exclude Model", label="Exclude Model", ) with gr.Row(): with gr.Accordion("See All Columns", open=False): with gr.Row(): with gr.Column(scale=9, elem_id="performance_columns"): performance_shown_columns = gr.CheckboxGroup( choices=PERFORMANCE_TOGGLE_COLS, value=PERFORMANCE_SELECTED_COLS, label="Toggle Columns", elem_id="column-select", interactive=True, ) with gr.Column( scale=1, min_width=200, elem_id="performance_select_columns", ): with gr.Row(): select_all_button = gr.Button( "Select All", elem_id="select-all-button", interactive=True, ) deselect_all_button = gr.Button( "Deselect All", elem_id="deselect-all-button", interactive=True, ) def select_all_columns(): return PERFORMANCE_TOGGLE_COLS def deselect_all_columns(): return [] select_all_button.click( select_all_columns, inputs=[], outputs=performance_shown_columns, ) deselect_all_button.click( deselect_all_columns, inputs=[], outputs=performance_shown_columns, ) with gr.Row(): with gr.Accordion("Filter Devices", open=False): with gr.Row(): with gr.Column( scale=9, elem_id="filter_devices_column" ): performance_shown_devices = gr.CheckboxGroup( choices=PERFORMANCE_DEVICES, value=PERFORMANCE_DEVICES, label="Filter Devices", interactive=True, ) with gr.Column( scale=1, min_width=200, elem_id="filter_select_devices", ): with gr.Row(): select_all_devices_button = gr.Button( "Select All", elem_id="select-all-devices-button", interactive=True, ) deselect_all_devices_button = gr.Button( "Deselect All", elem_id="deselect-all-devices-button", interactive=True, ) def select_all_devices(): return PERFORMANCE_DEVICES def deselect_all_devices(): return [] select_all_devices_button.click( select_all_devices, inputs=[], outputs=performance_shown_devices, ) deselect_all_devices_button.click( deselect_all_devices, inputs=[], outputs=performance_shown_devices, ) with gr.Row(): performance_shown_os = gr.CheckboxGroup( choices=PERFORMANCE_OS, value=PERFORMANCE_OS, label="Filter OS", interactive=True, ) with gr.Column(scale=1): with gr.Accordion("See Performance Filters"): with gr.Row(): with gr.Row(): min_short_speed, max_short_speed = floor( min(performance_df["Short-Form Speed"]) ), ceil(max(performance_df["Short-Form Speed"])) short_speed_slider = RangeSlider( value=[min_short_speed, max_short_speed], minimum=min_short_speed, maximum=max_short_speed, step=0.001, label="Short-Form Speed", ) with gr.Row(): min_long_speed, max_long_speed = floor( performance_df["Long-Form Speed"].dropna().min() ), ceil(performance_df["Long-Form Speed"].dropna().max()) long_speed_slider = RangeSlider( value=[min_long_speed, max_long_speed], minimum=min_long_speed, maximum=max_long_speed, step=0.001, label="Long-Form Speed", ) with gr.Row(): with gr.Row(): min_short_toks, max_short_toks = floor( min(performance_df["Short-Form Tok/s"]) ), ceil(max(performance_df["Short-Form Tok/s"])) short_toks_slider = RangeSlider( value=[min_short_toks, max_short_toks], minimum=min_short_toks, maximum=max_short_toks, step=0.001, label="Short-Form Tok/s", ) with gr.Row(): min_long_toks, max_long_toks = floor( performance_df["Long-Form Tok/s"].dropna().min() ), ceil(performance_df["Long-Form Tok/s"].dropna().max()) long_toks_slider = RangeSlider( value=[min_long_toks, max_long_toks], minimum=min_long_toks, maximum=max_long_toks, step=0.001, label="Long-Form Tok/s", ) with gr.Row(): gr.Markdown(PERFORMANCE_TEXT, elem_classes="markdown-text") with gr.Row(): leaderboard_df = gr.components.Dataframe( value=performance_df[ PERFORMANCE_ALWAYS_HERE_COLS + performance_shown_columns.value ], headers=[ PERFORMANCE_ALWAYS_HERE_COLS + performance_shown_columns.value ], datatype=[ c.type for c in fields(PerformanceAutoEvalColumn) if c.name in PERFORMANCE_COLS ], elem_id="leaderboard-table", elem_classes="large-table", interactive=False, ) # Copy of the leaderboard dataframe to apply filters to hidden_leaderboard_df = gr.components.Dataframe( value=performance_df, headers=PERFORMANCE_COLS, datatype=[ c.type for c in fields(PerformanceAutoEvalColumn) if c.name in PERFORMANCE_COLS ], visible=False, ) # Inputs for the dataframe filter function performance_filter_inputs = [ hidden_leaderboard_df, performance_shown_columns, filter_performance_models, exclude_performance_models, performance_shown_devices, performance_shown_os, short_speed_slider, long_speed_slider, short_toks_slider, long_toks_slider, ] filter_output = leaderboard_df filter_performance_models.change( performance_filter, performance_filter_inputs, filter_output ) exclude_performance_models.change( performance_filter, performance_filter_inputs, filter_output ) performance_shown_columns.change( performance_filter, performance_filter_inputs, filter_output ) performance_shown_devices.change( performance_filter, performance_filter_inputs, filter_output ) performance_shown_os.change( performance_filter, performance_filter_inputs, filter_output ) short_speed_slider.change( performance_filter, performance_filter_inputs, filter_output ) long_speed_slider.change( performance_filter, performance_filter_inputs, filter_output ) short_toks_slider.change( performance_filter, performance_filter_inputs, filter_output ) long_toks_slider.change( performance_filter, performance_filter_inputs, filter_output ) with gr.TabItem("English Quality", elem_id="timeline", id=1): with gr.Row(): with gr.Column(scale=1): with gr.Row(): with gr.Column(scale=6, elem_classes="filter_models_column"): filter_quality_models = gr.Textbox( placeholder="🔍 Filter Model (separate multiple queries with ';')", label="Filter Models", ) with gr.Column(scale=4, elem_classes="exclude_models_column"): exclude_quality_models = gr.Textbox( placeholder="🔍 Exclude Model", label="Exclude Model", ) with gr.Row(): with gr.Accordion("See All Columns", open=False): quality_shown_columns = gr.CheckboxGroup( choices=QUALITY_TOGGLE_COLS, value=QUALITY_SELECTED_COLS, label="Toggle Columns", elem_id="column-select", interactive=True, ) with gr.Column(scale=1): with gr.Accordion("See Quality Filters"): with gr.Row(): with gr.Row(): quality_min_avg_wer, quality_max_avg_wer = ( floor(min(model_df["Average WER"])), ceil(max(model_df["Average WER"])) + 1, ) wer_slider = RangeSlider( value=[quality_min_avg_wer, quality_max_avg_wer], minimum=quality_min_avg_wer, maximum=quality_max_avg_wer, label="Average WER", ) with gr.Row(): quality_min_qoi, quality_max_qoi = floor( min(model_df["QoI"]) ), ceil(max(model_df["QoI"] + 1)) qoi_slider = RangeSlider( value=[quality_min_qoi, quality_max_qoi], minimum=quality_min_qoi, maximum=quality_max_qoi, label="QoI", ) with gr.Row(): gr.Markdown(QUALITY_TEXT) with gr.Row(): quality_leaderboard_df = gr.components.Dataframe( value=model_df[ QUALITY_ALWAYS_HERE_COLS + quality_shown_columns.value ], headers=[QUALITY_ALWAYS_HERE_COLS + quality_shown_columns.value], datatype=[ c.type for c in fields(QualityAutoEvalColumn) if c.name in QUALITY_COLS ], elem_id="leaderboard-table", elem_classes="large-table", interactive=False, ) # Copy of the leaderboard dataframe to apply filters to hidden_quality_leaderboard_df = gr.components.Dataframe( value=model_df, headers=QUALITY_COLS, datatype=[ c.type for c in fields(QualityAutoEvalColumn) if c.name in QUALITY_COLS ], visible=False, ) # Inputs for the dataframe filter function filter_inputs = [ hidden_quality_leaderboard_df, quality_shown_columns, filter_quality_models, wer_slider, qoi_slider, exclude_quality_models, ] filter_output = quality_leaderboard_df filter_quality_models.change( quality_filter, filter_inputs, filter_output ) exclude_quality_models.change( quality_filter, filter_inputs, filter_output ) quality_shown_columns.change( quality_filter, filter_inputs, filter_output ) wer_slider.change(quality_filter, filter_inputs, filter_output) qoi_slider.change(quality_filter, filter_inputs, filter_output) # Timeline Tab with gr.TabItem("Timeline", elem_id="timeline", id=4): # Create subtabs for different metrics with gr.Tabs(): with gr.TabItem("QoI", id=0): with gr.Row(): with gr.Column(scale=6): filter_qoi = gr.Textbox( placeholder="🔍 Filter Model-Device-OS (separate multiple queries with ';')", label="Filter", ) with gr.Column(scale=4): exclude_qoi = gr.Textbox( placeholder="🔍 Exclude Model-Device-OS", label="Exclude", ) with gr.Row(): with gr.Column(): qoi_plot = gr.Plot(container=True) demo.load( lambda x, y, z: plot_metric( x, "qoi", "QoI", "QoI Over Time for Model-Device-OS Combinations", y, z, ), [ gr.Dataframe(benchmark_df, visible=False), filter_qoi, exclude_qoi, ], qoi_plot, ) filter_qoi.change( lambda x, y, z: plot_metric( x, "qoi", "QoI", "QoI Over Time for Model-Device-OS Combinations", y, z, ), [ gr.Dataframe(benchmark_df, visible=False), filter_qoi, exclude_qoi, ], qoi_plot, ) exclude_qoi.change( lambda x, y, z: plot_metric( x, "qoi", "QoI", "QoI Over Time for Model-Device-OS Combinations", y, z, ), [ gr.Dataframe(benchmark_df, visible=False), filter_qoi, exclude_qoi, ], qoi_plot, ) with gr.TabItem("Average WER", id=1): with gr.Row(): with gr.Column(scale=6): filter_average_wer = gr.Textbox( placeholder="🔍 Filter Model-Device-OS (separate multiple queries with ';')", label="Filter", ) with gr.Column(scale=4): exclude_average_wer = gr.Textbox( placeholder="🔍 Exclude Model-Device-OS", label="Exclude", ) with gr.Row(): with gr.Column(): average_wer_plot = gr.Plot(container=True) demo.load( lambda x, y, z: plot_metric( x, "average_wer", "Average WER", "Average WER Over Time for Model-Device-OS Combinations", y, z, ), [ gr.Dataframe(benchmark_df, visible=False), filter_average_wer, exclude_average_wer, ], average_wer_plot, ) filter_average_wer.change( lambda x, y, z: plot_metric( x, "average_wer", "Average WER", "Average WER Over Time for Model-Device-OS Combinations", y, z, ), [ gr.Dataframe(benchmark_df, visible=False), filter_average_wer, exclude_average_wer, ], average_wer_plot, ) exclude_average_wer.change( lambda x, y, z: plot_metric( x, "average_wer", "Average WER", "Average WER Over Time for Model-Device-OS Combinations", y, z, ), [ gr.Dataframe(benchmark_df, visible=False), filter_average_wer, exclude_average_wer, ], average_wer_plot, ) with gr.TabItem("Speed", id=2): with gr.Row(): with gr.Column(scale=6): filter_speed = gr.Textbox( placeholder="🔍 Filter Model-Device-OS (separate multiple queries with ';')", label="Filter", ) with gr.Column(scale=4): exclude_speed = gr.Textbox( placeholder="🔍 Exclude Model-Device-OS", label="Exclude", ) with gr.Row(): with gr.Column(): speed_plot = gr.Plot(container=True) demo.load( lambda x, y, z: plot_metric( x, "speed", "Speed", "Speed Over Time for Model-Device-OS Combinations", y, z, ), [ gr.Dataframe(benchmark_df, visible=False), filter_speed, exclude_speed, ], speed_plot, ) filter_speed.change( lambda x, y, z: plot_metric( x, "speed", "Speed", "Speed Over Time for Model-Device-OS Combinations", y, z, ), [ gr.Dataframe(benchmark_df, visible=False), filter_speed, exclude_speed, ], speed_plot, ) exclude_speed.change( lambda x, y, z: plot_metric( x, "speed", "Speed", "Speed Over Time for Model-Device-OS Combinations", y, z, ), [ gr.Dataframe(benchmark_df, visible=False), filter_speed, exclude_speed, ], speed_plot, ) with gr.TabItem("Tok/s", id=3): with gr.Row(): with gr.Column(scale=6): filter_toks = gr.Textbox( placeholder="🔍 Filter Model-Device-OS (separate multiple queries with ';')", label="Filter", ) with gr.Column(scale=4): exclude_toks = gr.Textbox( placeholder="🔍 Exclude Model-Device-OS", label="Exclude", ) with gr.Row(): with gr.Column(): toks_plot = gr.Plot(container=True) demo.load( lambda x, y, z: plot_metric( x, "tokens_per_second", "Tok/s", "Tok/s Over Time for Model-Device-OS Combinations", y, z, ), [ gr.Dataframe(benchmark_df, visible=False), filter_toks, exclude_toks, ], toks_plot, ) filter_toks.change( lambda x, y, z: plot_metric( x, "tokens_per_second", "Tok/s", "Tok/s Over Time for Model-Device-OS Combinations", y, z, ), [ gr.Dataframe(benchmark_df, visible=False), filter_toks, exclude_toks, ], toks_plot, ) exclude_toks.change( lambda x, y, z: plot_metric( x, "tokens_per_second", "Tok/s", "Tok/s Over Time for Model-Device-OS Combinations", y, z, ), [ gr.Dataframe(benchmark_df, visible=False), filter_toks, exclude_toks, ], toks_plot, ) # Device Support Tab with gr.TabItem("Device Support", elem_id="device_support", id=6): # Load device support data from CSV support_data = pd.read_csv("dashboard_data/support_data.csv") support_data.set_index(support_data.columns[0], inplace=True) support_data["Model"] = support_data["Model"].apply( lambda x: x.replace("_", "/") ) support_data["Model"] = support_data["Model"].apply( lambda x: make_model_name_clickable_link(x) ) support_data = ( support_data.assign(model_len=support_data["Model"].str.len()) .sort_values( by=["model_len"], ascending=[True], ) .drop(columns=["model_len"]) ) with gr.Row(): with gr.Column(scale=1): with gr.Row(): with gr.Column(scale=6, elem_id="filter_models_column"): filter_support_models = gr.Textbox( placeholder="🔍 Filter Model (separate multiple queries with ';')", label="Filter Models", ) with gr.Column(scale=4, elem_classes="exclude_models_column"): exclude_support_models = gr.Textbox( placeholder="🔍 Exclude Model", label="Exclude Model", ) with gr.Row(): with gr.Accordion("See All Columns", open=False): with gr.Row(): with gr.Column(scale=9): support_shown_columns = gr.CheckboxGroup( choices=support_data.columns.tolist()[ 1: ], # Exclude 'Model' column value=support_data.columns.tolist()[1:], label="Toggle Columns", elem_id="support-column-select", interactive=True, ) with gr.Column(scale=1, min_width=200): with gr.Row(): select_all_support_button = gr.Button( "Select All", elem_id="select-all-support-button", interactive=True, ) deselect_all_support_button = gr.Button( "Deselect All", elem_id="deselect-all-support-button", interactive=True, ) with gr.Column(): gr.Markdown( """ ### Legend - ✅ Supported: The model is supported and tested on this device. - ⚠️ Failed: Some tests failed on this device. """ ) # Display device support data in a table device_support_table = gr.Dataframe( value=support_data, headers=support_data.columns.tolist(), datatype=["html" for _ in support_data.columns], elem_id="device-support-table", elem_classes="large-table", interactive=False, ) # Hidden dataframe to store the original data hidden_support_df = gr.Dataframe(value=support_data, visible=False) def filter_support_data(df, columns, model_query, exclude_models): filtered_df = df.copy() # Filter models based on query if model_query: filtered_df = filtered_df[ filtered_df["Model"].str.contains( "|".join(q.strip() for q in model_query.split(";")), case=False, regex=True, ) ] # Exclude specified models if exclude_models: exclude_list = [ re.escape(m.strip()) for m in exclude_models.split(";") ] filtered_df = filtered_df[ ~filtered_df["Model"].str.contains( "|".join(exclude_list), case=False, regex=True ) ] # Select columns selected_columns = ["Model"] + [ col for col in columns if col in df.columns ] filtered_df = filtered_df[selected_columns] return filtered_df def select_all_support_columns(): return support_data.columns.tolist()[1:] # Exclude 'Model' column def deselect_all_support_columns(): return [] # Connect the filter function to the input components filter_inputs = [ hidden_support_df, support_shown_columns, filter_support_models, exclude_support_models, ] filter_support_models.change( filter_support_data, filter_inputs, device_support_table ) exclude_support_models.change( filter_support_data, filter_inputs, device_support_table ) support_shown_columns.change( filter_support_data, filter_inputs, device_support_table ) # Connect select all and deselect all buttons select_all_support_button.click( select_all_support_columns, inputs=[], outputs=support_shown_columns, ) deselect_all_support_button.click( deselect_all_support_columns, inputs=[], outputs=support_shown_columns, ) # Methodology Tab with gr.TabItem("Methodology", elem_id="methodology", id=7): gr.Markdown(METHODOLOGY_TEXT, elem_id="methodology-text") # Citation section with gr.Accordion("📙 Citation", open=False): citation_button = gr.Textbox( value=CITATION_BUTTON_TEXT, label=CITATION_BUTTON_LABEL, lines=7, elem_id="citation-button", show_copy_button=True, ) # Launch the Gradio interface demo.launch(debug=True, share=True, ssr_mode=False)