CatoEr commited on
Commit
7216ad1
1 Parent(s): 56bf313

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -16
app.py CHANGED
@@ -27,15 +27,47 @@ class RaceClassifier(nn.Module):
27
  return self.out(output)
28
 
29
 
30
- labels = {
31
  0: "African American",
32
  1: "Asian",
33
  2: "Latin",
34
  3: "White"
35
  }
 
 
 
 
 
 
36
  model_race = RaceClassifier(n_classes=4)
37
  model_race.to(device)
38
- model_race.load_state_dict(torch.load('best_model_race.pt', map_location=torch.device('cpu')))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  def predict(*text):
41
  tweets = [tweet for tweet in text if tweet]
@@ -55,22 +87,16 @@ def predict(*text):
55
  input_ids = encoded_sentences["input_ids"].to(device)
56
  attention_mask = encoded_sentences["attention_mask"].to(device)
57
 
58
- model_race.eval()
59
- with torch.no_grad():
60
- outputs = model_race(input_ids, attention_mask)
61
- probs = torch.nn.functional.softmax(outputs, dim=1)
62
- predictions = torch.argmax(outputs, dim=1)
63
- predictions = predictions.cpu().numpy()
64
-
65
- output_string = "RACE\n Probabilities:\n"
66
- for i, prob in enumerate(probs[0]):
67
- print(f"{labels[i]} = {round(prob.item() * 100, 2)}%")
68
- output_string += f"{labels[i]} = {round(prob.item() * 100, 2)}%\n"
69
 
70
- print(labels[predictions[0]])
71
- output_string += f"Predicted as: {labels[predictions[0]]}"
 
 
 
72
 
73
- return output_string
74
 
75
 
76
  max_textboxes = 20
 
27
  return self.out(output)
28
 
29
 
30
+ race_labels = {
31
  0: "African American",
32
  1: "Asian",
33
  2: "Latin",
34
  3: "White"
35
  }
36
+
37
+ orientation_labels = {
38
+ 0: "Heterosexual",
39
+ 1: "LGBTQ"
40
+ }
41
+
42
  model_race = RaceClassifier(n_classes=4)
43
  model_race.to(device)
44
+ model_race.load_state_dict(torch.load('best_model_race.pt'))
45
+
46
+ model_orientation = RaceClassifier(n_classes=2)
47
+ model_orientation.to(device)
48
+ model_orientation.load_state_dict(torch.load('best_model_orientation_last.pt'))
49
+
50
+
51
+ def evaluate(model, input, mask):
52
+ model.eval()
53
+ with torch.no_grad():
54
+ outputs = model(input, mask)
55
+ probs = torch.nn.functional.softmax(outputs, dim=1)
56
+ predictions = torch.argmax(outputs, dim=1)
57
+ predictions = predictions.cpu().numpy()
58
+ return probs, predictions
59
+
60
+
61
+ def write_output(probs, predictions, title, labels):
62
+ output_string = f"{title.upper()}\n Probabilities:\n"
63
+ for i, prob in enumerate(probs[0]):
64
+ print(f"{labels[i]} = {round(prob.item() * 100, 2)}%")
65
+ output_string += f"{labels[i]} = {round(prob.item() * 100, 2)}%\n"
66
+
67
+ output_string += f"Predicted as: {labels[predictions[0]]}\n"
68
+
69
+ return output_string
70
+
71
 
72
  def predict(*text):
73
  tweets = [tweet for tweet in text if tweet]
 
87
  input_ids = encoded_sentences["input_ids"].to(device)
88
  attention_mask = encoded_sentences["attention_mask"].to(device)
89
 
90
+ race_probs, race_predictions = evaluate(model_race, input_ids, attention_mask)
91
+ orientation_probs, orientation_predictions = evaluate(model_orientation, input_ids, attention_mask)
 
 
 
 
 
 
 
 
 
92
 
93
+ final_output = str()
94
+ final_output += write_output(race_probs, race_predictions, "race", race_labels)
95
+ final_output += "\n"
96
+ final_output += write_output(orientation_probs, orientation_predictions, "sexual orientation", orientation_labels)
97
+ final_output += "\n"
98
 
99
+ return final_output
100
 
101
 
102
  max_textboxes = 20