Emily McMilin
new API updates
38542cb
raw
history blame
16.2 kB
from typing import Optional
import gradio as gr
import torch
from transformers import AutoModelForTokenClassification, AutoTokenizer
from transformers import pipeline
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
# DATASETS
REDDIT = 'reddit_finetuned'
WIKIBIO = 'wikibio_finetuned'
BASE = 'BERT_base'
# Play with me, consts
SUBREDDIT_CONDITIONING_VARIABLES = ["none", "subreddit"]
WIKIBIO_CONDITIONING_VARIABLES = ['none', 'birth_date', 'birth_place'] # EMILY!!
BERT_LIKE_MODELS = ["bert", "distilbert"]
## Internal constants
GENDER_OPTIONS = ['female', 'male']
DECIMAL_PLACES = 1
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MAX_TOKEN_LENGTH = 32
NON_LOSS_TOKEN_ID = -100
# Picked ints that will pop out visually during debug
NON_GENDERED_TOKEN_ID = 30
LABEL_DICT = {GENDER_OPTIONS[0]: 9, GENDER_OPTIONS[1]: -9}
CLASSES = list(LABEL_DICT.keys())
MULTITOKEN_WOMAN_WORD = 'policewoman'
MULTITOKEN_MAN_WORD = 'spiderman'
# Wikibio conts
START_YEAR = 1800
STOP_YEAR = 1999
SPLIT_KEY = "DATE"
# Reddit consts
# List of randomly selected (tending towards those with seemingly more gender-neutral words)
# in order of increasing self-identified female participation.
# See http://bburky.com/subredditgenderratios/ , Minimum subreddit size: 100000
# Update: 400000
SUBREDDITS = [
"GlobalOffensive",
"pcmasterrace",
"nfl",
"sports",
"The_Donald",
"leagueoflegends",
"Overwatch",
"gonewild",
"Futurology",
"space",
"technology",
"gaming",
"Jokes",
"dataisbeautiful",
"woahdude",
"askscience",
"wow",
"anime",
"BlackPeopleTwitter",
"politics",
"pokemon",
"worldnews",
"reddit.com",
"interestingasfuck",
"videos",
"nottheonion",
"television",
"science",
"atheism",
"movies",
"gifs",
"Music",
"trees",
"EarthPorn",
"GetMotivated",
"pokemongo",
"news",
"fffffffuuuuuuuuuuuu",
"Fitness",
"Showerthoughts",
"OldSchoolCool",
"explainlikeimfive",
"todayilearned",
"gameofthrones",
"AdviceAnimals",
"DIY",
"WTF",
"IAmA",
"cringepics",
"tifu",
"mildlyinteresting",
"funny",
"pics",
"LifeProTips",
"creepy",
"personalfinance",
"food",
"AskReddit",
"books",
"aww",
"sex",
"relationships",
]
# Fire up the models
models_paths = dict()
models = dict()
base_path = "emilylearning/"
# reddit finetuned models:
for var in SUBREDDIT_CONDITIONING_VARIABLES:
models_paths[(REDDIT, var)] = base_path + f'cond_ft_{var}_on_reddit__prcnt_100__test_run_False'
models[(REDDIT, var)] = AutoModelForTokenClassification.from_pretrained(
models_paths[(REDDIT, var)]
)
# wikibio finetuned models:
for var in WIKIBIO_CONDITIONING_VARIABLES:
models_paths[(WIKIBIO, var)] = base_path + f"cond_ft_{var}_on_wiki_bio__prcnt_100__test_run_False"
models[(WIKIBIO, var)] = AutoModelForTokenClassification.from_pretrained(
models_paths[(WIKIBIO, var)]
)
# BERT-like models:
for bert_like in BERT_LIKE_MODELS:
models_paths[(BASE, bert_like)] = f"{bert_like}-base-uncased"
models[(BASE, bert_like)] = pipeline(
"fill-mask", model=models_paths[(BASE, bert_like)])
# Tokenizers same for each model, so just grabbing one of them
tokenizer = AutoTokenizer.from_pretrained(
models_paths[(BASE, BERT_LIKE_MODELS[0])], add_prefix_space=True
)
MASK_TOKEN_ID = tokenizer.mask_token_id
def get_gendered_token_ids(tokenizer):
## Set up gendered token constants
gendered_lists = [
['he', 'she'],
['him', 'her'],
['his', 'hers'],
["himself", "herself"],
['male', 'female'],
['man', 'woman'],
['men', 'women'],
["husband", "wife"],
['father', 'mother'],
['boyfriend', 'girlfriend'],
['brother', 'sister'],
["actor", "actress"],
]
# Generating dicts here for potential later token reconstruction of predictions
male_gendered_dict = {list[0]: list for list in gendered_lists}
female_gendered_dict = {list[1]: list for list in gendered_lists}
male_gendered_token_ids = tokenizer.convert_tokens_to_ids(
list(male_gendered_dict.keys()))
female_gendered_token_ids = tokenizer.convert_tokens_to_ids(
list(female_gendered_dict.keys())
)
# Below technique is used to grab second token in a multi-token word
# There must be a better way...
multiword_woman_token_ids = tokenizer.encode(
MULTITOKEN_WOMAN_WORD, add_special_tokens=False)
assert len(multiword_woman_token_ids) == 2
subword_woman_token_id = multiword_woman_token_ids[1]
multiword_man_token_ids = tokenizer.encode(
MULTITOKEN_MAN_WORD, add_special_tokens=False)
assert len(multiword_man_token_ids) == 2
subword_man_token_id = multiword_man_token_ids[1]
male_gendered_token_ids.append(subword_man_token_id)
female_gendered_token_ids.append(subword_woman_token_id)
assert tokenizer.unk_token_id not in male_gendered_token_ids
assert tokenizer.unk_token_id not in female_gendered_token_ids
return male_gendered_token_ids, female_gendered_token_ids
def tokenize_and_append_metadata(text, tokenizer, female_gendered_token_ids, male_gendered_token_ids):
"""Tokenize text and mask/flag 'gendered_tokens_ids' in token_ids and labels."""
label_list = list(LABEL_DICT.values())
assert label_list[0] == LABEL_DICT["female"], "LABEL_DICT not an ordered dict"
label2id = {label: idx for idx, label in enumerate(label_list)}
tokenized = tokenizer(
text,
truncation=True,
padding='max_length',
max_length=MAX_TOKEN_LENGTH,
)
# Finding the gender pronouns in the tokens
token_ids = tokenized["input_ids"]
female_tags = torch.tensor(
[
LABEL_DICT["female"]
if id in female_gendered_token_ids
else NON_GENDERED_TOKEN_ID
for id in token_ids
]
)
male_tags = torch.tensor(
[
LABEL_DICT["male"]
if id in male_gendered_token_ids
else NON_GENDERED_TOKEN_ID
for id in token_ids
]
)
# Labeling and masking out occurrences of gendered pronouns
labels = torch.tensor([NON_LOSS_TOKEN_ID] * len(token_ids))
labels = torch.where(
female_tags == LABEL_DICT["female"],
label2id[LABEL_DICT["female"]],
NON_LOSS_TOKEN_ID,
)
labels = torch.where(
male_tags == LABEL_DICT["male"], label2id[LABEL_DICT["male"]], labels
)
masked_token_ids = torch.where(
female_tags == LABEL_DICT["female"], MASK_TOKEN_ID, torch.tensor(
token_ids)
)
masked_token_ids = torch.where(
male_tags == LABEL_DICT["male"], MASK_TOKEN_ID, masked_token_ids
)
tokenized["input_ids"] = masked_token_ids
tokenized["labels"] = labels
return tokenized
def get_tokenized_text_with_metadata(input_text, indie_vars, dataset, male_gendered_token_ids, female_gendered_token_ids):
"""Construct dict of tokenized texts with each year injected into the text."""
if dataset == WIKIBIO:
text_portions = input_text.split(SPLIT_KEY)
# If no SPLIT_KEY found in text, add space for metadata and whitespaces
if len(text_portions) == 1:
text_portions = ['Born in ', f" {text_portions[0]}"]
tokenized_w_metadata = {'ids': [], 'atten_mask': [], 'toks': [], 'labels': []}
for indie_var in indie_vars:
if dataset == WIKIBIO:
target_text = f"{indie_var}".join(text_portions)
else:
target_text = f"r/{indie_var}: {input_text}"
tokenized_sample = tokenize_and_append_metadata(
target_text,
tokenizer,
male_gendered_token_ids,
female_gendered_token_ids
)
tokenized_w_metadata['ids'].append(tokenized_sample["input_ids"])
tokenized_w_metadata['atten_mask'].append(
torch.tensor(tokenized_sample["attention_mask"]))
tokenized_w_metadata['toks'].append(
tokenizer.convert_ids_to_tokens(tokenized_sample["input_ids"]))
tokenized_w_metadata['labels'].append(tokenized_sample["labels"])
return tokenized_w_metadata
def get_avg_prob_from_finetuned_outputs(outputs, is_masked, num_preds, gender):
preds = torch.softmax(outputs[0][0].cpu(), dim=1, dtype=torch.double)
pronoun_preds = torch.where(is_masked, preds[:,CLASSES.index(gender)], 0.0)
return round(torch.sum(pronoun_preds).item() / num_preds * 100, DECIMAL_PLACES)
def get_avg_prob_from_pipeline_outputs(mask_filled_text, gendered_token_ids, num_preds):
pronoun_preds = [sum([
pronoun["score"] if pronoun["token"] in gendered_token_ids else 0.0
for pronoun in top_preds])
for top_preds in mask_filled_text
]
return round(sum(pronoun_preds) / num_preds * 100, DECIMAL_PLACES)
def get_figure(results, dataset, gender, indie_var_name):
fig, ax = plt.subplots()
ax.plot(results)
if dataset == REDDIT:
ax.set_xlabel("Subreddit prepended to input text")
ax.xaxis.set_major_locator(MaxNLocator(6))
else:
ax.set_xlabel("Date injected into input text")
ax.set_title(f"Softmax probability of pronouns predicted {gender}\n by model type vs {indie_var_name}.")
ax.set_ylabel(f"Avg softmax prob for {gender} pronouns")
ax.legend(list(results.columns))
return fig
def predict_gender_pronouns(
dataset,
bert_like_models,
normalizing,
input_text,
):
"""Run inference on input_text for each model type, returning df and plots of precentage
of gender pronouns predicted as female and male in each target text.
"""
male_gendered_token_ids, female_gendered_token_ids = get_gendered_token_ids(tokenizer)
if dataset == REDDIT:
indie_vars = SUBREDDITS
conditioning_variables = SUBREDDIT_CONDITIONING_VARIABLES
indie_var_name = 'subreddit'
else:
indie_vars = np.linspace(START_YEAR, STOP_YEAR, 20).astype(int)
conditioning_variables = WIKIBIO_CONDITIONING_VARIABLES
indie_var_name = 'date'
tokenized = get_tokenized_text_with_metadata(
input_text,
indie_vars,
dataset,
male_gendered_token_ids,
female_gendered_token_ids
)
initial_is_masked = tokenized['ids'][0] == MASK_TOKEN_ID
num_preds = torch.sum(initial_is_masked).item()
female_dfs = []
male_dfs = []
female_dfs.append(pd.DataFrame({indie_var_name: indie_vars}))
male_dfs.append(pd.DataFrame({indie_var_name: indie_vars}))
for var in conditioning_variables:
prefix = f"{var}_metadata"
model = models[(dataset, var)]
female_pronoun_preds = []
male_pronoun_preds = []
for indie_var_idx in range(len(tokenized['ids'])):
if dataset == WIKIBIO:
is_masked = initial_is_masked # injected text all same token length
else:
is_masked = tokenized['ids'][indie_var_idx] == MASK_TOKEN_ID
ids = tokenized["ids"][indie_var_idx]
atten_mask = tokenized["atten_mask"][indie_var_idx]
labels = tokenized["labels"][indie_var_idx]
with torch.no_grad():
outputs = model(ids.unsqueeze(dim=0),
atten_mask.unsqueeze(dim=0))
female_pronoun_preds.append(
get_avg_prob_from_finetuned_outputs(outputs,is_masked, num_preds, "female")
)
male_pronoun_preds.append(
get_avg_prob_from_finetuned_outputs(outputs,is_masked, num_preds, "male")
)
female_dfs.append(pd.DataFrame({prefix : female_pronoun_preds}))
male_dfs.append(pd.DataFrame({prefix : male_pronoun_preds}))
for bert_like in bert_like_models:
prefix = f"base_{bert_like}"
model = models[(BASE, bert_like)]
female_pronoun_preds = []
male_pronoun_preds = []
for indie_var_idx in range(len(tokenized['ids'])):
toks = tokenized["toks"][indie_var_idx]
target_text_for_bert = ' '.join(
toks[1:-1]) # Removing [CLS] and [SEP]
mask_filled_text = model(target_text_for_bert)
# Quick hack as realized return type based on how many MASKs in text.
if type(mask_filled_text[0]) is not list:
mask_filled_text = [mask_filled_text]
female_pronoun_preds.append(get_avg_prob_from_pipeline_outputs(
mask_filled_text,
female_gendered_token_ids,
num_preds
))
male_pronoun_preds.append(get_avg_prob_from_pipeline_outputs(
mask_filled_text,
male_gendered_token_ids,
num_preds
))
if normalizing:
total_gendered_probs = np.add(female_pronoun_preds, male_pronoun_preds)
female_pronoun_preds = np.around(
np.divide(female_pronoun_preds, total_gendered_probs)*100,
decimals=DECIMAL_PLACES
)
male_pronoun_preds = np.around(
np.divide(male_pronoun_preds, total_gendered_probs)*100,
decimals=DECIMAL_PLACES
)
female_dfs.append(pd.DataFrame({prefix : female_pronoun_preds}))
male_dfs.append(pd.DataFrame({prefix : male_pronoun_preds}))
# To display to user as an example
toks = tokenized["toks"][0]
target_text_w_masks = ' '.join(toks[1:-1])
# Plots / dataframe for display to users
female_results = pd.concat(female_dfs, axis=1).set_index(indie_var_name)
male_results = pd.concat(male_dfs, axis=1).set_index(indie_var_name)
female_fig = get_figure(female_results, dataset, "female", indie_var_name)
male_fig = get_figure(male_results, dataset, "male", indie_var_name)
female_results.reset_index(inplace=True) # Gradio Dataframe doesn't 'see' index?
male_results.reset_index(inplace=True) # Gradio Dataframe doesn't 'see' index?
return (
target_text_w_masks,
female_fig,
male_fig,
female_results,
male_results,
)
gr.Interface(
fn=predict_gender_pronouns,
inputs=[
gr.inputs.Radio(
[REDDIT, WIKIBIO],
default=WIKIBIO,
type="value",
label="Pick 'conditionally' fine-tuned model.",
optional=False,
),
gr.inputs.CheckboxGroup(
BERT_LIKE_MODELS,
default=[BERT_LIKE_MODELS[0]],
type="value",
label="Pick optional BERT base uncased model.",
),
gr.inputs.Dropdown(
["False", "True"],
label="Normalize BERT-like model's predictions to gendered-only?",
default = "True",
type="index",
),
gr.inputs.Textbox(
lines=5,
label="Input Text: Sentence about a single person using some gendered pronouns to refer to them.",
default="She always walked past the building built in DATE on her way to her job as an elementary school teacher.",
),
],
outputs=[
gr.Textbox(
type="auto", label="Sample target text fed to model"),
gr.Plot(type="auto", label="Plot of softmax probability pronouns predicted female."),
gr.Plot(type="auto", label="Plot of softmax probability pronouns predicted male."),
gr.Dataframe(
show_label=True,
overflow_row_behaviour="show_ends",
label="Table of softmax probability pronouns predicted female",
),
gr.Dataframe(
show_label=True,
overflow_row_behaviour="show_ends",
label="Table of softmax probability pronouns predicted male",
),
],
).launch(debug=True, share=True)