Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import os | |
from huggingface_hub import HfFileSystem | |
REPO_ID = "nsourlos/draco_streamlit" | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
data_path='data.csv' | |
# Load the CSV file | |
def load_data(file): | |
df = pd.read_csv(file, index_col='id') | |
return df | |
# Save the CSV file | |
def save_data(df, filename): | |
df.to_csv(filename) | |
# Function to calculate accuracy for each unique text attribute | |
def calculate_accuracy(df): | |
accuracy_dict = {} | |
grouped = df.groupby('text')['label'] | |
for text, labels in grouped: | |
accuracy = labels.mean() # Accuracy is the mean of the label values | |
accuracy_dict[text] = accuracy | |
return accuracy_dict | |
# Initialize session state variables | |
if 'data' not in st.session_state: | |
st.session_state.data = None | |
if 'new_rows' not in st.session_state: | |
st.session_state.new_rows = [] | |
if 'file_path' not in st.session_state: | |
st.session_state.file_path = None | |
if 'add_row_clicked' not in st.session_state: | |
st.session_state.add_row_clicked = False | |
if 'rerun_count' not in st.session_state: | |
st.session_state.rerun_count = 0 | |
if 'finished' not in st.session_state: | |
st.session_state.finished = False | |
# Function to add new row | |
def add_row(new_text, new_label): | |
new_id = st.session_state['data'].index.max() + 1 if not st.session_state['data'].empty else 0 | |
new_row = {'id': new_id, 'text': new_text, 'label': new_label, 'checked': False} | |
st.session_state.new_rows.append(new_row) | |
updated_data=pd.concat([st.session_state.data, pd.DataFrame([new_row]).set_index('id')]) | |
file_path=st.session_state.file_path | |
save_data(updated_data, file_path) | |
st.session_state.data=load_data(file_path) | |
st.session_state.add_row_clicked = False # Reset the add row state | |
st.session_state.rerun_count += 1 | |
st.rerun() | |
# Streamlit app | |
st.title("Interactive DataFrame Editor") | |
# uploaded_file = st.file_uploader("Upload your CSV file", type="csv") | |
uploaded_file = data_path#'data.csv' | |
if uploaded_file is not None: | |
st.session_state.file_path = uploaded_file#.name | |
if st.session_state.rerun_count==0: | |
st.session_state.data = load_data(uploaded_file) | |
file_loaded=uploaded_file#.name | |
st.subheader("DataFrame") | |
if st.session_state.data is not None: | |
# Display non-editable columns | |
edited_data = st.data_editor(st.session_state.data) | |
if edited_data is not None: | |
st.session_state.data = edited_data | |
save_data(st.session_state.data, st.session_state.file_path) | |
if st.button("Add Row"): | |
st.session_state.add_row_clicked = True | |
if st.session_state.add_row_clicked: | |
# Inputs for adding new row | |
new_text = st.text_input("Enter model name for new row:") | |
new_label = st.selectbox("Select label for new row:", options=[0, 1]) | |
if st.button("Confirm Add Row"): | |
add_row(new_text, new_label) | |
# Calculate accuracy | |
accuracy_dict = calculate_accuracy(st.session_state.data) | |
# Create scatter plot | |
texts = list(accuracy_dict.keys()) | |
accuracies = list(accuracy_dict.values()) | |
fig, ax = plt.subplots(figsize=(10, 4)) | |
ax.scatter(texts, accuracies) | |
ax.set_xlabel('Text') | |
ax.set_ylabel('Accuracy') | |
ax.set_title('Accuracy of Labels for Each Text Attribute') | |
plt.xticks(rotation=90) # Rotate x-axis labels for better readability | |
st.subheader("Leaderboard") | |
st.pyplot(fig) | |
# Button to finish and reset session state | |
if st.button('Finish'): | |
st.success('Saving.... Space will restart soon....') | |
st.session_state.finished = True | |
fs = HfFileSystem(token=HF_TOKEN.replace("\"","")) | |
with fs.open('spaces/nsourlos/draco_streamlit/data.csv', 'w') as f: | |
f.write(st.session_state.data.to_csv()) | |
else: | |
st.write("Please upload a CSV file to get started.") |