CatoEr's picture
Update app.py
6eeb6f8 verified
raw
history blame
No virus
5.57 kB
import torch
from torch import nn
from transformers import AutoModel, AutoTokenizer
import gradio as gr
device = torch.device("cpu")
class RaceClassifier(nn.Module):
def __init__(self, n_classes):
super(RaceClassifier, self).__init__()
self.bert = AutoModel.from_pretrained("vinai/bertweet-base")
self.drop = nn.Dropout(p=0.3) # can be changed in future
self.out = nn.Linear(self.bert.config.hidden_size,
n_classes) # linear layer for the output with the number of classes
def forward(self, input_ids, attention_mask):
bert_output = self.bert(
input_ids=input_ids,
attention_mask=attention_mask
)
last_hidden_state = bert_output[0]
pooled_output = last_hidden_state[:, 0]
output = self.drop(pooled_output)
return self.out(output)
race_labels = {
0: "African American",
1: "Asian",
2: "Latin",
3: "White"
}
age_labels = {
0: "Adult",
1: "Elderly",
2: "Young"
}
education_labels = {
0: "Educated",
1: "Uneducated"
}
gender_labels = {
0: "Female",
1: "Male",
2: "Non-Binary",
3: "Transgender"
}
orientation_labels = {
0: "Heterosexual",
1: "LGBTQ"
}
model_race = RaceClassifier(n_classes=4)
model_race.to(device)
model_race.load_state_dict(torch.load('best_model_race.pt', map_location=torch.device('cpu')))
model_age = RaceClassifier(n_classes=3)
model_age.to(device)
model_age.load_state_dict(torch.load('best_model_age_last.pt', map_location=torch.device('cpu')))
model_education = RaceClassifier(n_classes=2)
model_education.to(device)
model_education.load_state_dict(torch.load('best_model_education_last.pt', map_location=torch.device('cpu')))
model_gender = RaceClassifier(n_classes=4)
model_gender.to(device)
model_gender.load_state_dict(torch.load('best_model_gender_last.pt', map_location=torch.device('cpu')))
model_orientation = RaceClassifier(n_classes=2)
model_orientation.to(device)
model_orientation.load_state_dict(torch.load('best_model_orientation_last.pt', map_location=torch.device('cpu')))
def evaluate(model, input, mask):
model.eval()
with torch.no_grad():
outputs = model(input, mask)
probs = torch.nn.functional.softmax(outputs, dim=1)
predictions = torch.argmax(outputs, dim=1)
predictions = predictions.cpu().numpy()
return probs, predictions
def write_output(probs, predictions, title, labels):
output_string = f"{title.upper()}\n Probabilities:\n"
for i, prob in enumerate(probs[0]):
print(f"{labels[i]} = {round(prob.item() * 100, 2)}%")
output_string += f"{labels[i]} = {round(prob.item() * 100, 2)}%\n"
output_string += f"Predicted as: {labels[predictions[0]]}\n"
return output_string
def predict(*text):
tweets = [tweet for tweet in text if tweet]
print(tweets)
sentences = tweets
tokenizer = AutoTokenizer.from_pretrained("vinai/bertweet-base", normalization=True)
encoded_sentences = tokenizer(
sentences,
padding=True,
truncation=True,
return_tensors='pt',
max_length=128,
)
input_ids = encoded_sentences["input_ids"].to(device)
attention_mask = encoded_sentences["attention_mask"].to(device)
race_probs, race_predictions = evaluate(model_race, input_ids, attention_mask)
age_probs, age_predictions = evaluate(model_age, input_ids, attention_mask)
education_probs, education_predictions = evaluate(model_education, input_ids, attention_mask)
gender_probs, gender_predictions = evaluate(model_gender, input_ids, attention_mask)
orientation_probs, orientation_predictions = evaluate(model_orientation, input_ids, attention_mask)
final_output = str()
final_output += write_output(race_probs, race_predictions, "race", race_labels)
final_output += "\n"
final_output += write_output(age_probs, age_predictions,"age",age_labels)
final_output += "\n"
final_output += write_output(education_probs,education_predictions,"education", education_labels)
final_output += "\n"
final_output += write_output(gender_probs, gender_predictions, "gender", gender_labels)
final_output += "\n"
final_output += write_output(orientation_probs, orientation_predictions, "sexual orientation", orientation_labels)
return final_output
max_textboxes = 20
def update_textboxes(k):
components = []
if k is None:
k = 0
for i in range(max_textboxes):
if i < k:
components.append(gr.update(visible=True))
else:
components.append(gr.update(visible=False))
return components
def clear_textboxes():
return [gr.update(value='') for _ in range(max_textboxes)]
def clear_output_box():
return gr.update(value='')
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=1):
s = gr.Slider(1, max_textboxes, value=1, step=1, label="How many tweets do you want to enter:")
textboxes = [gr.Textbox(label=f"Tweet {i + 1}", visible=(i == 0)) for i in range(max_textboxes)]
s.change(fn=update_textboxes, inputs=s, outputs=textboxes)
btn = gr.Button("Predict")
btn_clear = gr.Button("Clear")
with gr.Column(scale=1):
output = gr.Textbox(label="Profile of User")
btn.click(fn=predict, inputs=textboxes, outputs=output)
btn_clear.click(fn=clear_textboxes, outputs=textboxes)
btn_clear.click(fn=clear_output_box, outputs=output)
demo.launch()