Spaces:
Runtime error
Runtime error
from typing import Optional | |
import gradio as gr | |
import torch | |
from transformers import AutoModelForTokenClassification, AutoTokenizer | |
from transformers import pipeline | |
import pandas as pd | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from matplotlib.ticker import MaxNLocator | |
# DATASETS | |
REDDIT = 'reddit_finetuned' | |
WIKIBIO = 'wikibio_finetuned' | |
BASE = 'BERT_base' | |
# Play with me, consts | |
SUBREDDIT_CONDITIONING_VARIABLES = ["none", "subreddit"] | |
WIKIBIO_CONDITIONING_VARIABLES = ['none', 'birth_date', 'birth_place'] # EMILY!! | |
BERT_LIKE_MODELS = ["bert", "distilbert"] | |
## Internal constants | |
GENDER_OPTIONS = ['female', 'male'] | |
DECIMAL_PLACES = 1 | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
MAX_TOKEN_LENGTH = 32 | |
NON_LOSS_TOKEN_ID = -100 | |
# Picked ints that will pop out visually during debug | |
NON_GENDERED_TOKEN_ID = 30 | |
LABEL_DICT = {GENDER_OPTIONS[0]: 9, GENDER_OPTIONS[1]: -9} | |
CLASSES = list(LABEL_DICT.keys()) | |
MULTITOKEN_WOMAN_WORD = 'policewoman' | |
MULTITOKEN_MAN_WORD = 'spiderman' | |
# Wikibio conts | |
START_YEAR = 1800 | |
STOP_YEAR = 1999 | |
SPLIT_KEY = "DATE" | |
# Reddit consts | |
# List of randomly selected (tending towards those with seemingly more gender-neutral words) | |
# in order of increasing self-identified female participation. | |
# See http://bburky.com/subredditgenderratios/ , Minimum subreddit size: 100000 | |
# Update: 400000 | |
SUBREDDITS = [ | |
"GlobalOffensive", | |
"pcmasterrace", | |
"nfl", | |
"sports", | |
"The_Donald", | |
"leagueoflegends", | |
"Overwatch", | |
"gonewild", | |
"Futurology", | |
"space", | |
"technology", | |
"gaming", | |
"Jokes", | |
"dataisbeautiful", | |
"woahdude", | |
"askscience", | |
"wow", | |
"anime", | |
"BlackPeopleTwitter", | |
"politics", | |
"pokemon", | |
"worldnews", | |
"reddit.com", | |
"interestingasfuck", | |
"videos", | |
"nottheonion", | |
"television", | |
"science", | |
"atheism", | |
"movies", | |
"gifs", | |
"Music", | |
"trees", | |
"EarthPorn", | |
"GetMotivated", | |
"pokemongo", | |
"news", | |
"fffffffuuuuuuuuuuuu", | |
"Fitness", | |
"Showerthoughts", | |
"OldSchoolCool", | |
"explainlikeimfive", | |
"todayilearned", | |
"gameofthrones", | |
"AdviceAnimals", | |
"DIY", | |
"WTF", | |
"IAmA", | |
"cringepics", | |
"tifu", | |
"mildlyinteresting", | |
"funny", | |
"pics", | |
"LifeProTips", | |
"creepy", | |
"personalfinance", | |
"food", | |
"AskReddit", | |
"books", | |
"aww", | |
"sex", | |
"relationships", | |
] | |
# Fire up the models | |
models_paths = dict() | |
models = dict() | |
base_path = "emilylearning/" | |
# reddit finetuned models: | |
for var in SUBREDDIT_CONDITIONING_VARIABLES: | |
models_paths[(REDDIT, var)] = base_path + f'cond_ft_{var}_on_reddit__prcnt_100__test_run_False' | |
models[(REDDIT, var)] = AutoModelForTokenClassification.from_pretrained( | |
models_paths[(REDDIT, var)] | |
) | |
# wikibio finetuned models: | |
for var in WIKIBIO_CONDITIONING_VARIABLES: | |
models_paths[(WIKIBIO, var)] = base_path + f"cond_ft_{var}_on_wiki_bio__prcnt_100__test_run_False" | |
models[(WIKIBIO, var)] = AutoModelForTokenClassification.from_pretrained( | |
models_paths[(WIKIBIO, var)] | |
) | |
# BERT-like models: | |
for bert_like in BERT_LIKE_MODELS: | |
models_paths[(BASE, bert_like)] = f"{bert_like}-base-uncased" | |
models[(BASE, bert_like)] = pipeline( | |
"fill-mask", model=models_paths[(BASE, bert_like)]) | |
# Tokenizers same for each model, so just grabbing one of them | |
tokenizer = AutoTokenizer.from_pretrained( | |
models_paths[(BASE, BERT_LIKE_MODELS[0])], add_prefix_space=True | |
) | |
MASK_TOKEN_ID = tokenizer.mask_token_id | |
def get_gendered_token_ids(tokenizer): | |
## Set up gendered token constants | |
gendered_lists = [ | |
['he', 'she'], | |
['him', 'her'], | |
['his', 'hers'], | |
["himself", "herself"], | |
['male', 'female'], | |
['man', 'woman'], | |
['men', 'women'], | |
["husband", "wife"], | |
['father', 'mother'], | |
['boyfriend', 'girlfriend'], | |
['brother', 'sister'], | |
["actor", "actress"], | |
] | |
# Generating dicts here for potential later token reconstruction of predictions | |
male_gendered_dict = {list[0]: list for list in gendered_lists} | |
female_gendered_dict = {list[1]: list for list in gendered_lists} | |
male_gendered_token_ids = tokenizer.convert_tokens_to_ids( | |
list(male_gendered_dict.keys())) | |
female_gendered_token_ids = tokenizer.convert_tokens_to_ids( | |
list(female_gendered_dict.keys()) | |
) | |
# Below technique is used to grab second token in a multi-token word | |
# There must be a better way... | |
multiword_woman_token_ids = tokenizer.encode( | |
MULTITOKEN_WOMAN_WORD, add_special_tokens=False) | |
assert len(multiword_woman_token_ids) == 2 | |
subword_woman_token_id = multiword_woman_token_ids[1] | |
multiword_man_token_ids = tokenizer.encode( | |
MULTITOKEN_MAN_WORD, add_special_tokens=False) | |
assert len(multiword_man_token_ids) == 2 | |
subword_man_token_id = multiword_man_token_ids[1] | |
male_gendered_token_ids.append(subword_man_token_id) | |
female_gendered_token_ids.append(subword_woman_token_id) | |
assert tokenizer.unk_token_id not in male_gendered_token_ids | |
assert tokenizer.unk_token_id not in female_gendered_token_ids | |
return male_gendered_token_ids, female_gendered_token_ids | |
def tokenize_and_append_metadata(text, tokenizer, female_gendered_token_ids, male_gendered_token_ids): | |
"""Tokenize text and mask/flag 'gendered_tokens_ids' in token_ids and labels.""" | |
label_list = list(LABEL_DICT.values()) | |
assert label_list[0] == LABEL_DICT["female"], "LABEL_DICT not an ordered dict" | |
label2id = {label: idx for idx, label in enumerate(label_list)} | |
tokenized = tokenizer( | |
text, | |
truncation=True, | |
padding='max_length', | |
max_length=MAX_TOKEN_LENGTH, | |
) | |
# Finding the gender pronouns in the tokens | |
token_ids = tokenized["input_ids"] | |
female_tags = torch.tensor( | |
[ | |
LABEL_DICT["female"] | |
if id in female_gendered_token_ids | |
else NON_GENDERED_TOKEN_ID | |
for id in token_ids | |
] | |
) | |
male_tags = torch.tensor( | |
[ | |
LABEL_DICT["male"] | |
if id in male_gendered_token_ids | |
else NON_GENDERED_TOKEN_ID | |
for id in token_ids | |
] | |
) | |
# Labeling and masking out occurrences of gendered pronouns | |
labels = torch.tensor([NON_LOSS_TOKEN_ID] * len(token_ids)) | |
labels = torch.where( | |
female_tags == LABEL_DICT["female"], | |
label2id[LABEL_DICT["female"]], | |
NON_LOSS_TOKEN_ID, | |
) | |
labels = torch.where( | |
male_tags == LABEL_DICT["male"], label2id[LABEL_DICT["male"]], labels | |
) | |
masked_token_ids = torch.where( | |
female_tags == LABEL_DICT["female"], MASK_TOKEN_ID, torch.tensor( | |
token_ids) | |
) | |
masked_token_ids = torch.where( | |
male_tags == LABEL_DICT["male"], MASK_TOKEN_ID, masked_token_ids | |
) | |
tokenized["input_ids"] = masked_token_ids | |
tokenized["labels"] = labels | |
return tokenized | |
def get_tokenized_text_with_metadata(input_text, indie_vars, dataset, male_gendered_token_ids, female_gendered_token_ids): | |
"""Construct dict of tokenized texts with each year injected into the text.""" | |
if dataset == WIKIBIO: | |
text_portions = input_text.split(SPLIT_KEY) | |
# If no SPLIT_KEY found in text, add space for metadata and whitespaces | |
if len(text_portions) == 1: | |
text_portions = ['Born in ', f" {text_portions[0]}"] | |
tokenized_w_metadata = {'ids': [], 'atten_mask': [], 'toks': [], 'labels': []} | |
for indie_var in indie_vars: | |
if dataset == WIKIBIO: | |
target_text = f"{indie_var}".join(text_portions) | |
else: | |
target_text = f"r/{indie_var}: {input_text}" | |
tokenized_sample = tokenize_and_append_metadata( | |
target_text, | |
tokenizer, | |
male_gendered_token_ids, | |
female_gendered_token_ids | |
) | |
tokenized_w_metadata['ids'].append(tokenized_sample["input_ids"]) | |
tokenized_w_metadata['atten_mask'].append( | |
torch.tensor(tokenized_sample["attention_mask"])) | |
tokenized_w_metadata['toks'].append( | |
tokenizer.convert_ids_to_tokens(tokenized_sample["input_ids"])) | |
tokenized_w_metadata['labels'].append(tokenized_sample["labels"]) | |
return tokenized_w_metadata | |
def get_avg_prob_from_finetuned_outputs(outputs, is_masked, num_preds, gender): | |
preds = torch.softmax(outputs[0][0].cpu(), dim=1, dtype=torch.double) | |
pronoun_preds = torch.where(is_masked, preds[:,CLASSES.index(gender)], 0.0) | |
return round(torch.sum(pronoun_preds).item() / num_preds * 100, DECIMAL_PLACES) | |
def get_avg_prob_from_pipeline_outputs(mask_filled_text, gendered_token_ids, num_preds): | |
pronoun_preds = [sum([ | |
pronoun["score"] if pronoun["token"] in gendered_token_ids else 0.0 | |
for pronoun in top_preds]) | |
for top_preds in mask_filled_text | |
] | |
return round(sum(pronoun_preds) / num_preds * 100, DECIMAL_PLACES) | |
def get_figure(results, dataset, gender, indie_var_name): | |
fig, ax = plt.subplots() | |
ax.plot(results) | |
if dataset == REDDIT: | |
ax.set_xlabel("Subreddit prepended to input text") | |
ax.xaxis.set_major_locator(MaxNLocator(6)) | |
else: | |
ax.set_xlabel("Date injected into input text") | |
ax.set_title(f"Softmax probability of pronouns predicted {gender}\n by model type vs {indie_var_name}.") | |
ax.set_ylabel(f"Avg softmax prob for {gender} pronouns") | |
ax.legend(list(results.columns)) | |
return fig | |
def predict_gender_pronouns( | |
dataset, | |
bert_like_models, | |
normalizing, | |
input_text, | |
): | |
"""Run inference on input_text for each model type, returning df and plots of precentage | |
of gender pronouns predicted as female and male in each target text. | |
""" | |
male_gendered_token_ids, female_gendered_token_ids = get_gendered_token_ids(tokenizer) | |
if dataset == REDDIT: | |
indie_vars = SUBREDDITS | |
conditioning_variables = SUBREDDIT_CONDITIONING_VARIABLES | |
indie_var_name = 'subreddit' | |
else: | |
indie_vars = np.linspace(START_YEAR, STOP_YEAR, 20).astype(int) | |
conditioning_variables = WIKIBIO_CONDITIONING_VARIABLES | |
indie_var_name = 'date' | |
tokenized = get_tokenized_text_with_metadata( | |
input_text, | |
indie_vars, | |
dataset, | |
male_gendered_token_ids, | |
female_gendered_token_ids | |
) | |
initial_is_masked = tokenized['ids'][0] == MASK_TOKEN_ID | |
num_preds = torch.sum(initial_is_masked).item() | |
female_dfs = [] | |
male_dfs = [] | |
female_dfs.append(pd.DataFrame({indie_var_name: indie_vars})) | |
male_dfs.append(pd.DataFrame({indie_var_name: indie_vars})) | |
for var in conditioning_variables: | |
prefix = f"{var}_metadata" | |
model = models[(dataset, var)] | |
female_pronoun_preds = [] | |
male_pronoun_preds = [] | |
for indie_var_idx in range(len(tokenized['ids'])): | |
if dataset == WIKIBIO: | |
is_masked = initial_is_masked # injected text all same token length | |
else: | |
is_masked = tokenized['ids'][indie_var_idx] == MASK_TOKEN_ID | |
ids = tokenized["ids"][indie_var_idx] | |
atten_mask = tokenized["atten_mask"][indie_var_idx] | |
labels = tokenized["labels"][indie_var_idx] | |
with torch.no_grad(): | |
outputs = model(ids.unsqueeze(dim=0), | |
atten_mask.unsqueeze(dim=0)) | |
female_pronoun_preds.append( | |
get_avg_prob_from_finetuned_outputs(outputs,is_masked, num_preds, "female") | |
) | |
male_pronoun_preds.append( | |
get_avg_prob_from_finetuned_outputs(outputs,is_masked, num_preds, "male") | |
) | |
female_dfs.append(pd.DataFrame({prefix : female_pronoun_preds})) | |
male_dfs.append(pd.DataFrame({prefix : male_pronoun_preds})) | |
for bert_like in bert_like_models: | |
prefix = f"base_{bert_like}" | |
model = models[(BASE, bert_like)] | |
female_pronoun_preds = [] | |
male_pronoun_preds = [] | |
for indie_var_idx in range(len(tokenized['ids'])): | |
toks = tokenized["toks"][indie_var_idx] | |
target_text_for_bert = ' '.join( | |
toks[1:-1]) # Removing [CLS] and [SEP] | |
mask_filled_text = model(target_text_for_bert) | |
# Quick hack as realized return type based on how many MASKs in text. | |
if type(mask_filled_text[0]) is not list: | |
mask_filled_text = [mask_filled_text] | |
female_pronoun_preds.append(get_avg_prob_from_pipeline_outputs( | |
mask_filled_text, | |
female_gendered_token_ids, | |
num_preds | |
)) | |
male_pronoun_preds.append(get_avg_prob_from_pipeline_outputs( | |
mask_filled_text, | |
male_gendered_token_ids, | |
num_preds | |
)) | |
if normalizing: | |
total_gendered_probs = np.add(female_pronoun_preds, male_pronoun_preds) | |
female_pronoun_preds = np.around( | |
np.divide(female_pronoun_preds, total_gendered_probs)*100, | |
decimals=DECIMAL_PLACES | |
) | |
male_pronoun_preds = np.around( | |
np.divide(male_pronoun_preds, total_gendered_probs)*100, | |
decimals=DECIMAL_PLACES | |
) | |
female_dfs.append(pd.DataFrame({prefix : female_pronoun_preds})) | |
male_dfs.append(pd.DataFrame({prefix : male_pronoun_preds})) | |
# To display to user as an example | |
toks = tokenized["toks"][0] | |
target_text_w_masks = ' '.join(toks[1:-1]) | |
# Plots / dataframe for display to users | |
female_results = pd.concat(female_dfs, axis=1).set_index(indie_var_name) | |
male_results = pd.concat(male_dfs, axis=1).set_index(indie_var_name) | |
female_fig = get_figure(female_results, dataset, "female", indie_var_name) | |
male_fig = get_figure(male_results, dataset, "male", indie_var_name) | |
female_results.reset_index(inplace=True) # Gradio Dataframe doesn't 'see' index? | |
male_results.reset_index(inplace=True) # Gradio Dataframe doesn't 'see' index? | |
return ( | |
target_text_w_masks, | |
female_fig, | |
male_fig, | |
female_results, | |
male_results, | |
) | |
gr.Interface( | |
fn=predict_gender_pronouns, | |
inputs=[ | |
gr.inputs.Radio( | |
[REDDIT, WIKIBIO], | |
default=WIKIBIO, | |
type="value", | |
label="Pick 'conditionally' fine-tuned model.", | |
optional=False, | |
), | |
gr.inputs.CheckboxGroup( | |
BERT_LIKE_MODELS, | |
default=[BERT_LIKE_MODELS[0]], | |
type="value", | |
label="Pick optional BERT base uncased model.", | |
), | |
gr.inputs.Dropdown( | |
["False", "True"], | |
label="Normalize BERT-like model's predictions to gendered-only?", | |
default = "True", | |
type="index", | |
), | |
gr.inputs.Textbox( | |
lines=5, | |
label="Input Text: Sentence about a single person using some gendered pronouns to refer to them.", | |
default="She always walked past the building built in DATE on her way to her job as an elementary school teacher.", | |
), | |
], | |
outputs=[ | |
gr.Textbox( | |
type="auto", label="Sample target text fed to model"), | |
gr.Plot(type="auto", label="Plot of softmax probability pronouns predicted female."), | |
gr.Plot(type="auto", label="Plot of softmax probability pronouns predicted male."), | |
gr.Dataframe( | |
show_label=True, | |
overflow_row_behaviour="show_ends", | |
label="Table of softmax probability pronouns predicted female", | |
), | |
gr.Dataframe( | |
show_label=True, | |
overflow_row_behaviour="show_ends", | |
label="Table of softmax probability pronouns predicted male", | |
), | |
], | |
).launch(debug=True, share=True) |