|
import torch |
|
from torch import nn |
|
from transformers import AutoModel, AutoTokenizer |
|
import gradio as gr |
|
|
|
|
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda") |
|
else: |
|
device = torch.device("cpu") |
|
|
|
|
|
class RaceClassifier(nn.Module): |
|
|
|
def __init__(self, n_classes): |
|
super(RaceClassifier, self).__init__() |
|
self.bert = AutoModel.from_pretrained("vinai/bertweet-base") |
|
self.drop = nn.Dropout(p=0.3) |
|
self.out = nn.Linear(self.bert.config.hidden_size, |
|
n_classes) |
|
|
|
def forward(self, input_ids, attention_mask): |
|
bert_output = self.bert( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask |
|
) |
|
last_hidden_state = bert_output[0] |
|
pooled_output = last_hidden_state[:, 0] |
|
output = self.drop(pooled_output) |
|
return self.out(output) |
|
|
|
|
|
labels = { |
|
0: "African American", |
|
1: "Asian", |
|
2: "Latin", |
|
3: "White" |
|
} |
|
model_race = RaceClassifier(n_classes=4) |
|
model_race.to(device) |
|
model_race.load_state_dict(torch.load('best_model_race.pt')) |
|
|
|
|
|
def predict(text): |
|
sentences = [ |
|
text |
|
] |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("vinai/bertweet-base", normalization=True) |
|
|
|
encoded_sentences = tokenizer( |
|
sentences, |
|
padding=True, |
|
truncation=True, |
|
return_tensors='pt', |
|
max_length=128, |
|
) |
|
|
|
input_ids = encoded_sentences["input_ids"].to(device) |
|
attention_mask = encoded_sentences["attention_mask"].to(device) |
|
|
|
model_race.eval() |
|
with torch.no_grad(): |
|
outputs = model_race(input_ids, attention_mask) |
|
probs = torch.nn.functional.softmax(outputs, dim=1) |
|
predictions = torch.argmax(outputs, dim=1) |
|
predictions = predictions.cpu().numpy() |
|
|
|
output_string = "" |
|
for i, prob in enumerate(probs[0]): |
|
print(f"{labels[i]}: %{round(prob.item() * 100, 2)}") |
|
output_string += f"{labels[i]}: %{round(prob.item() * 100, 2)}\n" |
|
|
|
print(labels[predictions[0]]) |
|
output_string += f"Predicted as: {labels[predictions[0]]}" |
|
|
|
return output_string |
|
|
|
|
|
demo = gr.Interface( |
|
fn=predict, |
|
inputs=["text"], |
|
outputs=["text"], |
|
) |
|
|
|
demo.launch() |
|
|
|
|