Emily McMilin commited on
Commit
4547bf4
·
1 Parent(s): 68fec63

is_masked optimization, rem dep args

Browse files
Files changed (1) hide show
  1. app.py +11 -5
app.py CHANGED
@@ -346,7 +346,8 @@ def predict_gender_pronouns(
346
  male_gendered_token_ids,
347
  female_gendered_token_ids
348
  )
349
- num_preds = torch.sum(tokenized['ids'][0] == MASK_TOKEN_ID).item()
 
350
 
351
  female_dfs = []
352
  male_dfs = []
@@ -359,7 +360,10 @@ def predict_gender_pronouns(
359
  female_pronoun_preds = []
360
  male_pronoun_preds = []
361
  for indie_var_idx in range(len(tokenized['ids'])):
362
- is_masked = tokenized['ids'][indie_var_idx] == MASK_TOKEN_ID
 
 
 
363
 
364
  ids = tokenized["ids"][indie_var_idx]
365
  atten_mask = tokenized["atten_mask"][indie_var_idx]
@@ -474,15 +478,17 @@ gr.Interface(
474
  outputs=[
475
  gr.outputs.Textbox(
476
  type="auto", label="Sample target text fed to model"),
477
- gr.outputs.Plot(type="auto", label="Plot of softmax probability pronouns predicted female."),
478
- gr.outputs.Plot(type="auto", label="Plot of softmax probability pronouns predicted male."),
479
  gr.outputs.Dataframe(
 
480
  overflow_row_behaviour="show_ends",
481
  label="Table of softmax probability pronouns predicted female",
482
  ),
483
  gr.outputs.Dataframe(
 
484
  overflow_row_behaviour="show_ends",
485
  label="Table of softmax probability pronouns predicted male",
486
  ),
487
  ],
488
- ).launch(debug=True)
 
346
  male_gendered_token_ids,
347
  female_gendered_token_ids
348
  )
349
+ initial_is_masked = tokenized['ids'][0] == MASK_TOKEN_ID
350
+ num_preds = torch.sum(initial_is_masked).item()
351
 
352
  female_dfs = []
353
  male_dfs = []
 
360
  female_pronoun_preds = []
361
  male_pronoun_preds = []
362
  for indie_var_idx in range(len(tokenized['ids'])):
363
+ if dataset == WIKIBIO:
364
+ is_masked = initial_is_masked # injected text all same token length
365
+ else:
366
+ is_masked = tokenized['ids'][indie_var_idx] == MASK_TOKEN_ID
367
 
368
  ids = tokenized["ids"][indie_var_idx]
369
  atten_mask = tokenized["atten_mask"][indie_var_idx]
 
478
  outputs=[
479
  gr.outputs.Textbox(
480
  type="auto", label="Sample target text fed to model"),
481
+ gr.Plot(type="auto", label="Plot of softmax probability pronouns predicted female."),
482
+ gr.Plot(type="auto", label="Plot of softmax probability pronouns predicted male."),
483
  gr.outputs.Dataframe(
484
+ show_label=True,
485
  overflow_row_behaviour="show_ends",
486
  label="Table of softmax probability pronouns predicted female",
487
  ),
488
  gr.outputs.Dataframe(
489
+ show_label=True,
490
  overflow_row_behaviour="show_ends",
491
  label="Table of softmax probability pronouns predicted male",
492
  ),
493
  ],
494
+ ).launch(debug=True, share=True)