CatoEr commited on
Commit
bea74aa
1 Parent(s): 202e06e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -6
app.py CHANGED
@@ -1,13 +1,87 @@
 
 
 
1
  import gradio as gr
2
 
 
 
 
 
 
3
 
4
- def greet(name):
5
- result = "Label Probabilities:\n" + f"African American: {str(round(0.797795832157135,2)*100)}\n"
6
- + f"Asian: {str(round(0.17413224279880524,2)*100)}\n"+ f"Latin: {str(round(0.0132269160822033,2)*100)}\n"+ f"White: {str(round(0.14844958670437336,2)*100)}"
7
-
8
- return result
9
 
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
13
  demo.launch()
 
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers import AutoModel, AutoTokenizer
4
  import gradio as gr
5
 
6
+ # Check if CUDA is available
7
+ if torch.cuda.is_available():
8
+ device = torch.device("cuda")
9
+ else:
10
+ device = torch.device("cpu")
11
 
 
 
 
 
 
12
 
13
+ class RaceClassifier(nn.Module):
14
 
15
+ def __init__(self, n_classes):
16
+ super(RaceClassifier, self).__init__()
17
+ self.bert = AutoModel.from_pretrained("vinai/bertweet-base")
18
+ self.drop = nn.Dropout(p=0.3) # can be changed in future
19
+ self.out = nn.Linear(self.bert.config.hidden_size,
20
+ n_classes) # linear layer for the output with the number of classes
21
+
22
+ def forward(self, input_ids, attention_mask):
23
+ bert_output = self.bert(
24
+ input_ids=input_ids,
25
+ attention_mask=attention_mask
26
+ )
27
+ last_hidden_state = bert_output[0]
28
+ pooled_output = last_hidden_state[:, 0]
29
+ output = self.drop(pooled_output)
30
+ return self.out(output)
31
+
32
+
33
+ labels = {
34
+ 0: "African American",
35
+ 1: "Asian",
36
+ 2: "Latin",
37
+ 3: "White"
38
+ }
39
+ model_race = RaceClassifier(n_classes=4)
40
+ model_race.to(device)
41
+ model_race.load_state_dict(torch.load('best_model_race.pt'))
42
+
43
+
44
+ def predict(text):
45
+ sentences = [
46
+ text
47
+ ]
48
+
49
+ tokenizer = AutoTokenizer.from_pretrained("vinai/bertweet-base", normalization=True)
50
+
51
+ encoded_sentences = tokenizer(
52
+ sentences,
53
+ padding=True,
54
+ truncation=True,
55
+ return_tensors='pt',
56
+ max_length=128,
57
+ )
58
+
59
+ input_ids = encoded_sentences["input_ids"].to(device)
60
+ attention_mask = encoded_sentences["attention_mask"].to(device)
61
+
62
+ model_race.eval()
63
+ with torch.no_grad():
64
+ outputs = model_race(input_ids, attention_mask)
65
+ probs = torch.nn.functional.softmax(outputs, dim=1)
66
+ predictions = torch.argmax(outputs, dim=1)
67
+ predictions = predictions.cpu().numpy()
68
+
69
+ output_string = ""
70
+ for i, prob in enumerate(probs[0]):
71
+ print(f"{labels[i]}: %{round(prob.item() * 100, 2)}")
72
+ output_string += f"{labels[i]}: %{round(prob.item() * 100, 2)}\n"
73
+
74
+ print(labels[predictions[0]])
75
+ output_string += f"Predicted as: {labels[predictions[0]]}"
76
+
77
+ return output_string
78
+
79
+
80
+ demo = gr.Interface(
81
+ fn=predict,
82
+ inputs=["text"],
83
+ outputs=["text"],
84
+ )
85
 
 
86
  demo.launch()
87
+