|
import streamlit as st |
|
import pandas as pd |
|
import numpy as np |
|
import plotly.express as px |
|
|
|
st.set_page_config(layout="wide") |
|
import streamlit as st |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
import streamlit as st |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
import streamlit as st |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
import streamlit as st |
|
import pandas as pd |
|
import plotly.express as px |
|
import plotly.graph_objects as go |
|
|
|
|
|
|
|
st.markdown( |
|
""" |
|
<style> |
|
*{ |
|
padding:0; |
|
margin:0; |
|
} |
|
.fixed-col { |
|
position: fixed; |
|
top: 4rem; |
|
right: 0; |
|
width: 30%; |
|
padding-left: 0rem; |
|
background: white; |
|
z-index: 100; |
|
} |
|
body { |
|
margin: 0; |
|
padding: 0; |
|
} |
|
.maint { |
|
margin: auto; |
|
margin-bottom:1.5rem; |
|
} |
|
.centered-title { |
|
text-align: center; |
|
} |
|
.scroller { |
|
margin-top: 2rem; /* Adjust as necessary to avoid overlap */ |
|
} |
|
</style> |
|
""", unsafe_allow_html=True |
|
) |
|
margins_css = """ |
|
<style> |
|
.main > div { |
|
padding-left: 3rem; |
|
padding-right:3rem; |
|
padding-top:0.4rem; |
|
} |
|
</style> |
|
""" |
|
|
|
st.markdown(margins_css, unsafe_allow_html=True) |
|
|
|
models = ['SSD300', 'SSD512', 'DETR'] |
|
pruning_methods = ['VIB Pruning','Transfer Pruning'] |
|
datasets = ['VOC','SPARK'] |
|
|
|
hyperparameters = { |
|
'SSD300': {'Transfer Pruning': [('SSD300-ResNet50', '-', '-', 120), ('SSD300-ITPCC-A', '-', '-', 120), ('SSD300-ITPCC-B', '-', '-', 120), ('SSD300-ITPCC-C', '-', '-', 120)], |
|
'VIB Pruning': [('SSD300-ResNet50', '-', '-', 120), ('SSD300-VIB-v1', "0.0001", 240, 100), ('SSD300-VIB-v2', "0.0002", 240,100)]}, |
|
'SSD512': {'Transfer Pruning': [('SSD512-ResNet50', '-', '-', 120), ('SSD512-ITPCC-A', '-', '-', 120), ('SSD512-ITPCC-B', '-', '-', 120), ('SSD512-ITPCC-C', '-', '-', 120)], |
|
'VIB Pruning': [('SSD512-ResNet50', '-', '-', 120),('SSD512-VIB-v1', "0.0003", 200, 100)]}, |
|
'DETR': {'SPARK': [("DETR-baseline", "-", "-","-", 20), ("DETR-SPARK-A", "-","-", 30, 40), ("DETR-SPARK-B", "-","-", 30, 40)], |
|
'VOC': [("DETR-baseline", "-", "-","-", 130), ("DETR-VOC-A", "0.0001","0.00001", 80, 200), ("DETR-VOC-B", "0.00005","0.0001", 80, 200)]}, |
|
} |
|
|
|
results_data = { |
|
'SSD300': { |
|
'VIB Pruning':{'model':['SSD300-ResNet50','SSD300-VIB-v1','SSD300-VIB-v2'],'map': ["77.79", "78.71", "77.41"], 'flops': ["11.1", "5.04", "3.49"],'flopsd':['0.0%','54.55%','68.54%'], 'params': ["49.2", "19.84", "11.18"],'paramsd':['0.0%','59.68%','77.28%'],}, |
|
'Transfer Pruning':{'model':["SSD300-ResNet50",'SSD300-ITPCC-A','SSD300-ITPCC-B','SSD300-ITPCC-C'],'map': ["77.79", "77.86" , "77.06", "75.08"], 'flops': ["11.1", "6.85", "5.08", "3.38"],'flopsd':['0.0%','38.2%','54.2%',"69.5%"], 'params': ["49.2", "32.5", "25.7", "19.4"],'paramsd':['0.0%','33.94%','47.77%',"60.5%"]}, |
|
}, |
|
'SSD512': { |
|
'VIB Pruning':{'model':["SSD512-ResNet50",'SSD512-VIB-v1'],'map': ["80.9","81.43"], 'flops': ["46.24", "9.73"],'flopsd':['0.0%','78.94%'], 'params': ["58.52","27.2"],'paramsd':['0.0%','53.42%'],}, |
|
'Transfer Pruning':{'model':["SSD512-ResNet50",'SSD512-ITPCC-A','SSD512-ITPCC-B','SSD512-ITPCC-C'],'map': ["80.9","81.05" , "80.45", "78.82"], 'flops': ["46.2", "31.42", "25.6", "20.1"],'flopsd':['0.0%','31.9%','44.6%',"56.5%"], 'params': ["58.5", "41.8", "35.0", "28.7"],'paramsd':['0.0%','28.5%','40.17%',"50.1%"],}, |
|
}, |
|
|
|
'DETR': {'SPARK':{'model':["DETR-baseline",'DETR-SPARK-A','DETR-SPARK-B'],'map': ["96.77", "94.5", "95.18"], 'flops': ["85", "56", "58"],'flopsd':['0.0%','34.1%','31.7%'], 'params': ["41.2", "23.3", "26.6"],'paramsd':['0.0%','47.3%','45.4%'],}, |
|
'VOC':{'model':["DETR-baseline",'DETR-VOC-A','DETR-VOC-B'],'map': ["79.34", "77.2", "78.0"], 'flops': ["85", "55", "60"],'flopsd':['0.0%','35.29%','29.41%'], 'params': ["41.2", "21.71", "22.47"],'paramsd':['0.0%','42.65%','35.5%'],}}, |
|
} |
|
|
|
|
|
|
|
st.markdown('<h1 class="centered-title">Variational Information bottleneck pruning for Object detection</h1>', unsafe_allow_html=True) |
|
col1, col2 = st.columns([5.2, 4.8]) |
|
|
|
with col2: |
|
st.markdown('<div class="fixed-col">', unsafe_allow_html=True) |
|
st.subheader('Filters') |
|
model = st.selectbox('Select model:', models) |
|
|
|
if model in ['SSD300', 'SSD512']: |
|
pruning = st.selectbox('Select pruning method:', pruning_methods) |
|
hyperparameter_data = hyperparameters[model][pruning] |
|
else: |
|
dataset = st.selectbox('Select dataset:', datasets) |
|
hyperparameter_data = hyperparameters[model][dataset] |
|
|
|
st.markdown('<div class="scroller">', unsafe_allow_html=True) |
|
st.subheader('Hyperparameters') |
|
st.markdown('<br>', unsafe_allow_html=True) |
|
if model in ['SSD300', 'SSD512']: |
|
df_hyperparams = pd.DataFrame(hyperparameter_data, columns=['Model', 'kl factor', 'Pruning Epochs', 'Finetuning Epochs']) |
|
else: |
|
df_hyperparams = pd.DataFrame(hyperparameter_data, columns=['Model', 'kl backbone','kl transformer', 'Pruning Epochs', 'Finetuning Epochs']) |
|
st.markdown(df_hyperparams.style.hide(axis="index").to_html(), unsafe_allow_html=True) |
|
st.markdown('</div>', unsafe_allow_html=True) |
|
st.markdown('</div>', unsafe_allow_html=True) |
|
|
|
|
|
with col1: |
|
st.subheader('Results') |
|
|
|
|
|
results = results_data[model] |
|
if model in ['SSD300', 'SSD512']: |
|
df_results = pd.DataFrame({ |
|
'Model': results[pruning]['model'], |
|
'mAP (%)': results[pruning]['map'], |
|
'GFLOPs': results[pruning]['flops'], |
|
'down by': results[pruning]['flopsd'], |
|
'Parameters (M)': results[pruning]['params'], |
|
'down by ': results[pruning]['paramsd'], |
|
}) |
|
else: |
|
df_results = pd.DataFrame({ |
|
'Model': results[dataset]['model'], |
|
'mAP': results[dataset]['map'], |
|
'FLOPs': results[dataset]['flops'], |
|
'down by': results[dataset]['flopsd'], |
|
'Params': results[dataset]['params'], |
|
'down by ': results[dataset]['paramsd'], |
|
}) |
|
|
|
|
|
st.markdown(df_results.style.hide(axis="index").to_html(), unsafe_allow_html=True) |
|
|
|
|
|
st.markdown("<div style='margin-top: 15px;'></div>", unsafe_allow_html=True) |
|
st.markdown( |
|
""" |
|
<h2 id="evolution-graphs" style="margin-bottom: 0px; padding-bottom: 0px;"> |
|
Evolution Graphs |
|
</h2> |
|
""", |
|
unsafe_allow_html=True |
|
) |
|
epochs = [1, 2, 3, 4] |
|
|
|
|
|
|
|
|
|
|
|
|
|
fig = go.Figure() |
|
if model in ['SSD300', 'SSD512']: |
|
ff=pruning |
|
else: |
|
ff=dataset |
|
|
|
|
|
|
|
fig.add_trace(go.Bar( |
|
x=results[ff]['model'], |
|
y=results[ff]['flops'], |
|
name='FLOPs', |
|
marker_color='orange' |
|
)) |
|
|
|
|
|
fig.add_trace(go.Bar( |
|
x=results[ff]['model'], |
|
y=results[ff]['params'], |
|
name='Params', |
|
marker_color='green' |
|
)) |
|
|
|
|
|
fig.update_layout( |
|
barmode='group', |
|
title='FLOPs and Params per model', |
|
xaxis_title='Model', |
|
yaxis_title='Count', |
|
legend_title='Metric', |
|
height=280, |
|
width=500, |
|
) |
|
|
|
|
|
st.plotly_chart(fig, use_container_width=True) |
|
|
|
|
|
|
|
|
|
|
|
|