import streamlit as st import pandas as pd import numpy as np import plotly.express as px st.set_page_config(layout="wide") import streamlit as st import pandas as pd import matplotlib.pyplot as plt import streamlit as st import pandas as pd import matplotlib.pyplot as plt import streamlit as st import pandas as pd import matplotlib.pyplot as plt import streamlit as st import pandas as pd import plotly.express as px import plotly.graph_objects as go st.markdown( """ """, unsafe_allow_html=True ) margins_css = """ """ st.markdown(margins_css, unsafe_allow_html=True) # Sample data for demonstration purposes models = ['SSD300', 'SSD512', 'DETR'] pruning_methods = ['VIB Pruning','Transfer Pruning'] datasets = ['VOC','SPARK'] hyperparameters = { 'SSD300': {'Transfer Pruning': [('SSD300-ResNet50', '-', '-', 120), ('SSD300-ITPCC-A', '-', '-', 120), ('SSD300-ITPCC-B', '-', '-', 120), ('SSD300-ITPCC-C', '-', '-', 120)], 'VIB Pruning': [('SSD300-ResNet50', '-', '-', 120), ('SSD300-VIB-v1', "0.0001", 240, 100), ('SSD300-VIB-v2', "0.0002", 240,100)]}, 'SSD512': {'Transfer Pruning': [('SSD512-ResNet50', '-', '-', 120), ('SSD512-ITPCC-A', '-', '-', 120), ('SSD512-ITPCC-B', '-', '-', 120), ('SSD512-ITPCC-C', '-', '-', 120)], 'VIB Pruning': [('SSD512-ResNet50', '-', '-', 120),('SSD512-VIB-v1', "0.0003", 200, 100)]}, 'DETR': {'SPARK': [("DETR-baseline", "-", "-","-", 20), ("DETR-SPARK-A", "-","-", 30, 40), ("DETR-SPARK-B", "-","-", 30, 40)], 'VOC': [("DETR-baseline", "-", "-","-", 130), ("DETR-VOC-A", "0.0001","0.00001", 80, 200), ("DETR-VOC-B", "0.00005","0.0001", 80, 200)]}, } results_data = { 'SSD300': { 'VIB Pruning':{'model':['SSD300-ResNet50','SSD300-VIB-v1','SSD300-VIB-v2'],'map': ["77.79", "78.71", "77.41"], 'flops': ["11.1", "5.04", "3.49"],'flopsd':['0.0%','54.55%','68.54%'], 'params': ["49.2", "19.84", "11.18"],'paramsd':['0.0%','59.68%','77.28%'],}, 'Transfer Pruning':{'model':["SSD300-ResNet50",'SSD300-ITPCC-A','SSD300-ITPCC-B','SSD300-ITPCC-C'],'map': ["77.79", "77.86" , "77.06", "75.08"], 'flops': ["11.1", "6.85", "5.08", "3.38"],'flopsd':['0.0%','38.2%','54.2%',"69.5%"], 'params': ["49.2", "32.5", "25.7", "19.4"],'paramsd':['0.0%','33.94%','47.77%',"60.5%"]}, }, 'SSD512': { 'VIB Pruning':{'model':["SSD512-ResNet50",'SSD512-VIB-v1'],'map': ["80.9","81.43"], 'flops': ["46.24", "9.73"],'flopsd':['0.0%','78.94%'], 'params': ["58.52","27.2"],'paramsd':['0.0%','53.42%'],}, 'Transfer Pruning':{'model':["SSD512-ResNet50",'SSD512-ITPCC-A','SSD512-ITPCC-B','SSD512-ITPCC-C'],'map': ["80.9","81.05" , "80.45", "78.82"], 'flops': ["46.2", "31.42", "25.6", "20.1"],'flopsd':['0.0%','31.9%','44.6%',"56.5%"], 'params': ["58.5", "41.8", "35.0", "28.7"],'paramsd':['0.0%','28.5%','40.17%',"50.1%"],}, }, 'DETR': {'SPARK':{'model':["DETR-baseline",'DETR-SPARK-A','DETR-SPARK-B'],'map': ["96.77", "94.5", "95.18"], 'flops': ["85", "56", "58"],'flopsd':['0.0%','34.1%','31.7%'], 'params': ["41.2", "23.3", "26.6"],'paramsd':['0.0%','47.3%','45.4%'],}, 'VOC':{'model':["DETR-baseline",'DETR-VOC-A','DETR-VOC-B'],'map': ["79.34", "77.2", "78.0"], 'flops': ["85", "55", "60"],'flopsd':['0.0%','35.29%','29.41%'], 'params': ["41.2", "21.71", "22.47"],'paramsd':['0.0%','42.65%','35.5%'],}}, } # Title of the research st.markdown('

Variational Information bottleneck pruning for Object detection

', unsafe_allow_html=True)# Create two columns with specified widths col1, col2 = st.columns([5.2, 4.8]) # Right section: Filters and Hyperparameters with col2: st.markdown('
', unsafe_allow_html=True) st.subheader('Filters') model = st.selectbox('Select model:', models) if model in ['SSD300', 'SSD512']: pruning = st.selectbox('Select pruning method:', pruning_methods) hyperparameter_data = hyperparameters[model][pruning] else: dataset = st.selectbox('Select dataset:', datasets) hyperparameter_data = hyperparameters[model][dataset] st.markdown('
', unsafe_allow_html=True) st.subheader('Hyperparameters') st.markdown('
', unsafe_allow_html=True) # Add space between filter and hyperparameters if model in ['SSD300', 'SSD512']: df_hyperparams = pd.DataFrame(hyperparameter_data, columns=['Model', 'kl factor', 'Pruning Epochs', 'Finetuning Epochs']) else: df_hyperparams = pd.DataFrame(hyperparameter_data, columns=['Model', 'kl backbone','kl transformer', 'Pruning Epochs', 'Finetuning Epochs']) st.markdown(df_hyperparams.style.hide(axis="index").to_html(), unsafe_allow_html=True) st.markdown('
', unsafe_allow_html=True) st.markdown('
', unsafe_allow_html=True) # Left section: Results and Evolution Graphs with col1: st.subheader('Results') # Display results table results = results_data[model] if model in ['SSD300', 'SSD512']: df_results = pd.DataFrame({ 'Model': results[pruning]['model'], 'mAP': results[pruning]['map'], 'FLOPs': results[pruning]['flops'], 'down by': results[pruning]['flopsd'], 'Params': results[pruning]['params'], 'down by ': results[pruning]['paramsd'], }) else: df_results = pd.DataFrame({ 'Model': results[dataset]['model'], 'mAP': results[dataset]['map'], 'FLOPs': results[dataset]['flops'], 'down by': results[dataset]['flopsd'], 'Params': results[dataset]['params'], 'down by ': results[dataset]['paramsd'], }) st.markdown(df_results.style.hide(axis="index").to_html(), unsafe_allow_html=True) # Display evolution graphs st.markdown("
", unsafe_allow_html=True) st.markdown( """

Evolution Graphs

""", unsafe_allow_html=True ) epochs = [1, 2, 3, 4] # fig = go.Figure() # fig.add_trace(go.Scatter(x=epochs, y=results['map'], mode='lines+markers', name='mAP')) # fig.update_layout(title='mAP per Epoch', xaxis_title='Epoch', yaxis_title='mAP') # st.plotly_chart(fig, use_container_width=True) fig = go.Figure() if model in ['SSD300', 'SSD512']: ff=pruning else: ff=dataset # Add FLOPs bar trace fig.add_trace(go.Bar( x=results[ff]['model'], y=results[ff]['flops'], name='FLOPs', marker_color='orange' )) # Add Params bar trace fig.add_trace(go.Bar( x=results[ff]['model'], y=results[ff]['params'], name='Params', marker_color='green' )) # Update the layout fig.update_layout( barmode='group', title='FLOPs and Params per model', xaxis_title='Model', yaxis_title='Count', legend_title='Metric', height=280 ) # Show the plot using Streamlit st.plotly_chart(fig, use_container_width=True)