CatoEr commited on
Commit
f91bc10
1 Parent(s): 70dd143

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -1
app.py CHANGED
@@ -34,6 +34,24 @@ race_labels = {
34
  3: "White"
35
  }
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  orientation_labels = {
38
  0: "Heterosexual",
39
  1: "LGBTQ"
@@ -43,6 +61,18 @@ model_race = RaceClassifier(n_classes=4)
43
  model_race.to(device)
44
  model_race.load_state_dict(torch.load('best_model_race.pt', map_location=torch.device('cpu')))
45
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  model_orientation = RaceClassifier(n_classes=2)
47
  model_orientation.to(device)
48
  model_orientation.load_state_dict(torch.load('best_model_orientation_last.pt', map_location=torch.device('cpu')))
@@ -88,13 +118,21 @@ def predict(*text):
88
  attention_mask = encoded_sentences["attention_mask"].to(device)
89
 
90
  race_probs, race_predictions = evaluate(model_race, input_ids, attention_mask)
 
 
 
91
  orientation_probs, orientation_predictions = evaluate(model_orientation, input_ids, attention_mask)
92
 
93
  final_output = str()
94
  final_output += write_output(race_probs, race_predictions, "race", race_labels)
95
  final_output += "\n"
96
- final_output += write_output(orientation_probs, orientation_predictions, "sexual orientation", orientation_labels)
97
  final_output += "\n"
 
 
 
 
 
98
 
99
  return final_output
100
 
 
34
  3: "White"
35
  }
36
 
37
+ age_labels = {
38
+ 0: "Adult",
39
+ 1: "Elderly",
40
+ 2: "Young"
41
+ }
42
+
43
+ education_labels = {
44
+ 0: "Educated",
45
+ 1: "Uneducated"
46
+ }
47
+
48
+ gender_labels = {
49
+ 0: "Female",
50
+ 1: "Male",
51
+ 2: "Non-Binary",
52
+ 3: "Transgender"
53
+ }
54
+
55
  orientation_labels = {
56
  0: "Heterosexual",
57
  1: "LGBTQ"
 
61
  model_race.to(device)
62
  model_race.load_state_dict(torch.load('best_model_race.pt', map_location=torch.device('cpu')))
63
 
64
+ model_age = RaceClassifier(n_classes=3)
65
+ model_age.to(device)
66
+ model_age.load_state_dict(torch.load('best_model_age_last.pt', map_location=torch.device('cpu')))
67
+
68
+ model_education = RaceClassifier(n_classes=2)
69
+ model_education.to(device)
70
+ model_education.load_state_dict(torch.load('best_model_education_last.pt', map_location=torch.device('cpu')))
71
+
72
+ model_gender = RaceClassifier(n_classes=4)
73
+ model_gender.to(device)
74
+ model_gender.load_state_dict(torch.load('best_model_gender_last.pt', map_location=torch.device('cpu')))
75
+
76
  model_orientation = RaceClassifier(n_classes=2)
77
  model_orientation.to(device)
78
  model_orientation.load_state_dict(torch.load('best_model_orientation_last.pt', map_location=torch.device('cpu')))
 
118
  attention_mask = encoded_sentences["attention_mask"].to(device)
119
 
120
  race_probs, race_predictions = evaluate(model_race, input_ids, attention_mask)
121
+ age_probs, age_predictions = evaluate(model_age, input_ids, attention_mask)
122
+ education_probs, education_predictions = evaluate(model_education, input_ids, attention_mask)
123
+ gender_probs, gender_predictions = evaluate(model_gender, input_ids, attention_mask)
124
  orientation_probs, orientation_predictions = evaluate(model_orientation, input_ids, attention_mask)
125
 
126
  final_output = str()
127
  final_output += write_output(race_probs, race_predictions, "race", race_labels)
128
  final_output += "\n"
129
+ final_output += write_output(age_probs, age_predictions,"age",age_labels)
130
  final_output += "\n"
131
+ final_output += write_output(education_probs,education_predictions,"education", education_labels)
132
+ final_output += "\n"
133
+ final_output += write_output(gender_probs, gender_predictions, "gender", gender_labels)
134
+ final_output += "\n"
135
+ final_output += write_output(orientation_probs, orientation_predictions, "sexual orientation", orientation_labels)
136
 
137
  return final_output
138