import streamlit as st from transformers import AutoModelForSequenceClassification, AutoTokenizer import torch import pandas as pd import random classifiers = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] def reset_scores(): global scores_df scores_df = pd.DataFrame(columns=['Comment'] + classifiers) def get_score(model_base, text): if model_base == "bert-base-cased": model_dir = "./bert/_bert_model" elif model_base == "distilbert-base-cased": model_dir = "./distilbert/_distilbert_model" else: model_dir = "./roberta/_roberta_model" model = AutoModelForSequenceClassification.from_pretrained(model_dir) tokenizer = AutoTokenizer.from_pretrained(model_base) inputs = tokenizer.encode_plus( text, max_length=512, truncation=True, padding=True, return_tensors='pt') outputs = model(**inputs) predictions = torch.sigmoid(outputs.logits) return predictions # Ask user for input, return scores st.title("Toxic Comment Classifier") text_input = st.text_input("Enter text for toxicity classification", "I hope you die") submit_btn = st.button("Submit") # Drop down menu for model selection, default is roberta model_base = st.selectbox("Select a pretrained model", ["roberta-base", "bert-base-cased", "distilbert-base-cased"]) if submit_btn and text_input: result = get_score(model_base, text_input) df = pd.DataFrame([result[0].tolist()], columns=classifiers) df = df.round(2) # Round the values to 2 decimal places # Format the values as percentages df = df.applymap(lambda x: '{:.0%}'.format(x)) st.table(df) # Read the test dataset test_df = pd.read_csv( "./jigsaw-toxic-comment-classification-challenge/test.csv") # Select 10 random comments from the test dataset sample_df = test_df.sample(n=3) # Create an empty DataFrame to store the scores reset_scores() # Calculate the scores for each comment and add them to the DataFrame for index, row in sample_df.iterrows(): result = get_score(model_base, row['comment_text']) scores = result[0].tolist() scores_df.loc[len(scores_df)] = [row['comment_text']] + scores # Round the values to 2 decimal places scores_df = scores_df.round(2) st.subheader("Toxicity Scores for Random Comments") st.table(scores_df) # Create a button to reset the scores if st.button("Refresh Random Tweets"): reset_scores() st.success("New tweets have been loaded!")