Spaces:
Runtime error
Runtime error
Commit
·
1c2e20b
1
Parent(s):
1708735
Rem deprecated gradio attributes; add esp to avoid /0 error; return examples.
Browse files
app.py
CHANGED
@@ -34,6 +34,7 @@ NON_GENDERED_TOKEN_ID = 30
|
|
34 |
LABEL_DICT = {GENDER_OPTIONS[0]: 9, GENDER_OPTIONS[1]: -9}
|
35 |
CLASSES = list(LABEL_DICT.keys())
|
36 |
NON_LOSS_TOKEN_ID = -100
|
|
|
37 |
|
38 |
# Wikibio conts
|
39 |
START_YEAR = 1800
|
@@ -289,7 +290,7 @@ def get_tokenized_text_with_metadata(input_text, indie_vars, dataset, male_gende
|
|
289 |
def get_avg_prob_from_finetuned_outputs(outputs, is_masked, num_preds, gender):
|
290 |
preds = torch.softmax(outputs[0][0].cpu(), dim=1, dtype=torch.double)
|
291 |
pronoun_preds = torch.where(is_masked, preds[:,CLASSES.index(gender)], 0.0)
|
292 |
-
return round(torch.sum(pronoun_preds).item() / num_preds * 100, DECIMAL_PLACES)
|
293 |
|
294 |
|
295 |
def get_avg_prob_from_pipeline_outputs(mask_filled_text, gendered_token_ids, num_preds):
|
@@ -298,7 +299,7 @@ def get_avg_prob_from_pipeline_outputs(mask_filled_text, gendered_token_ids, num
|
|
298 |
for pronoun in top_preds])
|
299 |
for top_preds in mask_filled_text
|
300 |
]
|
301 |
-
return round(sum(pronoun_preds) / num_preds * 100, DECIMAL_PLACES)
|
302 |
|
303 |
def get_figure(results, dataset, gender, indie_var_name, include_baseline=True):
|
304 |
colors = ['b', 'g', 'c', 'm', 'y', 'r', 'k'] # assert no
|
@@ -521,7 +522,13 @@ scientist_example = [
|
|
521 |
"True",
|
522 |
'She was a very well regarded scientist and her work won many awards.',
|
523 |
]
|
524 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
525 |
death_date_example = [
|
526 |
WIKIBIO,
|
527 |
BERT_LIKE_MODELS,
|
@@ -529,8 +536,6 @@ death_date_example = [
|
|
529 |
"True",
|
530 |
'Died in DATE, she was recognized for her great accomplishments to the field of teaching.'
|
531 |
]
|
532 |
-
|
533 |
-
|
534 |
neg_reddit_example = [
|
535 |
REDDIT,
|
536 |
[BERT_LIKE_MODELS[0]],
|
@@ -539,39 +544,32 @@ neg_reddit_example = [
|
|
539 |
'She is not good at anything. The work she does is always subpar.'
|
540 |
]
|
541 |
|
542 |
-
|
543 |
gr.Interface(
|
544 |
fn=predict_gender_pronouns,
|
545 |
inputs=[
|
546 |
-
gr.
|
547 |
[REDDIT, WIKIBIO],
|
548 |
-
default=WIKIBIO,
|
549 |
type="value",
|
550 |
label="Pick 'conditionally' fine-tuned model.",
|
551 |
-
optional=False,
|
552 |
),
|
553 |
-
gr.
|
554 |
BERT_LIKE_MODELS,
|
555 |
-
default=[BERT_LIKE_MODELS[0]],
|
556 |
type="value",
|
557 |
label="Pick optional BERT base uncased model.",
|
558 |
),
|
559 |
-
gr.
|
560 |
["False", "True"],
|
561 |
label="Normalize BERT-like model's predictions to gendered-only?",
|
562 |
-
default = "True",
|
563 |
type="index",
|
564 |
),
|
565 |
-
gr.
|
566 |
["False", "True"],
|
567 |
label="Include baseline predictions (dashed-lines)?",
|
568 |
-
default = "True",
|
569 |
type="index",
|
570 |
),
|
571 |
-
gr.
|
572 |
lines=5,
|
573 |
label="Input Text: Sentence about a single person using some gendered pronouns to refer to them.",
|
574 |
-
default="She always walked past the building built in DATE on her way to her job as an elementary school teacher.",
|
575 |
),
|
576 |
],
|
577 |
outputs=[
|
@@ -593,5 +591,5 @@ gr.Interface(
|
|
593 |
title=title,
|
594 |
description=description,
|
595 |
article=article,
|
596 |
-
|
597 |
-
).launch()
|
|
|
34 |
LABEL_DICT = {GENDER_OPTIONS[0]: 9, GENDER_OPTIONS[1]: -9}
|
35 |
CLASSES = list(LABEL_DICT.keys())
|
36 |
NON_LOSS_TOKEN_ID = -100
|
37 |
+
EPS = 1e-5 # to avoid /0 errors
|
38 |
|
39 |
# Wikibio conts
|
40 |
START_YEAR = 1800
|
|
|
290 |
def get_avg_prob_from_finetuned_outputs(outputs, is_masked, num_preds, gender):
|
291 |
preds = torch.softmax(outputs[0][0].cpu(), dim=1, dtype=torch.double)
|
292 |
pronoun_preds = torch.where(is_masked, preds[:,CLASSES.index(gender)], 0.0)
|
293 |
+
return round(torch.sum(pronoun_preds).item() / (EPS + num_preds) * 100, DECIMAL_PLACES)
|
294 |
|
295 |
|
296 |
def get_avg_prob_from_pipeline_outputs(mask_filled_text, gendered_token_ids, num_preds):
|
|
|
299 |
for pronoun in top_preds])
|
300 |
for top_preds in mask_filled_text
|
301 |
]
|
302 |
+
return round(sum(pronoun_preds) / (EPS + num_preds) * 100, DECIMAL_PLACES)
|
303 |
|
304 |
def get_figure(results, dataset, gender, indie_var_name, include_baseline=True):
|
305 |
colors = ['b', 'g', 'c', 'm', 'y', 'r', 'k'] # assert no
|
|
|
522 |
"True",
|
523 |
'She was a very well regarded scientist and her work won many awards.',
|
524 |
]
|
525 |
+
building_example = [
|
526 |
+
WIKIBIO,
|
527 |
+
[BERT_LIKE_MODELS[0]],
|
528 |
+
"True",
|
529 |
+
"True",
|
530 |
+
"She always walked past the building built in DATE on her way to her job as an elementary school teacher.",
|
531 |
+
]
|
532 |
death_date_example = [
|
533 |
WIKIBIO,
|
534 |
BERT_LIKE_MODELS,
|
|
|
536 |
"True",
|
537 |
'Died in DATE, she was recognized for her great accomplishments to the field of teaching.'
|
538 |
]
|
|
|
|
|
539 |
neg_reddit_example = [
|
540 |
REDDIT,
|
541 |
[BERT_LIKE_MODELS[0]],
|
|
|
544 |
'She is not good at anything. The work she does is always subpar.'
|
545 |
]
|
546 |
|
|
|
547 |
gr.Interface(
|
548 |
fn=predict_gender_pronouns,
|
549 |
inputs=[
|
550 |
+
gr.Radio(
|
551 |
[REDDIT, WIKIBIO],
|
|
|
552 |
type="value",
|
553 |
label="Pick 'conditionally' fine-tuned model.",
|
|
|
554 |
),
|
555 |
+
gr.CheckboxGroup(
|
556 |
BERT_LIKE_MODELS,
|
|
|
557 |
type="value",
|
558 |
label="Pick optional BERT base uncased model.",
|
559 |
),
|
560 |
+
gr.Dropdown(
|
561 |
["False", "True"],
|
562 |
label="Normalize BERT-like model's predictions to gendered-only?",
|
|
|
563 |
type="index",
|
564 |
),
|
565 |
+
gr.Dropdown(
|
566 |
["False", "True"],
|
567 |
label="Include baseline predictions (dashed-lines)?",
|
|
|
568 |
type="index",
|
569 |
),
|
570 |
+
gr.Textbox(
|
571 |
lines=5,
|
572 |
label="Input Text: Sentence about a single person using some gendered pronouns to refer to them.",
|
|
|
573 |
),
|
574 |
],
|
575 |
outputs=[
|
|
|
591 |
title=title,
|
592 |
description=description,
|
593 |
article=article,
|
594 |
+
examples=[scientist_example, building_example, death_date_example, neg_reddit_example]
|
595 |
+
).launch(debug=True)
|