emilylearning commited on
Commit
1c2e20b
·
1 Parent(s): 1708735

Rem deprecated gradio attributes; add esp to avoid /0 error; return examples.

Browse files
Files changed (1) hide show
  1. app.py +17 -19
app.py CHANGED
@@ -34,6 +34,7 @@ NON_GENDERED_TOKEN_ID = 30
34
  LABEL_DICT = {GENDER_OPTIONS[0]: 9, GENDER_OPTIONS[1]: -9}
35
  CLASSES = list(LABEL_DICT.keys())
36
  NON_LOSS_TOKEN_ID = -100
 
37
 
38
  # Wikibio conts
39
  START_YEAR = 1800
@@ -289,7 +290,7 @@ def get_tokenized_text_with_metadata(input_text, indie_vars, dataset, male_gende
289
  def get_avg_prob_from_finetuned_outputs(outputs, is_masked, num_preds, gender):
290
  preds = torch.softmax(outputs[0][0].cpu(), dim=1, dtype=torch.double)
291
  pronoun_preds = torch.where(is_masked, preds[:,CLASSES.index(gender)], 0.0)
292
- return round(torch.sum(pronoun_preds).item() / num_preds * 100, DECIMAL_PLACES)
293
 
294
 
295
  def get_avg_prob_from_pipeline_outputs(mask_filled_text, gendered_token_ids, num_preds):
@@ -298,7 +299,7 @@ def get_avg_prob_from_pipeline_outputs(mask_filled_text, gendered_token_ids, num
298
  for pronoun in top_preds])
299
  for top_preds in mask_filled_text
300
  ]
301
- return round(sum(pronoun_preds) / num_preds * 100, DECIMAL_PLACES)
302
 
303
  def get_figure(results, dataset, gender, indie_var_name, include_baseline=True):
304
  colors = ['b', 'g', 'c', 'm', 'y', 'r', 'k'] # assert no
@@ -521,7 +522,13 @@ scientist_example = [
521
  "True",
522
  'She was a very well regarded scientist and her work won many awards.',
523
  ]
524
-
 
 
 
 
 
 
525
  death_date_example = [
526
  WIKIBIO,
527
  BERT_LIKE_MODELS,
@@ -529,8 +536,6 @@ death_date_example = [
529
  "True",
530
  'Died in DATE, she was recognized for her great accomplishments to the field of teaching.'
531
  ]
532
-
533
-
534
  neg_reddit_example = [
535
  REDDIT,
536
  [BERT_LIKE_MODELS[0]],
@@ -539,39 +544,32 @@ neg_reddit_example = [
539
  'She is not good at anything. The work she does is always subpar.'
540
  ]
541
 
542
-
543
  gr.Interface(
544
  fn=predict_gender_pronouns,
545
  inputs=[
546
- gr.inputs.Radio(
547
  [REDDIT, WIKIBIO],
548
- default=WIKIBIO,
549
  type="value",
550
  label="Pick 'conditionally' fine-tuned model.",
551
- optional=False,
552
  ),
553
- gr.inputs.CheckboxGroup(
554
  BERT_LIKE_MODELS,
555
- default=[BERT_LIKE_MODELS[0]],
556
  type="value",
557
  label="Pick optional BERT base uncased model.",
558
  ),
559
- gr.inputs.Dropdown(
560
  ["False", "True"],
561
  label="Normalize BERT-like model's predictions to gendered-only?",
562
- default = "True",
563
  type="index",
564
  ),
565
- gr.inputs.Dropdown(
566
  ["False", "True"],
567
  label="Include baseline predictions (dashed-lines)?",
568
- default = "True",
569
  type="index",
570
  ),
571
- gr.inputs.Textbox(
572
  lines=5,
573
  label="Input Text: Sentence about a single person using some gendered pronouns to refer to them.",
574
- default="She always walked past the building built in DATE on her way to her job as an elementary school teacher.",
575
  ),
576
  ],
577
  outputs=[
@@ -593,5 +591,5 @@ gr.Interface(
593
  title=title,
594
  description=description,
595
  article=article,
596
- # examples=[scientist_example, death_date_example, neg_reddit_example] # Rem for debug
597
- ).launch()
 
34
  LABEL_DICT = {GENDER_OPTIONS[0]: 9, GENDER_OPTIONS[1]: -9}
35
  CLASSES = list(LABEL_DICT.keys())
36
  NON_LOSS_TOKEN_ID = -100
37
+ EPS = 1e-5 # to avoid /0 errors
38
 
39
  # Wikibio conts
40
  START_YEAR = 1800
 
290
  def get_avg_prob_from_finetuned_outputs(outputs, is_masked, num_preds, gender):
291
  preds = torch.softmax(outputs[0][0].cpu(), dim=1, dtype=torch.double)
292
  pronoun_preds = torch.where(is_masked, preds[:,CLASSES.index(gender)], 0.0)
293
+ return round(torch.sum(pronoun_preds).item() / (EPS + num_preds) * 100, DECIMAL_PLACES)
294
 
295
 
296
  def get_avg_prob_from_pipeline_outputs(mask_filled_text, gendered_token_ids, num_preds):
 
299
  for pronoun in top_preds])
300
  for top_preds in mask_filled_text
301
  ]
302
+ return round(sum(pronoun_preds) / (EPS + num_preds) * 100, DECIMAL_PLACES)
303
 
304
  def get_figure(results, dataset, gender, indie_var_name, include_baseline=True):
305
  colors = ['b', 'g', 'c', 'm', 'y', 'r', 'k'] # assert no
 
522
  "True",
523
  'She was a very well regarded scientist and her work won many awards.',
524
  ]
525
+ building_example = [
526
+ WIKIBIO,
527
+ [BERT_LIKE_MODELS[0]],
528
+ "True",
529
+ "True",
530
+ "She always walked past the building built in DATE on her way to her job as an elementary school teacher.",
531
+ ]
532
  death_date_example = [
533
  WIKIBIO,
534
  BERT_LIKE_MODELS,
 
536
  "True",
537
  'Died in DATE, she was recognized for her great accomplishments to the field of teaching.'
538
  ]
 
 
539
  neg_reddit_example = [
540
  REDDIT,
541
  [BERT_LIKE_MODELS[0]],
 
544
  'She is not good at anything. The work she does is always subpar.'
545
  ]
546
 
 
547
  gr.Interface(
548
  fn=predict_gender_pronouns,
549
  inputs=[
550
+ gr.Radio(
551
  [REDDIT, WIKIBIO],
 
552
  type="value",
553
  label="Pick 'conditionally' fine-tuned model.",
 
554
  ),
555
+ gr.CheckboxGroup(
556
  BERT_LIKE_MODELS,
 
557
  type="value",
558
  label="Pick optional BERT base uncased model.",
559
  ),
560
+ gr.Dropdown(
561
  ["False", "True"],
562
  label="Normalize BERT-like model's predictions to gendered-only?",
 
563
  type="index",
564
  ),
565
+ gr.Dropdown(
566
  ["False", "True"],
567
  label="Include baseline predictions (dashed-lines)?",
 
568
  type="index",
569
  ),
570
+ gr.Textbox(
571
  lines=5,
572
  label="Input Text: Sentence about a single person using some gendered pronouns to refer to them.",
 
573
  ),
574
  ],
575
  outputs=[
 
591
  title=title,
592
  description=description,
593
  article=article,
594
+ examples=[scientist_example, building_example, death_date_example, neg_reddit_example]
595
+ ).launch(debug=True)