from typing import Optional import gradio as gr import torch from transformers import AutoModelForTokenClassification, AutoTokenizer from transformers import pipeline import pandas as pd import numpy as np import matplotlib.pyplot as plt from matplotlib.ticker import MaxNLocator # DATASETS REDDIT = 'reddit_finetuned' WIKIBIO = 'wikibio_finetuned' BASE = 'BERT_base' # Play with me, consts SUBREDDIT_CONDITIONING_VARIABLES = ["none", "subreddit"] WIKIBIO_CONDITIONING_VARIABLES = ['none', 'birth_date'] BERT_LIKE_MODELS = ["bert", "distilbert"] MAX_TOKEN_LENGTH = 32 # Internal markers for rendering BASELINE_MARKER = 'baseline' REDDIT_BASELINE_TEXT = ' ' WIKIBIO_BASELINE_TEXT = 'date' ## Internal constants from training GENDER_OPTIONS = ['female', 'male'] DECIMAL_PLACES = 1 MULTITOKEN_WOMAN_WORD = 'policewoman' MULTITOKEN_MAN_WORD = 'spiderman' # Picked ints that will pop out visually during debug NON_GENDERED_TOKEN_ID = 30 LABEL_DICT = {GENDER_OPTIONS[0]: 9, GENDER_OPTIONS[1]: -9} CLASSES = list(LABEL_DICT.keys()) NON_LOSS_TOKEN_ID = -100 # Wikibio conts START_YEAR = 1800 STOP_YEAR = 1999 SPLIT_KEY = "DATE" # Reddit consts # List of randomly selected (tending towards those with seemingly more gender-neutral words) # in order of increasing self-identified female participation. # See http://bburky.com/subredditgenderratios/ , Minimum subreddit size: 400000 SUBREDDITS = [ "GlobalOffensive", "pcmasterrace", "nfl", "sports", "The_Donald", "leagueoflegends", "Overwatch", "gonewild", "Futurology", "space", "technology", "gaming", "Jokes", "dataisbeautiful", "woahdude", "askscience", "wow", "anime", "BlackPeopleTwitter", "politics", "pokemon", "worldnews", "reddit.com", "interestingasfuck", "videos", "nottheonion", "television", "science", "atheism", "movies", "gifs", "Music", "trees", "EarthPorn", "GetMotivated", "pokemongo", "news", "fffffffuuuuuuuuuuuu", "Fitness", "Showerthoughts", "OldSchoolCool", "explainlikeimfive", "todayilearned", "gameofthrones", "AdviceAnimals", "DIY", "WTF", "IAmA", "cringepics", "tifu", "mildlyinteresting", "funny", "pics", "LifeProTips", "creepy", "personalfinance", "food", "AskReddit", "books", "aww", "sex", "relationships", ] # Fire up the models models_paths = dict() models = dict() base_path = "emilylearning/" # reddit finetuned models: for var in SUBREDDIT_CONDITIONING_VARIABLES: models_paths[(REDDIT, var)] = base_path + f'cond_ft_{var}_on_reddit__prcnt_100__test_run_False' models[(REDDIT, var)] = AutoModelForTokenClassification.from_pretrained( models_paths[(REDDIT, var)] ) # wikibio finetuned models: for var in WIKIBIO_CONDITIONING_VARIABLES: models_paths[(WIKIBIO, var)] = base_path + f"cond_ft_{var}_on_wiki_bio__prcnt_100__test_run_False" models[(WIKIBIO, var)] = AutoModelForTokenClassification.from_pretrained( models_paths[(WIKIBIO, var)] ) # BERT-like models: for bert_like in BERT_LIKE_MODELS: models_paths[(BASE, bert_like)] = f"{bert_like}-base-uncased" models[(BASE, bert_like)] = pipeline( "fill-mask", model=models_paths[(BASE, bert_like)]) # Tokenizers same for each model, so just grabbing one of them tokenizer = AutoTokenizer.from_pretrained( models_paths[(BASE, BERT_LIKE_MODELS[0])], add_prefix_space=True ) MASK_TOKEN_ID = tokenizer.mask_token_id def get_gendered_token_ids(tokenizer): ## Set up gendered token constants gendered_lists = [ ['he', 'she'], ['him', 'her'], ['his', 'hers'], ["himself", "herself"], ['male', 'female'], ['man', 'woman'], ['men', 'women'], ["husband", "wife"], ['father', 'mother'], ['boyfriend', 'girlfriend'], ['brother', 'sister'], ["actor", "actress"], ] # Generating dicts here for potential later token reconstruction of predictions male_gendered_dict = {list[0]: list for list in gendered_lists} female_gendered_dict = {list[1]: list for list in gendered_lists} male_gendered_token_ids = tokenizer.convert_tokens_to_ids( list(male_gendered_dict.keys())) female_gendered_token_ids = tokenizer.convert_tokens_to_ids( list(female_gendered_dict.keys()) ) # Below technique is used to grab second token in a multi-token word # There must be a better way... multiword_woman_token_ids = tokenizer.encode( MULTITOKEN_WOMAN_WORD, add_special_tokens=False) assert len(multiword_woman_token_ids) == 2 subword_woman_token_id = multiword_woman_token_ids[1] multiword_man_token_ids = tokenizer.encode( MULTITOKEN_MAN_WORD, add_special_tokens=False) assert len(multiword_man_token_ids) == 2 subword_man_token_id = multiword_man_token_ids[1] male_gendered_token_ids.append(subword_man_token_id) female_gendered_token_ids.append(subword_woman_token_id) assert tokenizer.unk_token_id not in male_gendered_token_ids assert tokenizer.unk_token_id not in female_gendered_token_ids return male_gendered_token_ids, female_gendered_token_ids def tokenize_and_append_metadata(text, tokenizer, female_gendered_token_ids, male_gendered_token_ids): """Tokenize text and mask/flag 'gendered_tokens_ids' in token_ids and labels.""" label_list = list(LABEL_DICT.values()) assert label_list[0] == LABEL_DICT["female"], "LABEL_DICT not an ordered dict" label2id = {label: idx for idx, label in enumerate(label_list)} tokenized = tokenizer( text, truncation=True, padding='max_length', max_length=MAX_TOKEN_LENGTH, ) # Finding the gender pronouns in the tokens token_ids = tokenized["input_ids"] female_tags = torch.tensor( [ LABEL_DICT["female"] if id in female_gendered_token_ids else NON_GENDERED_TOKEN_ID for id in token_ids ] ) male_tags = torch.tensor( [ LABEL_DICT["male"] if id in male_gendered_token_ids else NON_GENDERED_TOKEN_ID for id in token_ids ] ) # Labeling and masking out occurrences of gendered pronouns labels = torch.tensor([NON_LOSS_TOKEN_ID] * len(token_ids)) labels = torch.where( female_tags == LABEL_DICT["female"], label2id[LABEL_DICT["female"]], NON_LOSS_TOKEN_ID, ) labels = torch.where( male_tags == LABEL_DICT["male"], label2id[LABEL_DICT["male"]], labels ) masked_token_ids = torch.where( female_tags == LABEL_DICT["female"], MASK_TOKEN_ID, torch.tensor( token_ids) ) masked_token_ids = torch.where( male_tags == LABEL_DICT["male"], MASK_TOKEN_ID, masked_token_ids ) tokenized["input_ids"] = masked_token_ids tokenized["labels"] = labels return tokenized def get_tokenized_text_with_metadata(input_text, indie_vars, dataset, male_gendered_token_ids, female_gendered_token_ids): """Construct dict of tokenized texts with each year injected into the text.""" if dataset == WIKIBIO: text_portions = input_text.split(SPLIT_KEY) # If no SPLIT_KEY found in text, add space for metadata and whitespaces if len(text_portions) == 1: text_portions = ['Born in ', f" {text_portions[0]}"] tokenized_w_metadata = {'ids': [], 'atten_mask': [], 'toks': [], 'labels': []} for indie_var in indie_vars: if dataset == WIKIBIO: if indie_var == BASELINE_MARKER: indie_var = WIKIBIO_BASELINE_TEXT target_text = f"{indie_var}".join(text_portions) else: if indie_var == BASELINE_MARKER: indie_var = REDDIT_BASELINE_TEXT target_text = f"r/{indie_var}: {input_text}" tokenized_sample = tokenize_and_append_metadata( target_text, tokenizer, male_gendered_token_ids, female_gendered_token_ids ) tokenized_w_metadata['ids'].append(tokenized_sample["input_ids"]) tokenized_w_metadata['atten_mask'].append( torch.tensor(tokenized_sample["attention_mask"])) tokenized_w_metadata['toks'].append( tokenizer.convert_ids_to_tokens(tokenized_sample["input_ids"])) tokenized_w_metadata['labels'].append(tokenized_sample["labels"]) return tokenized_w_metadata def get_avg_prob_from_finetuned_outputs(outputs, is_masked, num_preds, gender): preds = torch.softmax(outputs[0][0].cpu(), dim=1, dtype=torch.double) pronoun_preds = torch.where(is_masked, preds[:,CLASSES.index(gender)], 0.0) return round(torch.sum(pronoun_preds).item() / num_preds * 100, DECIMAL_PLACES) def get_avg_prob_from_pipeline_outputs(mask_filled_text, gendered_token_ids, num_preds): pronoun_preds = [sum([ pronoun["score"] if pronoun["token"] in gendered_token_ids else 0.0 for pronoun in top_preds]) for top_preds in mask_filled_text ] return round(sum(pronoun_preds) / num_preds * 100, DECIMAL_PLACES) def get_figure(results, dataset, gender, indie_var_name, include_baseline=True): colors = ['b', 'g', 'c', 'm', 'y', 'r', 'k'] # assert no # Grab then remove baselines from df baseline = results.loc[BASELINE_MARKER] results.drop(index=BASELINE_MARKER, axis=1, inplace=True) fig, ax = plt.subplots() for i, col in enumerate(results.columns): ax.plot(results[col], color=colors[i])#, color=colors) if include_baseline == True: for i, (name, value) in enumerate(baseline.items()): if name == indie_var_name: continue ax.axhline(value, ls='--', color=colors[i]) if dataset == REDDIT: ax.set_xlabel("Subreddit prepended to input text") ax.xaxis.set_major_locator(MaxNLocator(6)) else: ax.set_xlabel("Date injected into input text") ax.set_title(f"Softmax probability of pronouns predicted {gender}\n by model type vs {indie_var_name}.") ax.set_ylabel(f"Avg softmax prob for {gender} pronouns") ax.legend(list(results.columns)) return fig def predict_gender_pronouns( dataset, bert_like_models, normalizing, include_baseline, input_text, ): """Run inference on input_text for each model type, returning df and plots of precentage of gender pronouns predicted as female and male in each target text. """ male_gendered_token_ids, female_gendered_token_ids = get_gendered_token_ids(tokenizer) if dataset == REDDIT: indie_vars = [BASELINE_MARKER] + SUBREDDITS conditioning_variables = SUBREDDIT_CONDITIONING_VARIABLES indie_var_name = 'subreddit' else: indie_vars = [BASELINE_MARKER] + np.linspace(START_YEAR, STOP_YEAR, 20).astype(int).tolist() conditioning_variables = WIKIBIO_CONDITIONING_VARIABLES indie_var_name = 'date' tokenized = get_tokenized_text_with_metadata( input_text, indie_vars, dataset, male_gendered_token_ids, female_gendered_token_ids ) initial_is_masked = tokenized['ids'][0] == MASK_TOKEN_ID num_preds = torch.sum(initial_is_masked).item() female_dfs = [] male_dfs = [] female_dfs.append(pd.DataFrame({indie_var_name: indie_vars})) male_dfs.append(pd.DataFrame({indie_var_name: indie_vars})) for var in conditioning_variables: prefix = f"{var}_metadata" model = models[(dataset, var)] female_pronoun_preds = [] male_pronoun_preds = [] for indie_var_idx in range(len(tokenized['ids'])): if dataset == WIKIBIO: is_masked = initial_is_masked # injected text all same token length else: is_masked = tokenized['ids'][indie_var_idx] == MASK_TOKEN_ID ids = tokenized["ids"][indie_var_idx] atten_mask = tokenized["atten_mask"][indie_var_idx] labels = tokenized["labels"][indie_var_idx] with torch.no_grad(): outputs = model(ids.unsqueeze(dim=0), atten_mask.unsqueeze(dim=0)) female_pronoun_preds.append( get_avg_prob_from_finetuned_outputs(outputs,is_masked, num_preds, "female") ) male_pronoun_preds.append( get_avg_prob_from_finetuned_outputs(outputs,is_masked, num_preds, "male") ) female_dfs.append(pd.DataFrame({prefix : female_pronoun_preds})) male_dfs.append(pd.DataFrame({prefix : male_pronoun_preds})) for bert_like in bert_like_models: prefix = f"base_{bert_like}" model = models[(BASE, bert_like)] female_pronoun_preds = [] male_pronoun_preds = [] for indie_var_idx in range(len(tokenized['ids'])): toks = tokenized["toks"][indie_var_idx] target_text_for_bert = ' '.join( toks[1:-1]) # Removing [CLS] and [SEP] mask_filled_text = model(target_text_for_bert) # Quick hack as realized return type based on how many MASKs in text. if type(mask_filled_text[0]) is not list: mask_filled_text = [mask_filled_text] female_pronoun_preds.append(get_avg_prob_from_pipeline_outputs( mask_filled_text, female_gendered_token_ids, num_preds )) male_pronoun_preds.append(get_avg_prob_from_pipeline_outputs( mask_filled_text, male_gendered_token_ids, num_preds )) if normalizing: total_gendered_probs = np.add(female_pronoun_preds, male_pronoun_preds) female_pronoun_preds = np.around( np.divide(female_pronoun_preds, total_gendered_probs)*100, decimals=DECIMAL_PLACES ) male_pronoun_preds = np.around( np.divide(male_pronoun_preds, total_gendered_probs)*100, decimals=DECIMAL_PLACES ) female_dfs.append(pd.DataFrame({prefix : female_pronoun_preds})) male_dfs.append(pd.DataFrame({prefix : male_pronoun_preds})) # Pick a sample to display to user as an example toks = tokenized["toks"][3] target_text_w_masks = ' '.join(toks[1:-1]) # Removing [CLS] and [SEP] # Plots / dataframe for display to users female_results = pd.concat(female_dfs, axis=1).set_index(indie_var_name) male_results = pd.concat(male_dfs, axis=1).set_index(indie_var_name) female_fig = get_figure(female_results, dataset, "female", indie_var_name, include_baseline) male_fig = get_figure(male_results, dataset, "male", indie_var_name, include_baseline) female_results.reset_index(inplace=True) # Gradio Dataframe doesn't 'see' index? male_results.reset_index(inplace=True) # Gradio Dataframe doesn't 'see' index? return ( target_text_w_masks, female_fig, male_fig, female_results, male_results, ) gr.Interface( fn=predict_gender_pronouns, inputs=[ gr.inputs.Radio( [REDDIT, WIKIBIO], default=WIKIBIO, type="value", label="Pick 'conditionally' fine-tuned model.", optional=False, ), gr.inputs.CheckboxGroup( BERT_LIKE_MODELS, default=[BERT_LIKE_MODELS[0]], type="value", label="Pick optional BERT base uncased model.", ), gr.inputs.Dropdown( ["False", "True"], label="Normalize BERT-like model's predictions to gendered-only?", default = "True", type="index", ), gr.inputs.Dropdown( ["False", "True"], label="Include baseline predictions (dashed-lines)?", default = "True", type="index", ), gr.inputs.Textbox( lines=5, label="Input Text: Sentence about a single person using some gendered pronouns to refer to them.", default="She always walked past the building built in DATE on her way to her job as an elementary school teacher.", ), ], outputs=[ gr.Textbox( type="auto", label="Sample target text fed to model"), gr.Plot(type="auto", label="Plot of softmax probability pronouns predicted female."), gr.Plot(type="auto", label="Plot of softmax probability pronouns predicted male."), gr.Dataframe( show_label=True, overflow_row_behaviour="show_ends", label="Table of softmax probability pronouns predicted female", ), gr.Dataframe( show_label=True, overflow_row_behaviour="show_ends", label="Table of softmax probability pronouns predicted male", ), ], ).launch(debug=True, share=True)