CatoEr's picture
Update app.py
6e257b4 verified
raw
history blame
No virus
2.22 kB
import torch
from torch import nn
from transformers import AutoModel, AutoTokenizer
import gradio as gr
device = torch.device("cpu")
class RaceClassifier(nn.Module):
def __init__(self, n_classes):
super(RaceClassifier, self).__init__()
self.bert = AutoModel.from_pretrained("vinai/bertweet-base")
self.drop = nn.Dropout(p=0.3) # can be changed in future
self.out = nn.Linear(self.bert.config.hidden_size,
n_classes) # linear layer for the output with the number of classes
def forward(self, input_ids, attention_mask):
bert_output = self.bert(
input_ids=input_ids,
attention_mask=attention_mask
)
last_hidden_state = bert_output[0]
pooled_output = last_hidden_state[:, 0]
output = self.drop(pooled_output)
return self.out(output)
labels = {
0: "African American",
1: "Asian",
2: "Latin",
3: "White"
}
model_race = RaceClassifier(n_classes=4)
model_race.to(device)
model_race.load_state_dict(torch.load('best_model_race.pt'))
def predict(text):
sentences = [
text
]
tokenizer = AutoTokenizer.from_pretrained("vinai/bertweet-base", normalization=True)
encoded_sentences = tokenizer(
sentences,
padding=True,
truncation=True,
return_tensors='pt',
max_length=128,
)
input_ids = encoded_sentences["input_ids"].to(device)
attention_mask = encoded_sentences["attention_mask"].to(device)
model_race.eval()
with torch.no_grad():
outputs = model_race(input_ids, attention_mask)
probs = torch.nn.functional.softmax(outputs, dim=1)
predictions = torch.argmax(outputs, dim=1)
predictions = predictions.cpu().numpy()
output_string = ""
for i, prob in enumerate(probs[0]):
print(f"{labels[i]}: %{round(prob.item() * 100, 2)}")
output_string += f"{labels[i]}: %{round(prob.item() * 100, 2)}\n"
print(labels[predictions[0]])
output_string += f"Predicted as: {labels[predictions[0]]}"
return output_string
demo = gr.Interface(
fn=predict,
inputs=["text"],
outputs=["text"],
)
demo.launch()