test-pfe / app.py
Nidhal-ch's picture
Update app.py
c3200ea verified
import streamlit as st
import pandas as pd
import numpy as np
import plotly.express as px
st.set_page_config(layout="wide")
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
import streamlit as st
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
st.markdown(
"""
<style>
*{
padding:0;
margin:0;
}
.fixed-col {
position: fixed;
top: 4rem;
right: 0;
width: 30%;
padding-left: 0rem;
background: white;
z-index: 100;
}
body {
margin: 0;
padding: 0;
}
.maint {
margin: auto;
margin-bottom:1.5rem;
}
.centered-title {
text-align: center;
}
.scroller {
margin-top: 2rem; /* Adjust as necessary to avoid overlap */
}
</style>
""", unsafe_allow_html=True
)
margins_css = """
<style>
.main > div {
padding-left: 3rem;
padding-right:3rem;
padding-top:0.4rem;
}
</style>
"""
st.markdown(margins_css, unsafe_allow_html=True)
# Sample data for demonstration purposes
models = ['SSD300', 'SSD512', 'DETR']
pruning_methods = ['VIB Pruning','Transfer Pruning']
datasets = ['VOC','SPARK']
hyperparameters = {
'SSD300': {'Transfer Pruning': [('SSD300-ResNet50', '-', '-', 120), ('SSD300-ITPCC-A', '-', '-', 120), ('SSD300-ITPCC-B', '-', '-', 120), ('SSD300-ITPCC-C', '-', '-', 120)],
'VIB Pruning': [('SSD300-ResNet50', '-', '-', 120), ('SSD300-VIB-v1', "0.0001", 240, 100), ('SSD300-VIB-v2', "0.0002", 240,100)]},
'SSD512': {'Transfer Pruning': [('SSD512-ResNet50', '-', '-', 120), ('SSD512-ITPCC-A', '-', '-', 120), ('SSD512-ITPCC-B', '-', '-', 120), ('SSD512-ITPCC-C', '-', '-', 120)],
'VIB Pruning': [('SSD512-ResNet50', '-', '-', 120),('SSD512-VIB-v1', "0.0003", 200, 100)]},
'DETR': {'SPARK': [("DETR-baseline", "-", "-","-", 20), ("DETR-SPARK-A", "-","-", 30, 40), ("DETR-SPARK-B", "-","-", 30, 40)],
'VOC': [("DETR-baseline", "-", "-","-", 130), ("DETR-VOC-A", "0.0001","0.00001", 80, 200), ("DETR-VOC-B", "0.00005","0.0001", 80, 200)]},
}
results_data = {
'SSD300': {
'VIB Pruning':{'model':['SSD300-ResNet50','SSD300-VIB-v1','SSD300-VIB-v2'],'map': ["77.79", "78.71", "77.41"], 'flops': ["11.1", "5.04", "3.49"],'flopsd':['0.0%','54.55%','68.54%'], 'params': ["49.2", "19.84", "11.18"],'paramsd':['0.0%','59.68%','77.28%'],},
'Transfer Pruning':{'model':["SSD300-ResNet50",'SSD300-ITPCC-A','SSD300-ITPCC-B','SSD300-ITPCC-C'],'map': ["77.79", "77.86" , "77.06", "75.08"], 'flops': ["11.1", "6.85", "5.08", "3.38"],'flopsd':['0.0%','38.2%','54.2%',"69.5%"], 'params': ["49.2", "32.5", "25.7", "19.4"],'paramsd':['0.0%','33.94%','47.77%',"60.5%"]},
},
'SSD512': {
'VIB Pruning':{'model':["SSD512-ResNet50",'SSD512-VIB-v1'],'map': ["80.9","81.43"], 'flops': ["46.24", "9.73"],'flopsd':['0.0%','78.94%'], 'params': ["58.52","27.2"],'paramsd':['0.0%','53.42%'],},
'Transfer Pruning':{'model':["SSD512-ResNet50",'SSD512-ITPCC-A','SSD512-ITPCC-B','SSD512-ITPCC-C'],'map': ["80.9","81.05" , "80.45", "78.82"], 'flops': ["46.2", "31.42", "25.6", "20.1"],'flopsd':['0.0%','31.9%','44.6%',"56.5%"], 'params': ["58.5", "41.8", "35.0", "28.7"],'paramsd':['0.0%','28.5%','40.17%',"50.1%"],},
},
'DETR': {'SPARK':{'model':["DETR-baseline",'DETR-SPARK-A','DETR-SPARK-B'],'map': ["96.77", "94.5", "95.18"], 'flops': ["85", "56", "58"],'flopsd':['0.0%','34.1%','31.7%'], 'params': ["41.2", "23.3", "26.6"],'paramsd':['0.0%','47.3%','45.4%'],},
'VOC':{'model':["DETR-baseline",'DETR-VOC-A','DETR-VOC-B'],'map': ["79.34", "77.2", "78.0"], 'flops': ["85", "55", "60"],'flopsd':['0.0%','35.29%','29.41%'], 'params': ["41.2", "21.71", "22.47"],'paramsd':['0.0%','42.65%','35.5%'],}},
}
# Title of the research
st.markdown('<h1 class="centered-title">Variational Information bottleneck pruning for Object detection</h1>', unsafe_allow_html=True)# Create two columns with specified widths
col1, col2 = st.columns([5.2, 4.8])
# Right section: Filters and Hyperparameters
with col2:
st.markdown('<div class="fixed-col">', unsafe_allow_html=True)
st.subheader('Filters')
model = st.selectbox('Select model:', models)
if model in ['SSD300', 'SSD512']:
pruning = st.selectbox('Select pruning method:', pruning_methods)
hyperparameter_data = hyperparameters[model][pruning]
else:
dataset = st.selectbox('Select dataset:', datasets)
hyperparameter_data = hyperparameters[model][dataset]
st.markdown('<div class="scroller">', unsafe_allow_html=True)
st.subheader('Hyperparameters')
st.markdown('<br>', unsafe_allow_html=True) # Add space between filter and hyperparameters
if model in ['SSD300', 'SSD512']:
df_hyperparams = pd.DataFrame(hyperparameter_data, columns=['Model', 'kl factor', 'Pruning Epochs', 'Finetuning Epochs'])
else:
df_hyperparams = pd.DataFrame(hyperparameter_data, columns=['Model', 'kl backbone','kl transformer', 'Pruning Epochs', 'Finetuning Epochs'])
st.markdown(df_hyperparams.style.hide(axis="index").to_html(), unsafe_allow_html=True)
st.markdown('</div>', unsafe_allow_html=True)
st.markdown('</div>', unsafe_allow_html=True)
# Left section: Results and Evolution Graphs
with col1:
st.subheader('Results')
# Display results table
results = results_data[model]
if model in ['SSD300', 'SSD512']:
df_results = pd.DataFrame({
'Model': results[pruning]['model'],
'mAP (%)': results[pruning]['map'],
'GFLOPs': results[pruning]['flops'],
'down by': results[pruning]['flopsd'],
'Parameters (M)': results[pruning]['params'],
'down by ': results[pruning]['paramsd'],
})
else:
df_results = pd.DataFrame({
'Model': results[dataset]['model'],
'mAP': results[dataset]['map'],
'FLOPs': results[dataset]['flops'],
'down by': results[dataset]['flopsd'],
'Params': results[dataset]['params'],
'down by ': results[dataset]['paramsd'],
})
st.markdown(df_results.style.hide(axis="index").to_html(), unsafe_allow_html=True)
# Display evolution graphs
st.markdown("<div style='margin-top: 15px;'></div>", unsafe_allow_html=True)
st.markdown(
"""
<h2 id="evolution-graphs" style="margin-bottom: 0px; padding-bottom: 0px;">
Evolution Graphs
</h2>
""",
unsafe_allow_html=True
)
epochs = [1, 2, 3, 4]
# fig = go.Figure()
# fig.add_trace(go.Scatter(x=epochs, y=results['map'], mode='lines+markers', name='mAP'))
# fig.update_layout(title='mAP per Epoch', xaxis_title='Epoch', yaxis_title='mAP')
# st.plotly_chart(fig, use_container_width=True)
fig = go.Figure()
if model in ['SSD300', 'SSD512']:
ff=pruning
else:
ff=dataset
# Add FLOPs bar trace
fig.add_trace(go.Bar(
x=results[ff]['model'],
y=results[ff]['flops'],
name='FLOPs',
marker_color='orange'
))
# Add Params bar trace
fig.add_trace(go.Bar(
x=results[ff]['model'],
y=results[ff]['params'],
name='Params',
marker_color='green'
))
# Update the layout
fig.update_layout(
barmode='group',
title='FLOPs and Params per model',
xaxis_title='Model',
yaxis_title='Count',
legend_title='Metric',
height=280,
width=500,
)
# Show the plot using Streamlit
st.plotly_chart(fig, use_container_width=True)