import gradio as gr import matplotlib.pyplot as plt import numpy as np import pandas as pd import torch from matplotlib.ticker import MaxNLocator from transformers import AutoModelForTokenClassification, AutoTokenizer from transformers import pipeline # DATASETS REDDIT = 'reddit_finetuned' WIKIBIO = 'wikibio_finetuned' BASE = 'BERT_base' # Play with me, consts SUBREDDIT_CONDITIONING_VARIABLES = ["none", "subreddit"] WIKIBIO_CONDITIONING_VARIABLES = ['none', 'birth_date'] BERT_LIKE_MODELS = ["bert", "distilbert"] MAX_TOKEN_LENGTH = 32 # Internal markers for rendering BASELINE_MARKER = 'baseline' REDDIT_BASELINE_TEXT = ' ' WIKIBIO_BASELINE_TEXT = 'date' ## Internal constants from training GENDER_OPTIONS = ['female', 'male'] DECIMAL_PLACES = 1 MULTITOKEN_WOMAN_WORD = 'policewoman' MULTITOKEN_MAN_WORD = 'spiderman' # Picked ints that will pop out visually during debug NON_GENDERED_TOKEN_ID = 30 LABEL_DICT = {GENDER_OPTIONS[0]: 9, GENDER_OPTIONS[1]: -9} CLASSES = list(LABEL_DICT.keys()) NON_LOSS_TOKEN_ID = -100 EPS = 1e-5 # to avoid /0 errors # Wikibio conts START_YEAR = 1800 STOP_YEAR = 1999 SPLIT_KEY = "DATE" # Reddit consts # List of randomly selected (tending towards those with seemingly more gender-neutral words) # in order of increasing self-identified female participation. # See , Minimum subreddit size: 400000 SUBREDDITS = [ "GlobalOffensive", "pcmasterrace", "nfl", "sports", "The_Donald", "leagueoflegends", "Overwatch", "gonewild", "Futurology", "space", "technology", "gaming", "Jokes", "dataisbeautiful", "woahdude", "askscience", "wow", "anime", "BlackPeopleTwitter", "politics", "pokemon", "worldnews", "", "interestingasfuck", "videos", "nottheonion", "television", "science", "atheism", "movies", "gifs", "Music", "trees", "EarthPorn", "GetMotivated", "pokemongo", "news", "fffffffuuuuuuuuuuuu", "Fitness", "Showerthoughts", "OldSchoolCool", "explainlikeimfive", "todayilearned", "gameofthrones", "AdviceAnimals", "DIY", "WTF", "IAmA", "cringepics", "tifu", "mildlyinteresting", "funny", "pics", "LifeProTips", "creepy", "personalfinance", "food", "AskReddit", "books", "aww", "sex", "relationships", ] # Fire up the models models_paths = dict() models = dict() base_path = "emilylearning/" # reddit finetuned models: for var in SUBREDDIT_CONDITIONING_VARIABLES: models_paths[(REDDIT, var)] = base_path + f'cond_ft_{var}_on_reddit__prcnt_100__test_run_False' models[(REDDIT, var)] = AutoModelForTokenClassification.from_pretrained( models_paths[(REDDIT, var)] ) # wikibio finetuned models: for var in WIKIBIO_CONDITIONING_VARIABLES: models_paths[(WIKIBIO, var)] = base_path + f"cond_ft_{var}_on_wiki_bio__prcnt_100__test_run_False" models[(WIKIBIO, var)] = AutoModelForTokenClassification.from_pretrained( models_paths[(WIKIBIO, var)] ) # BERT-like models: for bert_like in BERT_LIKE_MODELS: models_paths[(BASE, bert_like)] = f"{bert_like}-base-uncased" models[(BASE, bert_like)] = pipeline( "fill-mask", model=models_paths[(BASE, bert_like)]) # Tokenizers same for each model, so just grabbing one of them tokenizer = AutoTokenizer.from_pretrained( models_paths[(BASE, BERT_LIKE_MODELS[0])], add_prefix_space=True ) MASK_TOKEN_ID = tokenizer.mask_token_id def get_gendered_token_ids(tokenizer): ## Set up gendered token constants gendered_lists = [ ['he', 'she'], ['him', 'her'], ['his', 'hers'], ["himself", "herself"], ['male', 'female'], ['man', 'woman'], ['men', 'women'], ["husband", "wife"], ['father', 'mother'], ['boyfriend', 'girlfriend'], ['brother', 'sister'], ["actor", "actress"], ] # Generating dicts here for potential later token reconstruction of predictions male_gendered_dict = {list[0]: list for list in gendered_lists} female_gendered_dict = {list[1]: list for list in gendered_lists} male_gendered_token_ids = tokenizer.convert_tokens_to_ids( list(male_gendered_dict.keys())) female_gendered_token_ids = tokenizer.convert_tokens_to_ids( list(female_gendered_dict.keys()) ) # Below technique is used to grab second token in a multi-token word # There must be a better way... multiword_woman_token_ids = tokenizer.encode( MULTITOKEN_WOMAN_WORD, add_special_tokens=False) assert len(multiword_woman_token_ids) == 2 subword_woman_token_id = multiword_woman_token_ids[1] multiword_man_token_ids = tokenizer.encode( MULTITOKEN_MAN_WORD, add_special_tokens=False) assert len(multiword_man_token_ids) == 2 subword_man_token_id = multiword_man_token_ids[1] male_gendered_token_ids.append(subword_man_token_id) female_gendered_token_ids.append(subword_woman_token_id) # Confirming all tokens are in vocab assert tokenizer.unk_token_id not in male_gendered_token_ids assert tokenizer.unk_token_id not in female_gendered_token_ids return male_gendered_token_ids, female_gendered_token_ids def tokenize_and_append_metadata(text, tokenizer, female_gendered_token_ids, male_gendered_token_ids): """Tokenize text and mask/flag 'gendered_tokens_ids' in token_ids and labels.""" label_list = list(LABEL_DICT.values()) assert label_list[0] == LABEL_DICT["female"], "LABEL_DICT not an ordered dict" label2id = {label: idx for idx, label in enumerate(label_list)} tokenized = tokenizer( text, truncation=True, padding='max_length', max_length=MAX_TOKEN_LENGTH, ) # Finding the gender pronouns in the tokens token_ids = tokenized["input_ids"] female_tags = torch.tensor( [ LABEL_DICT["female"] if id in female_gendered_token_ids else NON_GENDERED_TOKEN_ID for id in token_ids ] ) male_tags = torch.tensor( [ LABEL_DICT["male"] if id in male_gendered_token_ids else NON_GENDERED_TOKEN_ID for id in token_ids ] ) # Labeling and masking out occurrences of gendered pronouns labels = torch.tensor([NON_LOSS_TOKEN_ID] * len(token_ids)) labels = torch.where( female_tags == LABEL_DICT["female"], label2id[LABEL_DICT["female"]], NON_LOSS_TOKEN_ID, ) labels = torch.where( male_tags == LABEL_DICT["male"], label2id[LABEL_DICT["male"]], labels ) masked_token_ids = torch.where( female_tags == LABEL_DICT["female"], MASK_TOKEN_ID, torch.tensor( token_ids) ) masked_token_ids = torch.where( male_tags == LABEL_DICT["male"], MASK_TOKEN_ID, masked_token_ids ) tokenized["input_ids"] = masked_token_ids tokenized["labels"] = labels return tokenized def get_tokenized_text_with_metadata(input_text, indie_vars, dataset, male_gendered_token_ids, female_gendered_token_ids): """Construct dict of tokenized texts with each year injected into the text.""" if dataset == WIKIBIO: text_portions = input_text.split(SPLIT_KEY) # If no SPLIT_KEY found in text, add space for metadata and whitespaces if len(text_portions) == 1: text_portions = ['Born in ', f" {text_portions[0]}"] tokenized_w_metadata = {'ids': [], 'atten_mask': [], 'toks': [], 'labels': []} for indie_var in indie_vars: if dataset == WIKIBIO: if indie_var == BASELINE_MARKER: indie_var = WIKIBIO_BASELINE_TEXT target_text = f"{indie_var}".join(text_portions) else: if indie_var == BASELINE_MARKER: indie_var = REDDIT_BASELINE_TEXT target_text = f"r/{indie_var}: {input_text}" tokenized_sample = tokenize_and_append_metadata( target_text, tokenizer, male_gendered_token_ids, female_gendered_token_ids ) tokenized_w_metadata['ids'].append(tokenized_sample["input_ids"]) tokenized_w_metadata['atten_mask'].append( torch.tensor(tokenized_sample["attention_mask"])) tokenized_w_metadata['toks'].append( tokenizer.convert_ids_to_tokens(tokenized_sample["input_ids"])) tokenized_w_metadata['labels'].append(tokenized_sample["labels"]) return tokenized_w_metadata def get_avg_prob_from_finetuned_outputs(outputs, is_masked, num_preds, gender): preds = torch.softmax(outputs[0][0].cpu(), dim=1, dtype=torch.double) pronoun_preds = torch.where(is_masked, preds[:,CLASSES.index(gender)], 0.0) return round(torch.sum(pronoun_preds).item() / (EPS + num_preds) * 100, DECIMAL_PLACES) def get_avg_prob_from_pipeline_outputs(mask_filled_text, gendered_token_ids, num_preds): pronoun_preds = [sum([ pronoun["score"] if pronoun["token"] in gendered_token_ids else 0.0 for pronoun in top_preds]) for top_preds in mask_filled_text ] return round(sum(pronoun_preds) / (EPS + num_preds) * 100, DECIMAL_PLACES) def get_figure(results, dataset, gender, indie_var_name, include_baseline=True): colors = ['b', 'g', 'c', 'm', 'y', 'r', 'k'] # assert no # Grab then remove baselines from df results_to_plot = results.drop(index=BASELINE_MARKER, axis=1) fig, ax = plt.subplots() for i, col in enumerate(results.columns): ax.plot(results_to_plot[col], color=colors[i])#, color=colors) if include_baseline == True: baseline = results.loc[BASELINE_MARKER] for i, (name, value) in enumerate(baseline.items()): if name == indie_var_name: continue ax.axhline(value, ls='--', color=colors[i]) if dataset == REDDIT: ax.set_xlabel("Subreddit prepended to input text") ax.xaxis.set_major_locator(MaxNLocator(6)) else: ax.set_xlabel("Date injected into input text") ax.set_title(f"Softmax probability of pronouns predicted {gender}\n by model type vs {indie_var_name}.") ax.set_ylabel(f"Avg softmax prob for {gender} pronouns") ax.legend(list(results_to_plot.columns)) return fig def predict_gender_pronouns( dataset, bert_like_models, normalizing, include_baseline, input_text, ): """Run inference on input_text for each model type, returning df and plots of precentage of gender pronouns predicted as female and male in each target text. """ male_gendered_token_ids, female_gendered_token_ids = get_gendered_token_ids(tokenizer) if dataset == REDDIT: indie_vars = [BASELINE_MARKER] + SUBREDDITS conditioning_variables = SUBREDDIT_CONDITIONING_VARIABLES indie_var_name = 'subreddit' else: indie_vars = [BASELINE_MARKER] + np.linspace(START_YEAR, STOP_YEAR, 20).astype(int).tolist() conditioning_variables = WIKIBIO_CONDITIONING_VARIABLES indie_var_name = 'date' tokenized = get_tokenized_text_with_metadata( input_text, indie_vars, dataset, male_gendered_token_ids, female_gendered_token_ids ) initial_is_masked = tokenized['ids'][0] == MASK_TOKEN_ID num_preds = torch.sum(initial_is_masked).item() female_dfs = [] male_dfs = [] female_dfs.append(pd.DataFrame({indie_var_name: indie_vars})) male_dfs.append(pd.DataFrame({indie_var_name: indie_vars})) for var in conditioning_variables: prefix = f"{var}_metadata" model = models[(dataset, var)] female_pronoun_preds = [] male_pronoun_preds = [] for indie_var_idx in range(len(tokenized['ids'])): if dataset == WIKIBIO: is_masked = initial_is_masked # injected text all same token length else: is_masked = tokenized['ids'][indie_var_idx] == MASK_TOKEN_ID ids = tokenized["ids"][indie_var_idx] atten_mask = tokenized["atten_mask"][indie_var_idx] labels = tokenized["labels"][indie_var_idx] with torch.no_grad(): outputs = model(ids.unsqueeze(dim=0), atten_mask.unsqueeze(dim=0)) female_pronoun_preds.append( get_avg_prob_from_finetuned_outputs(outputs,is_masked, num_preds, "female") ) male_pronoun_preds.append( get_avg_prob_from_finetuned_outputs(outputs,is_masked, num_preds, "male") ) female_dfs.append(pd.DataFrame({prefix : female_pronoun_preds})) male_dfs.append(pd.DataFrame({prefix : male_pronoun_preds})) for bert_like in bert_like_models: prefix = f"base_{bert_like}" model = models[(BASE, bert_like)] female_pronoun_preds = [] male_pronoun_preds = [] for indie_var_idx in range(len(tokenized['ids'])): toks = tokenized["toks"][indie_var_idx] target_text_for_bert = ' '.join( toks[1:-1]) # Removing [CLS] and [SEP] mask_filled_text = model(target_text_for_bert) # Quick hack as realized return type based on how many MASKs in text. if type(mask_filled_text[0]) is not list: mask_filled_text = [mask_filled_text] female_pronoun_preds.append(get_avg_prob_from_pipeline_outputs( mask_filled_text, female_gendered_token_ids, num_preds )) male_pronoun_preds.append(get_avg_prob_from_pipeline_outputs( mask_filled_text, male_gendered_token_ids, num_preds )) if normalizing: total_gendered_probs = np.add(female_pronoun_preds, male_pronoun_preds) female_pronoun_preds = np.around( np.divide(female_pronoun_preds, total_gendered_probs)*100, decimals=DECIMAL_PLACES ) male_pronoun_preds = np.around( np.divide(male_pronoun_preds, total_gendered_probs)*100, decimals=DECIMAL_PLACES ) female_dfs.append(pd.DataFrame({prefix : female_pronoun_preds})) male_dfs.append(pd.DataFrame({prefix : male_pronoun_preds})) # Pick a sample to display to user as an example toks = tokenized["toks"][3] target_text_w_masks = ' '.join(toks[1:-1]) # Removing [CLS] and [SEP] # Plots / dataframe for display to users female_results = pd.concat(female_dfs, axis=1).set_index(indie_var_name) male_results = pd.concat(male_dfs, axis=1).set_index(indie_var_name) female_fig = get_figure(female_results, dataset, "female", indie_var_name, include_baseline) female_results.reset_index(inplace=True) # Gradio Dataframe doesn't 'see' index? male_fig = get_figure(male_results, dataset, "male", indie_var_name, include_baseline) male_results.reset_index(inplace=True) # Gradio Dataframe doesn't 'see' index? return ( target_text_w_masks, female_fig, female_results, male_fig, male_results, ) title = "Causing Gender Pronouns" description = """ ## Intro This work investigates how we can cause LLMs to change their gender pronoun predictions. We do this by first considering plausible data generating processes for the type of datasets upon which the LLMs were pretrained. The data generating process is usually not revealed by the dataset alone, and instead requires (ideally well-informed) assumptions about what may have caused both the features and the labels to appear in the dataset. An example of an assumed data generating process for the [wiki-bio dataset]( is shown in the form of a causal DAG in [causing_gender_pronouns](, an earlier but better documented version of this Space. Once we have a causal DAG, we can identify likely confounding variables that have causal influences on both the features and the labels in a model. We can include those variables in our model train-time and/or at inference-time to produce spurious correlations, exposing potentially surprising learned relationships between the features and labels. ## This demo Here we can experiment with these spurious correlations in both BERT and BERT-like pre-trained models as well as two types of fine-tuned models. These fine-tuned models were trained with a specific gender-pronoun-predicting task, and with potentially confounding metadata either excluded (`none_metadata` variants) or included (`birth_date_metadata` and `subreddit_metadata` variants) in the text samples at train time. See [source code]( for more details. For the gender-pronoun-predicting task, the following non-gender-neutral terms are `[MASKED]` for gender-prediction. ``` gendered_lists = [ ['he', 'she'], ['him', 'her'], ['his', 'hers'], ["himself", "herself"], ['male', 'female'], ['man', 'woman'], ['men', 'women'], ["husband", "wife"], ['father', 'mother'], ['boyfriend', 'girlfriend'], ['brother', 'sister'], ["actor", "actress"], ["##man", "##woman"]] ``` What we are looking for in this demo is a dose-response relationship, where a larger intervention in the treatment (the text injected in the inference sample, displayed on the x-axis) produces a larger response in the output (the average softmax probability of a gendered pronoun, displayed on the y-axis). For the `wiki-bio` models the x-axis is simply the `date`, ranging from 1800 - 1999, which is injected into the text. For the `reddit` models, it is the `subreddit` name, which is prepended to the inference text samples, with subreddits that have a larger percentage of self-reported female commentors increasing to the right (following the methodology in, we just copied over the entire list of subreddits that had a Minimum subreddit size of 400,000). ## What you can do: - Pick a fine-tuned model type. - Pick optional BERT, and/or BERT-like model. - Decide if you want to see BERT-like model’s predictions normalized to only those predictions that are gendered (ignoring their gender-neutral predictions). - Note, DistilBERT in particular does a great job at predicting gender-neutral terms, so this normalization can look pretty noisy. - This normalization is not required for our fine-tuned models, which are forced to make a binary prediction. - Decide if you want to see the baseline prediction (from neutral or no text injection into your text sample) in the plot. - Come up with a text sample! - Any term included that is from the `gendered_lists` above will be masked out for prediction. - In the case of `wiki-bio`, any appearance of the word `DATE` will be replaced with the year shown on the x-axis. (If no `DATE` is included, the phrase `Born in DATE…` will be prepended to your text sample.) - In the case of `reddit`, the `subreddit` names shown on the x-axis (or shown more clearly in the associated dataframe) will be prepended to your text sample). - Don’t forget to hit the [Submit] button! - Using the provided examples at the bottom may result in a pre-cached dataframe being loaded, but the plot will only be calculated after you hit [Submit]. Note: if app seems frozen, refreshing webpage may help. Sorry for the inconvenience. Will debug soon. """ article = "The source code to generate the fine-tuned models can be found/reproduced here:" scientist_example = [ REDDIT, [BERT_LIKE_MODELS[0]], "True", "True", 'She was a very well regarded scientist and her work won many awards.', ] building_example = [ WIKIBIO, [BERT_LIKE_MODELS[0]], "True", "True", "She always walked past the building built in DATE on her way to her job as an elementary school teacher.", ] death_date_example = [ WIKIBIO, BERT_LIKE_MODELS, "False", "True", 'Died in DATE, she was recognized for her great accomplishments to the field of teaching.' ] neg_reddit_example = [ REDDIT, [BERT_LIKE_MODELS[0]], "False", "True", 'She is not good at anything. The work she does is always subpar.' ] gr.Interface( fn=predict_gender_pronouns, inputs=[ gr.Radio( [REDDIT, WIKIBIO], type="value", label="Pick 'conditionally' fine-tuned model.", ), gr.CheckboxGroup( BERT_LIKE_MODELS, type="value", label="Pick optional BERT base uncased model.", ), gr.Dropdown( ["False", "True"], label="Normalize BERT-like model's predictions to gendered-only?", type="index", ), gr.Dropdown( ["False", "True"], label="Include baseline predictions (dashed-lines)?", type="index", ), gr.Textbox( lines=5, label="Input Text: Sentence about a single person using some gendered pronouns to refer to them.", ), ], outputs=[ gr.Textbox( type="auto", label="Sample target text fed to model"), gr.Plot(type="auto", label="Plot of softmax probability pronouns predicted female."), gr.Dataframe( show_label=True, overflow_row_behaviour="show_ends", label="Table of softmax probability pronouns predicted female", ), gr.Plot(type="auto", label="Plot of softmax probability pronouns predicted male."), gr.Dataframe( show_label=True, overflow_row_behaviour="show_ends", label="Table of softmax probability pronouns predicted male", ), ], title=title, description=description, article=article, examples=[scientist_example, building_example, death_date_example, neg_reddit_example] ).launch(debug=True)