draco_streamlit / app.py
nsourlos's picture
'final_submit_message_correction'
88cec06
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
import os
from huggingface_hub import HfFileSystem
REPO_ID = "nsourlos/draco_streamlit"
HF_TOKEN = os.getenv("HF_TOKEN")
data_path='data.csv'
# Load the CSV file
def load_data(file):
df = pd.read_csv(file, index_col='id')
return df
# Save the CSV file
def save_data(df, filename):
df.to_csv(filename)
# Function to calculate accuracy for each unique text attribute
def calculate_accuracy(df):
accuracy_dict = {}
grouped = df.groupby('text')['label']
for text, labels in grouped:
accuracy = labels.mean() # Accuracy is the mean of the label values
accuracy_dict[text] = accuracy
return accuracy_dict
# Initialize session state variables
if 'data' not in st.session_state:
st.session_state.data = None
if 'new_rows' not in st.session_state:
st.session_state.new_rows = []
if 'file_path' not in st.session_state:
st.session_state.file_path = None
if 'add_row_clicked' not in st.session_state:
st.session_state.add_row_clicked = False
if 'rerun_count' not in st.session_state:
st.session_state.rerun_count = 0
if 'finished' not in st.session_state:
st.session_state.finished = False
# Function to add new row
def add_row(new_text, new_label):
new_id = st.session_state['data'].index.max() + 1 if not st.session_state['data'].empty else 0
new_row = {'id': new_id, 'text': new_text, 'label': new_label, 'checked': False}
st.session_state.new_rows.append(new_row)
updated_data=pd.concat([st.session_state.data, pd.DataFrame([new_row]).set_index('id')])
file_path=st.session_state.file_path
save_data(updated_data, file_path)
st.session_state.data=load_data(file_path)
st.session_state.add_row_clicked = False # Reset the add row state
st.session_state.rerun_count += 1
st.rerun()
# Streamlit app
st.title("Interactive DataFrame Editor")
# uploaded_file = st.file_uploader("Upload your CSV file", type="csv")
uploaded_file = data_path#'data.csv'
if uploaded_file is not None:
st.session_state.file_path = uploaded_file#.name
if st.session_state.rerun_count==0:
st.session_state.data = load_data(uploaded_file)
file_loaded=uploaded_file#.name
st.subheader("DataFrame")
if st.session_state.data is not None:
# Display non-editable columns
edited_data = st.data_editor(st.session_state.data)
if edited_data is not None:
st.session_state.data = edited_data
save_data(st.session_state.data, st.session_state.file_path)
if st.button("Add Row"):
st.session_state.add_row_clicked = True
if st.session_state.add_row_clicked:
# Inputs for adding new row
new_text = st.text_input("Enter model name for new row:")
new_label = st.selectbox("Select label for new row:", options=[0, 1])
if st.button("Confirm Add Row"):
add_row(new_text, new_label)
# Calculate accuracy
accuracy_dict = calculate_accuracy(st.session_state.data)
# Create scatter plot
texts = list(accuracy_dict.keys())
accuracies = list(accuracy_dict.values())
fig, ax = plt.subplots(figsize=(10, 4))
ax.scatter(texts, accuracies)
ax.set_xlabel('Text')
ax.set_ylabel('Accuracy')
ax.set_title('Accuracy of Labels for Each Text Attribute')
plt.xticks(rotation=90) # Rotate x-axis labels for better readability
st.subheader("Leaderboard")
st.pyplot(fig)
# Button to finish and reset session state
if st.button('Finish'):
st.success('Saving.... Space will restart soon....')
st.session_state.finished = True
fs = HfFileSystem(token=HF_TOKEN.replace("\"",""))
with fs.open('spaces/nsourlos/draco_streamlit/data.csv', 'w') as f:
f.write(st.session_state.data.to_csv())
else:
st.write("Please upload a CSV file to get started.")