{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Watermark Analysis\n", "\n", "Notebook for performing analysis and visualization of the effects of watermarking schemes" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Basic imports\n", "import os\n", "\n", "from tqdm import tqdm\n", "from statistics import mean\n", "\n", "import numpy as np\n", "import pandas as pd\n", "import torch\n", "\n", "import matplotlib.pyplot as plt\n", "\n", "from matplotlib import rc\n", "rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})\n", "rc('text', usetex=True)\n", "\n", "import cmasher as cmr" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from datasets import load_from_disk" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Load the processed dataset/frame" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# save_name = \"analysis_ds_1-19_realnews_1-3_v1\" # in figure\n", "# save_name = \"analysis_ds_1-21_greedy_redo\" \n", "# save_name = \"analysis_ds_1-21_greedy_redo_truncated\"\n", "# save_name = \"analysis_ds_1-23_greedy_gamma_0-25_truncated\" \n", "# save_name = \"analysis_ds_1-23_greedy_gamma_0-25_0-5_truncated\" # in figure (not 100% sure this is correct, check)\n", "\n", "# save_name = \"analysis_ds_1-20_more_attack\" # in figure\n", "\n", "# save_name = \"analysis_ds_1-19_realnews_1-3_v1\" # in figure\n", "# save_name = \"analysis_ds_1-23_en_1-3\"\n", "save_name = \"analysis_ds_1-23_pile_1-3\"\n", "\n", "save_dir = f\"input/{save_name}\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "raw_data = load_from_disk(save_dir)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### convert to pandas df" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df = raw_data.to_pandas()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(f\"Orig number of rows: {len(df)}\")\n", "df.tail()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df.columns" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### \"retokenization\" problem \n", "\n", "current hypo for what matches this criterion is based on the non 1-to-1 aspect of tokenization" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "retok_problematic_rows = df[(df['w_bl_whitelist_fraction'] != -1.0) & (df['w_bl_whitelist_fraction'] != 1.0) & (df['bl_type'] == 'hard')]\n", "print(f\"Num rows that are hard-blacklisted, and measureable, but still have a non-100% WL fraction: {len(retok_problematic_rows)} out of {len(df[df['bl_type'] == 'hard'])}\")\n", "# retok_problematic_rows" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "#### Replace or drop the the specially marked -1 rows since these are unmeasureable due to short length" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "orig_len = len(df)\n", "\n", "# df['no_bl_whitelist_fraction'].mask(df['no_bl_whitelist_fraction'] == -1.0, pd.NA, inplace=True)\n", "# df['w_bl_whitelist_fraction'].mask(df['w_bl_whitelist_fraction'] == -1.0, pd.NA, inplace=True)\n", "\n", "df = df[df[\"no_bl_whitelist_fraction\"] != -1.0]\n", "df = df[df[\"w_bl_whitelist_fraction\"] != -1.0]\n", "\n", "print(f\"Dropped {orig_len-len(df)} rows, new len {len(df)}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Drop rows where there weren't enough tokens to measure ppl in one or both of the output cases" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "orig_len = len(df)\n", "# df = df[df[\"no_bl_ppl\"].isna()]\n", "# df = df[df[\"w_bl_ppl\"].isna()]\n", "df = df[~(df[\"no_bl_ppl\"].isna() | df[\"w_bl_ppl\"].isna())]\n", "print(f\"Dropped {orig_len-len(df)} rows, new len {len(df)}\")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "#### drop rows with really large bias, as 100.0 is $\\simeq \\infty$" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "orig_len = len(df)\n", "\n", "df = df[df[\"bl_logit_bias\"] <= 100.0]\n", "\n", "print(f\"Dropped {orig_len-len(df)} rows, new len {len(df)}\")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "#### drop rows where using sampling but also beam search, not considering at this time" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "orig_len = len(df)\n", "\n", "# df = df[df[\"bl_hparams\"].apply(lambda tup: (tup[0] == False and tup[2] != 1) or (tup[0] == True and tup[2] == 1) or (tup[0] == False))]\n", "df = df[((df[\"use_sampling\"]==True) & (df[\"num_beams\"] == 1)) | (df[\"use_sampling\"]==False)]\n", "\n", "print(f\"Dropped {orig_len-len(df)} rows, new len {len(df)}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### correct the sampling temp column" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df.loc[df[\"use_sampling\"]==False,\"sampling_temp\"] = df.loc[df[\"use_sampling\"]==False,\"sampling_temp\"].fillna(0.0)\n", "df.loc[df[\"use_sampling\"]==True,\"sampling_temp\"] = df.loc[df[\"use_sampling\"]==True,\"sampling_temp\"].fillna(1.0)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "#### marking the hard blacklist rows as having inf/very large bias\n", "\n", "(after the > 100.0 bias drop)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df.loc[df[\"bl_type\"]==\"hard\",\"bl_logit_bias\"] = np.inf\n", "# df.loc[df[\"bl_type\"]==\"hard\",\"bl_logit_bias\"] = 10000 # crosscheck with whats hardcoded in the bl processor" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "#### Rename some parameters" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df[\"delta\"] = df[\"bl_logit_bias\"].values\n", "df[\"gamma\"] = 1 - df[\"bl_proportion\"].values\n", "df[\"gamma\"] = df[\"gamma\"].round(3)\n", "\n", "df[\"no_bl_act_num_wl_tokens\"] = np.round(df[\"no_bl_whitelist_fraction\"].values*df[\"no_bl_num_tokens_generated\"],1) # round to 1 for sanity\n", "df[\"w_bl_act_num_wl_tokens\"] = np.round(df[\"w_bl_whitelist_fraction\"].values*df[\"w_bl_num_tokens_generated\"],1) # round to 1 for sanity\n", "\n", "df[\"w_bl_std_num_wl_tokens\"] = np.sqrt(df[\"w_bl_var_num_wl_tokens\"].values)\n", "\n", "if \"real_completion_length\":\n", " df[\"baseline_num_tokens_generated\"] = df[\"real_completion_length\"].values\n", "\n", "if \"actual_attacked_ratio\" in df.columns:\n", " df[\"actual_attacked_fraction\"] = df[\"actual_attacked_ratio\"].values*df[\"replace_ratio\"].values\n", "\n", "if \"meta\" in df.columns:\n", " df[\"pile_set_name\"] = df[\"meta\"].apply(lambda dict: dict[\"pile_set_name\"])\n", "\n", "df[\"baseline_hit_list_length\"] = df[\"baseline_hit_list\"].apply(len)\n", "df[\"no_bl_hit_list_length\"] = df[\"no_bl_hit_list\"].apply(len)\n", "df[\"w_bl_hit_list_length\"] = df[\"w_bl_hit_list\"].apply(len)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# for pile outlier filtering\n", "df[\"w_bl_space_count\"] = df[\"w_bl_output\"].apply(lambda string: string.count(\" \"))\n", "df[\"no_bl_space_count\"] = df[\"no_bl_output\"].apply(lambda string: string.count(\" \"))\n", "df[\"baseline_space_count\"] = df[\"baseline_completion\"].apply(lambda string: string.count(\" \"))\n", "\n", "df[\"w_bl_space_frac\"] = df[\"w_bl_space_count\"].values / df[\"w_bl_hit_list_length\"]\n", "df[\"no_bl_space_frac\"] = df[\"no_bl_space_count\"].values / df[\"no_bl_hit_list_length\"]\n", "df[\"baseline_space_frac\"] = df[\"baseline_space_count\"].values / df[\"baseline_hit_list_length\"]" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Filter for the generation lengths we want to look at" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "orig_len = len(df)\n", "\n", "# # main filters\n", "# # df = df[(df[\"real_completion_length\"] == 200) & (df[\"w_bl_num_tokens_generated\"] == 200)]\n", "# df = df[(df[\"gamma\"] == 0.1) | (df[\"gamma\"] == 0.25) | (df[\"gamma\"] == 0.5)]\n", "# df = df[(df[\"delta\"] == 1.0) | (df[\"delta\"] == 2.0) | (df[\"delta\"] == 10.0)]\n", "# df = df[(df[\"use_sampling\"] == True)]\n", "# df = df[(df[\"bl_type\"] == \"soft\")]\n", "\n", "# df = df[(df[\"real_completion_length\"] == 200) & (df[\"no_bl_num_tokens_generated\"] == 200) & (df[\"w_bl_num_tokens_generated\"] == 200)] # now also applies to the truncated version\n", "# df = df[(df[\"no_bl_num_tokens_generated\"] >= 500) & (df[\"w_bl_num_tokens_generated\"] >= 500)] # all gas noop\n", "\n", "# # # attack specific\n", "# df = df[(df[\"real_completion_length\"] == 200) & (df[\"no_bl_num_tokens_generated\"] == 200) & (df[\"w_bl_num_tokens_generated\"] == 200)]\n", "# df = df[(df[\"replace_ratio\"] <= 0.7)]\n", "\n", "# NOTE pile only\n", "df = df[df[\"w_bl_space_frac\"] <= 0.9]\n", "df = df[df[\"no_bl_space_frac\"] <= 0.9]\n", "# df = df[df[\"pile_set_name\"] != \"Github\"]\n", "\n", "upper_T = 205\n", "lower_T = 195\n", "df = df[(df[\"baseline_hit_list_length\"] >= lower_T) & (df[\"no_bl_hit_list_length\"] >= lower_T) & (df[\"w_bl_hit_list_length\"] >= lower_T)] # now also applies to the truncated version\n", "df = df[(df[\"baseline_hit_list_length\"] <= upper_T) & (df[\"no_bl_hit_list_length\"] <= upper_T) & (df[\"w_bl_hit_list_length\"] <= upper_T)] # now also applies to the truncated version\n", "\n", "\n", "print(f\"Dropped {orig_len-len(df)} rows, new len {len(df)}\")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Add z-scores (convert the raw watermark measurement, fraction, to a z-score )" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from math import sqrt\n", "import scipy.stats\n", "def compute_z_score(observed_wl_frac, T, gamma):\n", " numer = observed_wl_frac - gamma\n", " denom = sqrt(gamma*(1-gamma)/T)\n", " z = numer/denom\n", " return z\n", "\n", "def compute_wl_for_z(z, T, gamma):\n", " denom = sqrt(gamma*(1-gamma)/T)\n", " numer = ((z*denom)+gamma)*T\n", " return numer\n", "\n", "def compute_p_value(z):\n", " p_value = scipy.stats.norm.sf(abs(z))\n", " return p_value\n", "\n", "df[\"baseline_z_score\"] = df[[\"baseline_whitelist_fraction\", \"baseline_num_tokens_generated\", \"gamma\"]].apply(lambda tup: compute_z_score(*tup), axis=1)\n", "df[\"no_bl_z_score\"] = df[[\"no_bl_whitelist_fraction\", \"no_bl_num_tokens_generated\", \"gamma\"]].apply(lambda tup: compute_z_score(*tup), axis=1)\n", "df[\"w_bl_z_score\"] = df[[\"w_bl_whitelist_fraction\", \"w_bl_num_tokens_generated\", \"gamma\"]].apply(lambda tup: compute_z_score(*tup), axis=1)\n", "\n", "if \"w_bl_attacked_whitelist_fraction\" in df.columns:\n", " df[\"w_bl_attacked_z_score\"] = df[[\"w_bl_attacked_whitelist_fraction\", \"w_bl_attacked_num_tokens_generated\", \"gamma\"]].apply(lambda tup: compute_z_score(*tup), axis=1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# if attacked in df\n", "if \"w_bl_attacked_whitelist_fraction\" in df.columns:\n", " df[\"w_bl_attacked_act_num_wl_tokens\"] = np.round(df[\"w_bl_attacked_whitelist_fraction\"].values*df[\"w_bl_attacked_num_tokens_generated\"],1) # round to 1 for sanity\n", "\n", " df[\"w_bl_attacked_z_score\"] = df[[\"w_bl_attacked_whitelist_fraction\", \"w_bl_attacked_num_tokens_generated\", \"gamma\"]].apply(lambda tup: compute_z_score(*tup), axis=1)\n", "\n", " df[[\"bl_proportion\",\"w_bl_attacked_whitelist_fraction\", \"w_bl_attacked_num_tokens_generated\",\"w_bl_attacked_act_num_wl_tokens\", \"w_bl_attacked_z_score\"]]" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Prepare groupby (decide which hyperparameters to groups the rows by)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# groupby_fields = ['num_beams', 'max_new_tokens']\n", "# groupby_fields = ['use_sampling','num_beams', 'max_new_tokens']\n", "# groupby_fields = ['use_sampling','num_beams', 'max_new_tokens', 'bl_logit_bias']\n", "# groupby_fields = ['use_sampling','num_beams', 'max_new_tokens', 'bl_type','bl_logit_bias']\n", "# groupby_fields = ['use_sampling','sampling_temp','num_beams', 'max_new_tokens', 'bl_type','bl_logit_bias']\n", "# groupby_fields = ['use_sampling','sampling_temp','num_beams', 'max_new_tokens', 'bl_type','bl_logit_bias','bl_proportion']\n", "# groupby_fields = ['use_sampling','num_beams','bl_type','bl_logit_bias','bl_proportion']\n", "\n", "if \"w_bl_attacked_whitelist_fraction\" in df.columns: \n", " groupby_fields = ['use_sampling','num_beams','gamma','delta', 'replace_ratio'] # attack grouping\n", "else:\n", " groupby_fields = ['use_sampling','num_beams','delta','gamma'] # regular grouping\n", " # groupby_fields = ['use_sampling','delta','gamma'] # regular grouping, but no beam variation\n", " # groupby_fields = ['delta','gamma'] # regular grouping, but no beam variation, and all sampling" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### narrowing in on IQ range (not generally used)\n", "\n", "(removing outliers by subsetting to rows near the mean etc.)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# tmp_grped_25 = df.groupby(groupby_fields, as_index= False)['avg_spike_entropy'].quantile(q=0.25).rename(columns={'avg_spike_entropy': 'avg_spike_entropy_25th'})\n", "# tmp_grped_50 = df.groupby(groupby_fields, as_index= False)['avg_spike_entropy'].quantile(q=0.5).rename(columns={'avg_spike_entropy': 'avg_spike_entropy_50th'})\n", "# tmp_grped_75 = df.groupby(groupby_fields, as_index= False)['avg_spike_entropy'].quantile(q=0.75).rename(columns={'avg_spike_entropy': 'avg_spike_entropy_75th'})\n", "# df = df.merge(tmp_grped_25, on = groupby_fields)\n", "# df = df.merge(tmp_grped_50, on = groupby_fields)\n", "# df = df.merge(tmp_grped_75, on = groupby_fields)\n", "\n", "# # tmp_grped_mean = df.groupby(groupby_fields, as_index= False)['avg_spike_entropy'].mean().rename(columns={'avg_spike_entropy': 'avg_spike_entropy_mean'})\n", "# # tmp_grped_median = df.groupby(groupby_fields, as_index= False)['avg_spike_entropy'].median().rename(columns={'avg_spike_entropy': 'avg_spike_entropy_median'})\n", "# # df = df.merge(tmp_grped_mean, on = groupby_fields)\n", "# # df = df.merge(tmp_grped_median, on = groupby_fields)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# # eps = 0.001\n", "# eps = 0.005\n", "# df[\"avg_spike_entropy_mean_minus_eps\"] = df['avg_spike_entropy_mean']-eps\n", "# df[\"avg_spike_entropy_mean_plus_eps\"] = df['avg_spike_entropy_mean']+eps\n", "\n", "# df[\"avg_spike_entropy_median_minus_eps\"] = df['avg_spike_entropy_median']-eps\n", "# df[\"avg_spike_entropy_median_plus_eps\"] = df['avg_spike_entropy_median']+eps\n", "# print(df.columns)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# # df[[\"avg_spike_entropy_25th\",\"avg_spike_entropy_75th\"]]\n", "# df[[\"avg_spike_entropy_mean_minus_eps\",\"avg_spike_entropy_mean\",\"avg_spike_entropy_mean_plus_eps\"]]\n", "# df[[\"avg_spike_entropy_median_minus_eps\",\"avg_spike_entropy_median\",\"avg_spike_entropy_median_plus_eps\"]]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# orig_len = len(df)\n", "\n", "# subdf = df[(df[\"avg_spike_entropy\"] >= df[\"avg_spike_entropy_25th\"]) & (df[\"avg_spike_entropy\"] <= df[\"avg_spike_entropy_75th\"])]\n", "\n", "# # subdf = df[(df[\"avg_spike_entropy\"] >= df[\"avg_spike_entropy_mean_minus_eps\"]) & (df[\"avg_spike_entropy\"] <= df[\"avg_spike_entropy_mean_plus_eps\"])]\n", "# # subdf = df[(df[\"avg_spike_entropy\"] >= df[\"avg_spike_entropy_mean_minus_eps\"]) & (df[\"avg_spike_entropy\"] <= df[\"avg_spike_entropy_mean_plus_eps\"])]\n", "\n", "# print(f\"Dropped {orig_len-len(subdf)} rows, new len {len(subdf)}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# subdf.groupby(groupby_fields)['avg_spike_entropy'].describe()\n", "# df.groupby(groupby_fields)['avg_spike_entropy'].describe()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# df = subdf" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Perform the groupby (group rows by their hyperparameter settings)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "grouped_df = df.groupby(groupby_fields)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(f\"Number of rows after filtering: {len(df)}\")\n", "print(f\"Number of groups: {len(grouped_df)}\")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Loop to compute \"confusion matrix\" (TPR,FPR etc.) at some z scores for tabulation (Table 2 & 8)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sklearn.metrics as metrics\n", "\n", "def reject_null_hypo(z_score=None,cuttoff=None):\n", " return z_score > cuttoff\n", "\n", "records = []\n", "\n", "for group_params in tqdm(list(grouped_df.groups.keys())):\n", " sub_df = grouped_df.get_group(group_params)\n", " grp_size = len(sub_df)\n", "\n", " # baseline_z_scores = sub_df[\"baseline_z_score\"].values\n", " # w_bl_z_scores = sub_df[\"w_bl_z_score\"].values\n", " # all_scores = np.concatenate([baseline_z_scores,w_bl_z_scores])\n", "\n", " # baseline_labels = np.zeros_like(baseline_z_scores)\n", " # attacked_labels = np.ones_like(w_bl_z_scores)\n", " # all_labels = np.concatenate([baseline_labels,attacked_labels])\n", "\n", " # fpr, tpr, thresholds = metrics.roc_curve(all_labels, all_scores, pos_label=1)\n", " # roc_auc = metrics.auc(fpr, tpr)\n", " record = {k:v for k,v in zip(groupby_fields,group_params)}\n", "\n", " for thresh in [4.0,5.0]:\n", " \n", " record[\"count\"] = grp_size\n", " record[f\"baseline_fpr_at_{thresh}\"] = reject_null_hypo(z_score=sub_df[\"baseline_z_score\"].values,cuttoff=thresh).sum() / grp_size\n", " record[f\"baseline_tnr_at_{thresh}\"] = (~reject_null_hypo(z_score=sub_df[\"baseline_z_score\"],cuttoff=thresh)).sum() / grp_size\n", " record[f\"no_bl_fpr_at_{thresh}\"] = reject_null_hypo(z_score=sub_df[\"no_bl_z_score\"].values,cuttoff=thresh).sum() / grp_size\n", " record[f\"no_bl_tnr_at_{thresh}\"] = (~reject_null_hypo(z_score=sub_df[\"no_bl_z_score\"].values,cuttoff=thresh)).sum() / grp_size\n", " record[f\"w_bl_tpr_at_{thresh}\"] = reject_null_hypo(z_score=sub_df[\"w_bl_z_score\"].values,cuttoff=thresh).sum() / grp_size\n", " record[f\"w_bl_fnr_at_{thresh}\"] = (~reject_null_hypo(z_score=sub_df[\"w_bl_z_score\"].values,cuttoff=thresh)).sum() / grp_size\n", "\n", " if \"w_bl_attacked_z_score\" in sub_df.columns:\n", " record[f\"w_bl_attacked_tpr_at_{thresh}\"] = reject_null_hypo(z_score=sub_df[\"w_bl_attacked_z_score\"].values,cuttoff=thresh).sum() / grp_size\n", " record[f\"w_bl_attacked_fnr_at_{thresh}\"] = (~reject_null_hypo(z_score=sub_df[\"w_bl_attacked_z_score\"].values,cuttoff=thresh)).sum() / grp_size\n", "\n", " records.append(record)\n", "\n", " # # df[f\"baseline_fp_at_{thresh}\"] = reject_null_hypo(z_score=df[\"baseline_z_score\"].values,cuttoff=thresh)\n", " # # df[f\"baseline_tn_at_{thresh}\"] = ~reject_null_hypo(z_score=df[\"baseline_z_score\"],cuttoff=thresh)\n", " # # df[f\"no_bl_fp_at_{thresh}\"] = reject_null_hypo(z_score=df[\"no_bl_z_score\"].values,cuttoff=thresh)\n", " # # df[f\"no_bl_tn_at_{thresh}\"] = ~reject_null_hypo(z_score=df[\"no_bl_z_score\"].values,cuttoff=thresh)\n", " # # df[f\"w_bl_tp_at_{thresh}\"] = reject_null_hypo(z_score=df[\"w_bl_z_score\"].values,cuttoff=thresh)\n", " # # df[f\"w_bl_fn_at_{thresh}\"] = ~reject_null_hypo(z_score=df[\"w_bl_z_score\"].values,cuttoff=thresh)\n", "\n", "\n", "roc_df = pd.DataFrame.from_records(records)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# thresh = 6.0\n", "# thresh = 5.0\n", "std_threshes = [4.0, 5.0] #, 6.0]\n", "# std_threshes = [4.0]\n", "\n", "# roc_df[\"params\"] = roc_df.index.to_list()\n", "\n", "# columns = [\"num_beams\", \"delta\", \"gamma\", \"count\"]\n", "# columns = [\"delta\", \"gamma\", \"count\"]\n", "columns = [\"use_sampling\",\"delta\", \"gamma\", \"count\"]\n", "# columns = [\"use_sampling\", \"replace_ratio\", \"count\"]\n", "\n", "for thresh in std_threshes:\n", " # columns += [f\"baseline_fpr_at_{thresh}\",f\"no_bl_fpr_at_{thresh}\",f\"w_bl_tpr_at_{thresh}\"]\n", " # columns += [f\"baseline_fpr_at_{thresh}\",f\"baseline_tnr_at_{thresh}\",f\"no_bl_fpr_at_{thresh}\",f\"no_bl_tnr_at_{thresh}\",f\"w_bl_tpr_at_{thresh}\",f\"w_bl_fn_at_{thresh}\"]\n", "\n", "\n", " # columns += [f\"baseline_fpr_at_{thresh}\",f\"baseline_tnr_at_{thresh}\",f\"w_bl_tpr_at_{thresh}\",f\"w_bl_fnr_at_{thresh}\"]\n", " \n", " if f\"w_bl_attacked_fnr_at_{thresh}\" in roc_df.columns:\n", " columns += [f\"w_bl_tpr_at_{thresh}\",f\"w_bl_fnr_at_{thresh}\"]\n", " columns += [f\"w_bl_attacked_tpr_at_{thresh}\",f\"w_bl_attacked_fnr_at_{thresh}\"] # if attack\n", " else:\n", " columns += [f\"baseline_fpr_at_{thresh}\",f\"baseline_tnr_at_{thresh}\",f\"w_bl_tpr_at_{thresh}\",f\"w_bl_fnr_at_{thresh}\"]\n", "\n", "# filter ot not\n", "sub_df = roc_df[(roc_df[\"use_sampling\"] == True) & ((roc_df[\"delta\"] == 1.0) | (roc_df[\"delta\"] == 2.0) | (roc_df[\"delta\"] == 5.0)) & ((roc_df[\"gamma\"] == 0.25) |(roc_df[\"gamma\"] == 0.5) )]\n", "# sub_df = roc_df[(roc_df[\"use_sampling\"] == False) & ((roc_df[\"delta\"] == 1.0) | (roc_df[\"delta\"] == 2.0) | (roc_df[\"delta\"] == 5.0)) & ((roc_df[\"gamma\"] == 0.25) |(roc_df[\"gamma\"] == 0.5) ) & (roc_df[\"num_beams\"] == 8)]\n", "# sub_df = roc_df[(roc_df[\"replace_ratio\"] == 0.1) | (roc_df[\"replace_ratio\"] == 0.3) | (roc_df[\"replace_ratio\"] == 0.5) | (roc_df[\"replace_ratio\"] == 0.7)]\n", "# sub_df = roc_df[(roc_df[\"num_beams\"] == 8)]\n", "# sub_df = roc_df\n", "\n", "# sub_df.sort_values(\"delta\")[columns]\n", "# sub_df.sort_values(\"num_beams\")[columns]\n", "sub_df.sort_values(by=[\"delta\",\"gamma\"],ascending=[True, False])[columns]" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "#### write tables to latex" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# print(roc_df[columns].drop([\"count\"],axis=1).sort_values(\"gamma\").round(3).to_latex(index=False))\n", "# print(roc_df[columns].drop([\"count\"],axis=1).sort_values(\"delta\").round(3).to_latex(index=False))\n", "# print(roc_df[columns].drop([\"count\"],axis=1).sort_values(\"num_beams\").round(3).to_latex(index=False))\n", "\n", "# print(sub_df.sort_values(by=[\"delta\",\"gamma\"],ascending=[True, False])[columns].round(3).to_latex(index=False))\n", "# print(sub_df.sort_values(\"num_beams\")[columns].round(3).to_latex(index=False))" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# ROC: No Attack (figure 4)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.clf()\n", "plt.figure(constrained_layout=True)\n", "plt.figure(figsize=(5, 4))\n", "\n", "import sklearn.metrics as metrics\n", "\n", "zoom = False\n", "# zoom = True\n", "\n", "beam_search = None\n", "# beam_search = 1\n", "# beam_search = 4\n", "# beam_search = 8\n", "\n", "deltas = [1.0,2.0,5.0,10.0]\n", "# gammas = [0.25, 0.5]\n", "gammas = [0.25]\n", "# gammas = [0.5]\n", "\n", "# deltas = [1.0,2.0,5.0,10.0]\n", "# gammas = [0.1,0.5]\n", "\n", "groups = []\n", "names = []\n", "for d in deltas:\n", " for g in gammas:\n", " if beam_search:\n", " groups.append((False, beam_search, d, g))\n", " else:\n", " groups.append((True, 1, d, g))\n", " names.append(f\"$\\delta:{d},\\gamma:{g}$\")\n", "groups=groups[::-1]\n", "names=names[::-1]\n", "\n", "# Make colormap\n", "import matplotlib.pyplot as plt\n", "viridis = plt.colormaps['viridis'].resampled(len(groups)+1) \n", "cmap = viridis.colors[:len(groups)][::-1]\n", "\n", "# plot different parameter levels\n", "for i,(group,name) in enumerate(zip(groups,names)):\n", "\n", " baseline_z_scores = grouped_df.get_group(group)[\"baseline_z_score\"].values\n", " w_bl_z_scores = grouped_df.get_group(group)[\"w_bl_z_score\"].values\n", " all_scores = np.concatenate([baseline_z_scores,w_bl_z_scores])\n", "\n", " baseline_labels = np.zeros_like(baseline_z_scores)\n", " attacked_labels = np.ones_like(w_bl_z_scores)\n", " all_labels = np.concatenate([baseline_labels,attacked_labels])\n", "\n", " fpr, tpr, thresholds = metrics.roc_curve(all_labels, all_scores, pos_label=1)\n", " roc_auc = metrics.auc(fpr, tpr)\n", "\n", " plt.plot(fpr, tpr, color=cmap[i], label = f'{name}, AUC:%0.3f, PPL:{round(grouped_df[\"w_bl_ppl\"].describe().loc[group][\"mean\"],1)}' % roc_auc, linewidth=3)\n", "\n", "if \"w_bl_attacked_ppl\" in df.columns:\n", " pass\n", "else:\n", " # # vanilla ppl value\n", " plt.scatter([-1],[-1],label=f' $\\delta=0$, PPL: {round(grouped_df[\"no_bl_ppl\"].describe().loc[groups,\"mean\"].mean(),1)}', color=\"white\")\n", "\n", "if zoom:\n", " if not \"w_bl_attacked_ppl\" in df.columns:\n", " plt.legend(loc = 'lower right', fontsize = 12)\n", " plt.xscale(\"log\")\n", " # plt.yscale(\"log\")\n", " plt.xlim([0, 1])\n", " plt.ylim([0.5, 1])\n", " plot_name = (\"roc_auc_zoom\" if not beam_search else f\"roc_auc_zoom_greedy_beams_{beam_search}\")\n", "\n", "else:\n", " if \"w_bl_attacked_ppl\" in df.columns:\n", " plt.legend(loc = 'lower right', fontsize = 12)\n", " plt.plot([0, 1], [0, 1],'r--')\n", " plt.xlim([0, 1])\n", " plt.ylim([0, 1])\n", " plot_name = (\"roc_auc\" if not beam_search else f\"roc_auc_greedy_beams_{beam_search}\")\n", "\n", "plt.ylabel('True Positive Rate', fontsize = 12)\n", "plt.xlabel('False Positive Rate', fontsize = 12)\n", "\n", "print(plot_name)\n", "\n", "# fname = f\"figs/{plot_name}.pdf\"\n", "# plt.savefig(fname, format=\"pdf\")\n", "\n", "plt.show()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "\n", "# ROC: Attack (figure 6)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sklearn.metrics as metrics\n", "\n", "plt.clf()\n", "plt.figure(constrained_layout=True)\n", "plt.figure(figsize=(5, 4))\n", "\n", "# attack_budgets = [0.1,0.2,0.3,0.4,0.5,0.6,0.7]\n", "attack_budgets = [0.1,0.3,0.5,0.7]\n", "groups = [(True, 1, 0.5, 2.0, budget) for budget in attack_budgets]\n", "beams = False\n", "# groups = [(False, 8, 0.5, 2.0, budget) for budget in attack_budgets]\n", "# beams = True\n", "\n", "names = [f\"$\\epsilon={eps}$\" for eps in attack_budgets]\n", "\n", "# Make colormap\n", "import matplotlib.pyplot as plt\n", "viridis = plt.colormaps['viridis'].resampled(len(groups)+1+1) # attack\n", "cmap = viridis.colors[:len(groups)+1][::-1]\n", "\n", "# plot original\n", "group = groups[0] # any will do\n", "baseline_z_scores = grouped_df.get_group(group)[\"baseline_z_score\"].values\n", "baseline_labels = np.zeros_like(baseline_z_scores)\n", "\n", "orig_watermark_z_scores = grouped_df.get_group(group)[\"w_bl_z_score\"].values\n", "watermark_labels = np.ones_like(orig_watermark_z_scores)\n", "\n", "all_scores = np.concatenate([baseline_z_scores,orig_watermark_z_scores])\n", "all_labels = np.concatenate([baseline_labels,watermark_labels])\n", "\n", "fpr, tpr, thresholds = metrics.roc_curve(all_labels, all_scores, pos_label=1)\n", "roc_auc = metrics.auc(fpr, tpr)\n", "\n", "plt.plot(fpr, tpr, color=cmap[0], label = f'unattacked, AUC:%0.3f, PPL:{round(grouped_df[\"w_bl_ppl\"].describe().loc[group][\"mean\"],1)}' % roc_auc, linewidth=3)\n", "\n", "# plot different attack levels\n", "for i,(group,name) in enumerate(zip(groups,names)):\n", "\n", " baseline_z_scores = grouped_df.get_group(group)[\"baseline_z_score\"].values\n", " attacked_z_scores = grouped_df.get_group(group)[\"w_bl_attacked_z_score\"].values\n", " all_scores = np.concatenate([baseline_z_scores,attacked_z_scores])\n", "\n", " baseline_labels = np.zeros_like(baseline_z_scores)\n", " attacked_labels = np.ones_like(attacked_z_scores)\n", " all_labels = np.concatenate([baseline_labels,attacked_labels])\n", "\n", " fpr, tpr, thresholds = metrics.roc_curve(all_labels, all_scores, pos_label=1)\n", " roc_auc = metrics.auc(fpr, tpr)\n", "\n", " plt.plot(fpr, tpr, color=cmap[i+1], label = f'{name}, AUC:%0.3f, PPL:{round(grouped_df[\"w_bl_attacked_ppl\"].describe().loc[group][\"mean\"],1)}' % roc_auc, linewidth=3)\n", "\n", "if \"w_bl_attacked_ppl\" in df.columns:\n", " pass\n", "else:\n", " # # vanilla ppl value\n", " plt.scatter([-1],[-1],label=f' $\\delta=0$, PPL: {round(grouped_df[\"no_bl_ppl\"].describe().loc[groups,\"mean\"].mean(),1)}', color=\"white\")\n", "\n", "zoom = False\n", "# zoom = True\n", "if zoom:\n", " if not \"w_bl_attacked_ppl\" in df.columns:\n", " plt.legend(loc = 'lower right')\n", " plt.xscale(\"log\")\n", " # plt.yscale(\"log\")\n", " plt.xlim([0, 1])\n", " plt.ylim([0.5, 1])\n", " if \"w_bl_attacked_ppl\" in df.columns:\n", " plot_name = \"roc_auc_untargeted_attack_no_beams_zoom\"\n", " # plot_name = \"roc_auc_untargeted_attack_with_beams_zoom\"\n", " else:\n", " plot_name = \"roc_auc_zoom\"\n", "else:\n", " if \"w_bl_attacked_ppl\" in df.columns:\n", " plt.legend(loc = 'lower right',fontsize = 9)\n", " plt.plot([0, 1], [0, 1],'r--')\n", " plt.xlim([0, 1])\n", " plt.ylim([0, 1])\n", " if \"w_bl_attacked_ppl\" in df.columns:\n", " if beams: plot_name = \"roc_auc_untargeted_attack_w_beams\"\n", " if not beams: plot_name = \"roc_auc_untargeted_attack_no_beams\"\n", " else:\n", " plot_name = \"roc_auc\"\n", "\n", "plt.ylabel('True Positive Rate')\n", "plt.xlabel('False Positive Rate')\n", "\n", "print(plot_name)\n", "\n", "# fname = f\"figs/{plot_name}.pdf\"\n", "# plt.savefig(fname, format=\"pdf\")\n", "\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Z vs T (figure 3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.clf()\n", "plt.figure(constrained_layout=True)\n", "plt.figure(figsize=(5, 4))\n", "\n", "# save_fig = True\n", "save_fig = False\n", "\n", "z_scores = True\n", "# z_scores = False\n", "\n", "beam_search = None\n", "# beam_search = 1\n", "# beam_search = 4\n", "# beam_search = 8\n", "\n", "ablate = \"delta\"\n", "delta_gammas = [\n", " # (0.5,0.25),\n", " # (1.0,0.25),\n", " # (2.0,0.25),\n", " # (5.0,0.25),\n", " # (10.0,0.25),\n", " (0.5,0.5),\n", " (1.0,0.5),\n", " (2.0,0.5),\n", " (5.0,0.5),\n", " (10.0,0.5),\n", "]\n", "# ablate = \"gamma\"\n", "# delta_gammas = [\n", "# # (5.0,0.9),\n", "# # (5.0,0.75),\n", "# # (5.0,0.5),\n", "# # (5.0,0.25),\n", "# # (5.0,0.1),\n", "# (2.0,0.9),\n", "# (2.0,0.75),\n", "# (2.0,0.5),\n", "# (2.0,0.25),\n", "# (2.0,0.1),\n", "# ]\n", "# if not z_scores: delta_gammas = delta_gammas[::-1]\n", "\n", "groups = []\n", "names = []\n", "\n", "for d,g in delta_gammas:\n", " if beam_search:\n", " groups.append((False, beam_search, d, g))\n", " else:\n", " groups.append((True, 1, d, g))\n", " names.append(f\"$\\delta:{d},\\gamma:{g}$\")\n", "\n", "groups=groups[::-1]\n", "names=names[::-1]\n", "\n", "\n", "axis_max_t = 200\n", "\n", "max_t = None\n", "# max_t = 200\n", "# max_t = 100\n", "# max_t = 50\n", "\n", "# Make colormap\n", "import matplotlib.pyplot as plt\n", "viridis = plt.colormaps['viridis'].resampled(len(groups)+1) \n", "cmap = viridis.colors[:len(groups)][::-1]\n", "\n", "for grp_idx,(group, name) in enumerate(zip(groups, names)):\n", "\n", " delta, gamma = group[-2],group[-1]\n", "\n", " # this is the series of bools corresponding to token at T being in whitelist\n", " w_bl_hit_list = grouped_df.get_group(group)[\"w_bl_hit_list\"].to_list()\n", "\n", " lengths = [len(l) for l in w_bl_hit_list]\n", " diff_lengths = set(lengths) \n", " counter = {}\n", " for l in lengths:\n", " if counter.get(l):\n", " counter[l] += 1\n", " else:\n", " counter[l] = 1\n", " if max_t:\n", " min_length = min(min(diff_lengths),max_t)\n", " max_t = min_length\n", " else:\n", " min_length = min(diff_lengths)\n", " w_bl_hit_list = [l[:min_length] for l in w_bl_hit_list]\n", "\n", " # wl_hit_matrix = ~np.matrix(w_bl_hit_list)\n", " wl_hit_matrix = (~torch.tensor(w_bl_hit_list, dtype=bool)).to(torch.float)\n", " # wl_hit_matrix\n", "\n", " n = wl_hit_matrix.shape[0]\n", "\n", " if max_t:\n", " t_values = torch.arange(0,max_t)\n", " indices = torch.arange(0,max_t)\n", " else:\n", " t_values = torch.arange(0,wl_hit_matrix.shape[1])\n", " indices = torch.arange(0,wl_hit_matrix.shape[1])\n", " # print(t_values[:10])\n", "\n", " avg_cumulative = list()\n", " std_cumulative = list()\n", " prc_25_cumulative = list()\n", " prc_50_cumulative = list()\n", " prc_75_cumulative = list()\n", "\n", " prc_25_seq_indices = list()\n", "\n", " for idx in indices:\n", "\n", " hits_upto_t = wl_hit_matrix[:,:idx+1]\n", " cumulative_sum_to_t = hits_upto_t.sum(axis=1)\n", " wl_frac_at_t = cumulative_sum_to_t/(t_values[idx]+1)\n", " \n", " if z_scores:\n", " wl_z_score_at_t = compute_z_score(wl_frac_at_t, t_values[idx], gamma)\n", " avg_at_t = torch.mean(wl_z_score_at_t,axis=0)\n", " std_at_t = torch.std(wl_z_score_at_t,axis=0)\n", " prc_25_at_t = torch.quantile(wl_z_score_at_t,q=0.25,axis=0)\n", " prc_50_at_t = torch.quantile(wl_z_score_at_t,q=0.50,axis=0)\n", " prc_75_at_t = torch.quantile(wl_z_score_at_t,q=0.75,axis=0)\n", "\n", " if gamma == 0.9: # and idx > 20 and idx < 90:\n", " pcen=np.quantile(wl_z_score_at_t,0.75,interpolation='nearest')\n", " i_near=abs(wl_z_score_at_t-pcen).argmin()\n", " # prc_25_seq_indices.append((i_near.item(),pcen))\n", " prc_25_seq_indices.append((i_near.item()))\n", " else:\n", " avg_at_t = torch.mean(wl_frac_at_t,axis=0)\n", " std_at_t = torch.std(wl_frac_at_t,axis=0)\n", " prc_25_at_t = torch.quantile(wl_frac_at_t,q=0.25,axis=0)\n", " prc_50_at_t = torch.quantile(wl_frac_at_t,q=0.50,axis=0)\n", " prc_75_at_t = torch.quantile(wl_frac_at_t,q=0.75,axis=0)\n", "\n", " avg_cumulative.append(avg_at_t.item())\n", " std_cumulative.append(std_at_t.item())\n", " prc_25_cumulative.append(prc_25_at_t.item())\n", " prc_50_cumulative.append(prc_50_at_t.item())\n", " prc_75_cumulative.append(prc_75_at_t.item())\n", "\n", "\n", " print(prc_25_seq_indices)\n", "\n", " avg_cumulative = np.array(avg_cumulative)\n", " std_cumulative = np.array(std_cumulative)\n", " std_err_cumulative = std_cumulative/np.sqrt(n)\n", " var_cumulative = std_cumulative**2\n", " \n", " plt.plot(t_values, avg_cumulative, color=cmap[grp_idx], label=name)\n", "\n", " # bounds stuff\n", "\n", " # plt.plot(t_values, prc_25_cumulative, color=cmap[grp_idx], linestyle=\"dashed\") #, label=name+',25th') \n", " # # plt.plot(t_values, prc_50_cumulative, color=cmap[grp_idx], linestyle='--', label=name+',50th') \n", " # plt.plot(t_values, prc_75_cumulative, color=cmap[grp_idx], linestyle=\"dashed\") #, label=name+',75th ') \n", " # #fill between the upper and lower bands\n", " # plt.fill_between(t_values, prc_25_cumulative, prc_75_cumulative, alpha = .1,color = cmap[grp_idx])\n", " # or just lower\n", " # plt.fill_between(t_values, prc_25_cumulative, avg_cumulative, alpha = .1,color = cmap[grp_idx])\n", "\n", " # plt.plot(t_values, avg_cumulative-std_cumulative, color=cmap[grp_idx], linestyle=\"dashed\") #, label=name+',25th') \n", " # plt.plot(t_values, avg_cumulative+std_cumulative, color=cmap[grp_idx], linestyle=\"dashed\") #, label=name+',25th') \n", " # plt.plot(t_values, avg_cumulative-std_err_cumulative, color=cmap[grp_idx], linestyle=\"dashed\") #, label=name+',25th') \n", " # plt.plot(t_values, avg_cumulative+std_err_cumulative, color=cmap[grp_idx], linestyle=\"dashed\") #, label=name+',25th') \n", " # plt.plot(t_values, avg_cumulative-var_cumulative, color=cmap[grp_idx], linestyle=\"dashed\") #, label=name+',25th') \n", " # plt.plot(t_values, avg_cumulative+var_cumulative, color=cmap[grp_idx], linestyle=\"dashed\") #, label=name+',25th') \n", " # fill between the upper and lower bands\n", " # plt.fill_between(t_values, avg_cumulative-std_cumulative, avg_cumulative+std_cumulative, alpha = .1,color = cmap[grp_idx])\n", " # plt.fill_between(t_values, avg_cumulative-std_err_cumulative, avg_cumulative+std_err_cumulative, alpha = .1,color = cmap[grp_idx])\n", " # or just lower\n", " # plt.fill_between(t_values, avg_cumulative-std_cumulative, avg_cumulative, alpha = .1,color = cmap[grp_idx])\n", " # plt.fill_between(t_values, avg_cumulative-std_err_cumulative, avg_cumulative, alpha = .1,color = cmap[grp_idx])\n", "\n", "# plt.plot([0.0],[0.0],label=f'25th Percentile', linestyle=\"dashed\", color=\"gray\")\n", "\n", "# if beam_search:\n", "# plt.title(f\"Greedy, {beam_search}-way BS\")\n", "\n", "legend_font = 11\n", "\n", "# zoom_midrange = True\n", "# zoom = True\n", "\n", "zoom = False\n", "\n", "if zoom:\n", " if z_scores:\n", " plt.legend(loc = 'upper left', fontsize=legend_font)\n", " else:\n", " plt.legend(loc = 'lower right', fontsize=legend_font)\n", " if zoom_midrange:\n", " plt.xlim([(min_length)/4, (3*(max_t if max_t else min_length)/4)+1])\n", " else:\n", " plt.xlim([0, ((max_t if max_t else min_length)/4)+1])\n", " plot_name = f\"z_vs_t_zoom_ablate_{ablate}\" if z_scores else f\"wl_vs_t_zoom_ablate_{ablate}\"\n", "else:\n", " if z_scores:\n", " plt.legend(loc = 'upper left', fontsize=legend_font)\n", " else:\n", " plt.legend(loc = 'lower right', fontsize=legend_font)\n", " \n", " plt.xlim([0, ((max_t if max_t else min_length))+1])\n", "\n", " plot_name = f\"z_vs_t_ablate_{ablate}\" if z_scores else f\"wl_vs_t_ablate_{ablate}\"\n", "\n", "axes_label_fonts = 14\n", "if z_scores:\n", " plt.ylabel('z-score',fontsize=axes_label_fonts)\n", "else:\n", " plt.ylabel('Whitelist Fraction',fontsize=axes_label_fonts)\n", "plt.xlabel('T',fontsize=axes_label_fonts)\n", "\n", "# import matplotlib.ticker as ticker\n", "# tick_spacing = 5.0\n", "# plt.gca().yaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))\n", "\n", "axes_tick_font = 13\n", "plt.xticks(fontsize=axes_tick_font)\n", "plt.yticks(fontsize=axes_tick_font)\n", "\n", "plt.grid()\n", "plt.tight_layout()\n", "\n", "if beam_search:\n", " if ablate == \"gamma\":\n", " plot_name = f\"greedy_{beam_search}_beams_delta_{delta}\" \n", " if ablate == \"delta\":\n", " plot_name = f\"greedy_{beam_search}_beams_gamma_{gamma}\" \n", "\n", "# plot_name = \"z_vs_t_ablate_gamma_boosted_delta\"\n", "# plot_name = \"z_vs_t_ablate_delta_boosted_gamma\"\n", "\n", "print(plot_name)\n", "\n", "\n", "if save_fig:\n", " # fname = f\"figs/{plot_name}.pdf\"\n", " fname = f\"figs_new/{plot_name}.pdf\"\n", " plt.savefig(fname, format=\"pdf\")\n", "\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Set up data for charts (setup for figures 2&7)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "viz_df = pd.DataFrame()\n", "\n", "# aggregating\n", "\n", "# set the hparam keys, including an indiv column for each you want to ablate on\n", "viz_df[\"bl_hparams\"] = grouped_df[\"w_bl_exp_whitelist_fraction\"].describe().index.to_list()\n", "for i,key in enumerate(groupby_fields):\n", " viz_df[key] = viz_df[\"bl_hparams\"].apply(lambda tup: tup[i])\n", "\n", "# viz_df[\"delta\"] = viz_df[\"bl_logit_bias\"].values\n", "viz_df[\"gamma\"] = viz_df[\"gamma\"].values\n", "# viz_df[\"gamma\"] = np.ones_like(viz_df[\"bl_proportion\"].values) - viz_df[\"bl_proportion\"].values\n", "\n", "# aggregate each field of interest for each hparam setting (group)\n", "describe_dict = grouped_df[\"w_bl_exp_whitelist_fraction\"].describe()\n", "viz_df[\"w_bl_exp_whitelist_fraction_mean\"] = describe_dict[\"mean\"].to_list()\n", "viz_df[\"w_bl_exp_whitelist_fraction_std\"] = describe_dict[\"std\"].to_list()\n", "\n", "describe_dict = grouped_df[\"w_bl_var_whitelist_fraction\"].describe()\n", "viz_df[\"w_bl_var_whitelist_fraction_mean\"] = describe_dict[\"mean\"].to_list()\n", "viz_df[\"w_bl_var_whitelist_fraction_std\"] = describe_dict[\"std\"].to_list()\n", "\n", "describe_dict = grouped_df[\"w_bl_whitelist_fraction\"].describe()\n", "viz_df[\"w_bl_whitelist_fraction_min\"] = describe_dict[\"min\"].to_list()\n", "viz_df[\"w_bl_whitelist_fraction_25\"] = describe_dict[\"25%\"].to_list()\n", "viz_df[\"w_bl_whitelist_fraction_50\"] = describe_dict[\"50%\"].to_list()\n", "viz_df[\"w_bl_whitelist_fraction_75\"] = describe_dict[\"75%\"].to_list()\n", "viz_df[\"w_bl_whitelist_fraction_max\"] = describe_dict[\"max\"].to_list()\n", "viz_df[\"w_bl_whitelist_fraction_mean\"] = describe_dict[\"mean\"].to_list()\n", "viz_df[\"w_bl_whitelist_fraction_std\"] = describe_dict[\"std\"].to_list()\n", "\n", "describe_dict = grouped_df[\"no_bl_whitelist_fraction\"].describe()\n", "viz_df[\"no_bl_whitelist_fraction_mean\"] = describe_dict[\"mean\"].to_list()\n", "viz_df[\"no_bl_whitelist_fraction_std\"] = describe_dict[\"std\"].to_list()\n", "\n", "\n", "describe_dict = grouped_df[\"w_bl_z_score\"].describe()\n", "viz_df[\"w_bl_z_score_mean\"] = describe_dict[\"mean\"].to_list()\n", "viz_df[\"w_bl_z_score_std\"] = describe_dict[\"std\"].to_list()\n", "\n", "describe_dict = grouped_df[\"no_bl_z_score\"].describe()\n", "viz_df[\"no_bl_z_score_mean\"] = describe_dict[\"mean\"].to_list()\n", "viz_df[\"no_bl_z_score_std\"] = describe_dict[\"std\"].to_list()\n", "\n", "describe_dict = grouped_df[\"baseline_z_score\"].describe()\n", "viz_df[\"baseline_z_score_mean\"] = describe_dict[\"mean\"].to_list()\n", "viz_df[\"baseline_z_score_std\"] = describe_dict[\"std\"].to_list()\n", "\n", "\n", "describe_dict = grouped_df[\"w_bl_ppl\"].describe()\n", "viz_df[\"w_bl_ppl_mean\"] = describe_dict[\"mean\"].to_list()\n", "viz_df[\"w_bl_ppl_std\"] = describe_dict[\"std\"].to_list()\n", "\n", "describe_dict = grouped_df[\"no_bl_ppl\"].describe()\n", "viz_df[\"no_bl_ppl_mean\"] = describe_dict[\"mean\"].to_list()\n", "viz_df[\"no_bl_ppl_std\"] = describe_dict[\"std\"].to_list()\n", "\n", "describe_dict = grouped_df[\"baseline_ppl\"].describe()\n", "viz_df[\"baseline_ppl_mean\"] = describe_dict[\"mean\"].to_list()\n", "viz_df[\"baseline_ppl_std\"] = describe_dict[\"std\"].to_list()\n", "\n", "describe_dict = grouped_df[\"avg_spike_entropy\"].describe()\n", "viz_df[\"avg_spike_entropy_mean\"] = describe_dict[\"mean\"].to_list()\n", "viz_df[\"avg_spike_entropy_std\"] = describe_dict[\"std\"].to_list()\n", "\n", "print(f\"groupby legend: {groupby_fields}\")\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# filtering\n", "\n", "viz_df = viz_df[viz_df[\"bl_hparams\"].apply(lambda tup: (tup[0] == True))] # sampling\n", "\n", "# viz_df = viz_df[viz_df[\"bl_hparams\"].apply(lambda tup: (tup[0] == False))] # greedy\n", "\n", "\n", "# fix one of the bl params for analytic chart\n", "# viz_df = viz_df[(viz_df[\"gamma\"]==0.9) & (viz_df[\"delta\"]<=10.0)]\n", "# viz_df = viz_df[(viz_df[\"gamma\"]==0.75) & (viz_df[\"delta\"]<=10.0)]\n", "# viz_df = viz_df[(viz_df[\"gamma\"]==0.5) & (viz_df[\"delta\"]<=10.0)]\n", "# viz_df = viz_df[(viz_df[\"gamma\"]==0.25) & (viz_df[\"delta\"]<=10.0)]\n", "# viz_df = viz_df[(viz_df[\"gamma\"]==0.1) & (viz_df[\"delta\"]<=10.0)]\n", "\n", "# for the sample pareto chart\n", "viz_df = viz_df[(viz_df[\"delta\"] > 0.5) & (viz_df[\"delta\"]<=10.0)]\n", "# viz_df = viz_df[(viz_df[\"delta\"]<=2.0)] # zoom in on lower deltas\n", "# viz_df = viz_df[(viz_df[\"delta\"] >= 2.0) & (viz_df[\"delta\"]<=10.0)] # mid deltas\n", "# viz_df = viz_df[(viz_df[\"gamma\"] != 0.25) & (viz_df[\"gamma\"] != 0.75) & (viz_df[\"delta\"]<=2.0)]\n", "# viz_df = viz_df[(viz_df[\"gamma\"] != 0.1) & (viz_df[\"gamma\"] != 0.9) & (viz_df[\"delta\"]<=2.0)]\n", "\n", "# viz_df = viz_df[(viz_df[\"delta\"]==0.5) | (viz_df[\"delta\"]==2.0) | (viz_df[\"delta\"]==10.0)]\n", "\n", "# viz_df = viz_df[(viz_df[\"delta\"]!=0.1)&(viz_df[\"delta\"]!=0.5)&(viz_df[\"delta\"]!=50.0)]\n", "\n", "# for the beams pareto\n", "# viz_df = viz_df[(viz_df[\"delta\"]!=50.0)]\n", "# viz_df = viz_df[(viz_df[\"delta\"]!=50.0) & (viz_df[\"num_beams\"]!=1)]\n", "\n", "print(len(viz_df))\n", "\n", "viz_df" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# grouped_df[\"avg_spike_entropy\"]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# viz_df[[\"gamma\",\"avg_spike_entropy_mean\"]]" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Basic Exp vs Empirical WL fraction chart (figure 7)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "# plt.style.use(\"classic\")\n", "plt.style.use(\"default\")\n", "# plt.style.use('ggplot') \n", "# plt.style.use('seaborn')\n", "\n", "rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})\n", "rc('text', usetex=True)\n", "\n", "\n", "plt.clf()\n", "# plt.figure(figsize=(16, 4))\n", "# plt.figure(figsize=(8, 4))\n", "plt.figure(constrained_layout=True)\n", "plt.figure(figsize=(5, 4))\n", "\n", "\n", "# x_col = 'bl_hparams'\n", "# a = viz_df[x_col].apply(str)\n", "\n", "# x_col = 'bl_logit_bias'\n", "# x_col = 'bl_proportion'\n", "x_col = \"delta\"\n", "# x_col = \"gamma\"\n", "\n", "a = viz_df[x_col]\n", "print(f\"Num configurations: {len(a)}\")\n", "\n", "y_col = 'w_bl_whitelist_fraction_mean'\n", "y_col_err = 'w_bl_whitelist_fraction_std'\n", "\n", "viridis = plt.colormaps['viridis'].resampled(4)\n", "# cmap = viridis.colors[::-1]\n", "cmap = viridis.colors\n", "\n", "plt.plot(a, viz_df[\"w_bl_whitelist_fraction_mean\"].values, color=cmap[1], marker='o', label='Mean') \n", "plt.plot(a, viz_df[\"w_bl_whitelist_fraction_25\"].values, color=cmap[1], linestyle='-.', label='25th Percentile') \n", "plt.plot(a, viz_df[\"w_bl_whitelist_fraction_75\"].values, color=cmap[1], linestyle='-.', label='75th Percentile') \n", "# plt.plot(a, viz_df[\"w_bl_whitelist_fraction_min\"].values, color=cmap[1], linestyle='-.', label='min') \n", "# plt.plot(a, viz_df[\"w_bl_whitelist_fraction_max\"].values, color=cmap[1], linestyle='-.', label='max') \n", "\n", "#fill between the upper and lower bands\n", "plt.fill_between(a, viz_df[\"w_bl_whitelist_fraction_25\"], viz_df[\"w_bl_whitelist_fraction_75\"], alpha = .1,color = cmap[1])\n", "# plt.fill_between(a, viz_df[\"w_bl_whitelist_fraction_25\"], viz_df[\"w_bl_whitelist_fraction_75\"], alpha = .1,color = 'darkorchid')\n", "# plt.fill_between(a, y1_low, y1_high, alpha = .1,color = 'goldenrod')\n", "\n", "\n", "y_col = 'w_bl_exp_whitelist_fraction_mean'\n", "# y_col_err = 'w_bl_var_whitelist_fraction_mean'\n", "# d = viz_df[x_col].apply(str)\n", "\n", "# sub_df = viz_df[viz_df[\"num_beams\"]==1]\n", "\n", "a = viz_df[x_col]\n", "e = viz_df[y_col].values\n", "# plt.plot(a, e, label=\"Predicted Lower Bound\", color=cmap[-1])\n", "plt.plot(a, e, label=\"Analytic Bound\", color=\"r\")\n", "# f = viz_df[y_col_err].values\n", "# # f = np.sqrt(viz_df[y_col_err].values)\n", "# plt.errorbar(d, e, yerr=f, fmt=\"o\")\n", "\n", "plt.legend(loc=\"lower right\",frameon=True, facecolor=\"white\")\n", "\n", "# for logit bias x axis\n", "# log_axis = True\n", "log_axis = False\n", "if log_axis:\n", " plt.xscale(\"log\")\n", "\n", "ax = plt.gca()\n", "plt.draw()\n", "\n", "\n", "\n", "plt.xlabel(f\"Green List Bias, $\\delta$\")\n", "# plt.xlabel(f\"Whitelist size := $\\gamma$\")\n", "\n", "plt.ylabel(\"Fraction in Green List\")\n", "\n", "\n", "plt.grid()\n", "\n", "plt.tight_layout()\n", "\n", "if log_axis:\n", " plot_name = \"analytic_w_sampling_log.pdf\"\n", "else:\n", " plot_name = \"analytic_w_sampling_linear.pdf\"\n", " # plot_name = f\"analytic_w_sampling_linear_gamma_{viz_df['gamma'].values[0]}.pdf\"\n", "\n", "# plot_name = \"analytic_w_sampling_linear_greenlist.pdf\"\n", "print(plot_name)\n", "\n", "# fname = f\"figs/{plot_name}\"\n", "# plt.savefig(fname, format=\"pdf\")\n", "plt.show()\n" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# delta gamma sampling pareto plot (figure 2 left)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})\n", "rc('text', usetex=True)\n", "\n", "plt.clf()\n", "plt.figure(constrained_layout=True)\n", "plt.figure(figsize=(5, 4))\n", "\n", "\n", "x_col = 'w_bl_ppl_mean'\n", "y_col = 'w_bl_z_score_mean'\n", "\n", "# markers = [\"x\", \"p\", \"*\", \"P\"]\n", "\n", "deltas = sorted(np.unique(viz_df[\"delta\"].values))\n", "gammas = sorted(np.unique(viz_df[\"gamma\"].values), reverse=True)\n", "print(deltas, gammas)\n", "gamma_labels = [(g if g > 0.1 else 0.1) for g in gammas]\n", "\n", "markers = [\"x\", \"p\", \"*\", \"P\"][:len(deltas)]\n", "\n", "num_colors = len(gammas)\n", "cmap = cmr.get_sub_cmap('viridis', 0.0, 0.66, N=num_colors)\n", "# cmap = cmr.get_sub_cmap('plasma', 0.0, 0.66, N=num_colors)\n", "colors = cmap.colors#[::-1]\n", "\n", "\n", "for i,delta in enumerate(deltas):\n", " for j,gamma in enumerate(gammas):\n", " sub_df = viz_df[(viz_df[\"delta\"] == delta) & (viz_df[\"gamma\"] == gamma)]\n", " a = sub_df[x_col].values\n", " b = sub_df[y_col].values\n", " # plt.scatter(a, b, label=f\"$\\delta={delta},\\gamma={gamma}$\", color=colors[j], marker=markers[i])\n", " plt.plot(a, b, label=f\"$\\delta={delta},\\gamma={gamma}$\", color=colors[j], marker=markers[i])\n", "\n", "\n", "x_col = 'no_bl_ppl_mean'\n", "y_col = 'no_bl_z_score_mean'\n", "# x_col = 'baseline_ppl_mean'\n", "# y_col = 'baseline_z_score_mean'\n", "\n", "\n", "for i,delta in enumerate(deltas):\n", " for j,gamma in enumerate(gammas):\n", " sub_df = viz_df[(viz_df[\"delta\"] == delta) & (viz_df[\"gamma\"] == gamma)]\n", " a = sub_df[x_col].values\n", " b = sub_df[y_col].values\n", " plt.scatter(a, b, label=f\"$\\delta={delta},\\gamma={gamma}$\", color=colors[j])\n", "\n", "# # # for manual legend\n", "plt.scatter([-1],[-1], label=\"Vanilla\", color=\"gray\", marker=\"o\")\n", "\n", "ax = plt.gca()\n", "\n", "from matplotlib.cm import ScalarMappable\n", "from matplotlib.colors import Normalize, NoNorm, ListedColormap\n", "cmap = ListedColormap(colors)\n", "cmappable = ScalarMappable(norm=NoNorm(),cmap=cmap)\n", "cbar = plt.colorbar(cmappable,ticks=[i for i in range(len(gammas))],shrink=0.6, pad = 0.03)\n", "cbar.ax.set_yticklabels(gamma_labels) \n", "cbar.set_label('$\\gamma$', rotation=0)\n", "\n", "\n", "all_x = np.concatenate([viz_df['w_bl_ppl_mean'].values,viz_df['no_bl_ppl_mean'].values])\n", "all_y = np.concatenate([viz_df['w_bl_z_score_mean'].values,viz_df['no_bl_z_score_mean'].values])\n", "# all_x = np.concatenate([viz_df['w_bl_ppl_mean'].values,viz_df['baseline_ppl_mean'].values])\n", "# all_y = np.concatenate([viz_df['w_bl_z_score_mean'].values,viz_df['baseline_z_score_mean'].values])\n", "\n", "min_x, max_x = np.min(all_x), np.max(all_x)\n", "min_y, max_y = np.min(all_y), np.max(all_y)\n", "\n", "# x_min_tick = 1.0\n", "x_min_tick = 3.0\n", "x_max_tick = np.ceil([max_x])[0]+1.0\n", "y_min_tick = 0.0\n", "y_max_tick = np.ceil([max_y])[0]+1.0\n", "\n", "x_ticks = np.arange(x_min_tick,x_max_tick,1.0)\n", "y_ticks = np.arange(y_min_tick,y_max_tick,5.0)\n", "\n", "\n", "x_lim_min = 3.0\n", "x_lim_max = x_max_tick\n", "y_lim_min = 0.45\n", "# y_lim_max = 1.09\n", "y_lim_max = 1.005\n", "\n", "\n", "# plt.xlim((x_min_tick-0.5,x_max_tick))\n", "plt.xlim((x_lim_min,x_lim_max))\n", "# plt.xlim((4.0,8.0))\n", "# plt.ylim((-1.0,20.0))\n", "# plt.ylim((y_lim_min,y_lim_max))\n", "\n", "ax.set_xticks(x_ticks)\n", "# ax.set_yticks(y_ticks)\n", "\n", "ax.invert_xaxis()\n", "\n", "# # manual legend for dual parameter visualization\n", "f = lambda m,c: plt.plot([],[],marker=m, color=c, ls=\"none\")[0]\n", "handles = [f(markers[::-1][i], \"gray\") for i in range(len(deltas))]\n", "handles += [f(\"o\", \"gray\")]\n", "labels = [f\"$\\delta={delta}$\" for delta in deltas[::-1]]+[f\"$\\delta=0.0$\"]\n", "plt.legend(handles, labels, loc=\"upper right\", framealpha=1)\n", "\n", "plt.grid()\n", "\n", "plt.xlabel(\"Oracle Model PPL (better →)\")\n", "plt.ylabel(\"z-score (better →)\")\n", "\n", "\n", "plt.tight_layout()\n", "\n", "# plot_name = \"pareto_sampling_no_beams\"\n", "# fname = f\"figs/{plot_name}.pdf\"\n", "# plt.savefig(fname, format=\"pdf\")\n", "plt.show()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# beams pareto plot (figure 2 right)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "num_colors = 3\n", "cmap = cmr.get_sub_cmap('viridis', 0.0, 0.66, N=num_colors)\n", "colors = cmap.colors#[::-1]\n", "\n", "# plt.style.use('ggplot')\n", "# plt.style.use('seaborn')\n", "\n", "rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})\n", "rc('text', usetex=True)\n", "\n", "plt.clf()\n", "plt.figure(constrained_layout=True)\n", "plt.figure(figsize=(5, 4))\n", "\n", "\n", "x_col = 'w_bl_ppl_mean'\n", "y_col = 'w_bl_z_score_mean'\n", "\n", "markers = [\"s\",\"D\", \"x\", \"p\", \"*\", \"P\"] # <--- seems to match other pareto fig ordering\n", "\n", "deltas = sorted(np.unique(viz_df[\"delta\"].values))\n", "num_beams = sorted(np.unique(viz_df[\"num_beams\"].values))\n", "# gamma_labels = [(g if g > 0.1 else 0.1) for g in np.unique(viz_df[\"gamma\"].values)]\n", "\n", "for i,n_beams in enumerate(num_beams):\n", " for j,delta in enumerate(deltas):\n", " sub_df = viz_df[(viz_df[\"delta\"] == delta) & (viz_df[\"num_beams\"] == n_beams)]\n", " a = sub_df[x_col].values\n", " b = sub_df[y_col].values\n", " # plt.scatter(a, b, label=f\"$\\delta={delta},\\gamma={gamma}$\", color=colors[j], marker=markers[i])\n", " plt.plot(a, b, label=f\"$\\delta={delta}$\", color=colors[i], marker=markers[j])\n", "\n", "\n", "x_col = 'no_bl_ppl_mean'\n", "y_col = 'no_bl_z_score_mean'\n", "\n", "\n", "\n", "for i,n_beams in enumerate(num_beams):\n", " for j,delta in enumerate(deltas):\n", " sub_df = viz_df[(viz_df[\"delta\"] == delta) & (viz_df[\"num_beams\"] == n_beams)]\n", " a = sub_df[x_col].values\n", " b = sub_df[y_col].values\n", " plt.scatter(a, b, label=f\"$\\delta={delta}$\", color=colors[i])\n", "\n", "# # # for manual legend\n", "plt.scatter([-10],[-10], label=\"$\\delta=0$\", color=\"gray\", marker=\"o\")\n", "\n", "ax = plt.gca()\n", "\n", "from matplotlib.cm import ScalarMappable\n", "from matplotlib.colors import Normalize, NoNorm, ListedColormap\n", "cmap = ListedColormap(colors)\n", "cmappable = ScalarMappable(norm=NoNorm(),cmap=cmap)\n", "cbar = plt.colorbar(cmappable,ticks=[i for i in range(len(num_beams))],shrink=0.6, pad = 0.04)\n", "# cbar.set_ticks(num_beams)\n", "cbar.set_ticklabels(num_beams)\n", "# cbar.ax.set_yticklabels(num_beams) \n", "cbar.set_label('Num Beams', rotation=90)\n", "\n", "\n", "all_x = np.concatenate([viz_df['w_bl_ppl_mean'].values,viz_df['no_bl_ppl_mean'].values])\n", "all_y = np.concatenate([viz_df['w_bl_z_score_mean'].values,viz_df['no_bl_z_score_mean'].values])\n", "\n", "min_x, max_x = np.min(all_x), np.max(all_x)\n", "min_y, max_y = np.min(all_y), np.max(all_y)\n", "\n", "# x_max_tick = np.ceil([max_x])[0]+1.0\n", "x_max_tick = np.ceil([max_x])[0]\n", "y_max_tick = np.ceil([max_y])[0]+1.0\n", "\n", "\n", "plt.xlim((1.0,x_max_tick))\n", "plt.ylim((-1.0,y_max_tick))\n", "\n", "# x_ticks = np.arange(x_min_tick,x_max_tick,1.0)\n", "# y_ticks = np.arange(y_min_tick,y_max_tick,5.0)\n", "\n", "# ax.set_xticks(x_ticks)\n", "# ax.set_yticks(y_ticks)\n", "\n", "ax.invert_xaxis()\n", "\n", "# # manual legend for dual parameter visualization\n", "f = lambda m,c: plt.plot([],[],marker=m, color=c, ls=\"none\")[0]\n", "handles = [f(markers[::-1][i], \"gray\") for i in range(len(deltas))]\n", "handles += [f(\"o\", \"gray\")]\n", "labels = [f\"$\\delta={delta}$\" for delta in deltas[::-1]]+[f\"$\\delta=0.0$\"]\n", "plt.legend(handles, labels, loc=\"lower left\", framealpha=1)\n", "\n", "plt.grid()\n", "\n", "plt.xlabel(\"Oracle Model PPL (better →)\")\n", "plt.ylabel(\"z-score (better →)\")\n", "\n", "\n", "plt.tight_layout()\n", "\n", "\n", "plot_name = \"pareto_greedy_w_beams\"\n", "print(plot_name)\n", "\n", "# fname = f\"figs/{plot_name}.pdf\"\n", "# plt.savefig(fname, format=\"pdf\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## z vs entropy (not in paper)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(f\"groupby legend: {groupby_fields}\")\n", "# hist_subset = grouped_df.get_group((True,1,2.0,0.1)) # needs to match the groupby keys and order\n", "# hist_subset = grouped_df.get_group((True,1,2.0,0.25)) \n", "hist_subset = grouped_df.get_group((True,1,2.0,0.5)) \n", "# hist_subset = grouped_df.get_group((True,1,2.0,0.75)) \n", "# hist_subset = grouped_df.get_group((True,1,2.0,0.9)) " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(len(hist_subset))\n", "# hist_subset = hist_subset[hist_subset[\"w_bl_space_frac\"] <= 0.9]\n", "# hist_subset = hist_subset[hist_subset[\"no_bl_space_frac\"] <= 0.9]\n", "# print(len(hist_subset))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# y = hist_subset[\"w_bl_z_score\"]\n", "# y = hist_subset[\"no_bl_z_score\"]\n", "y = hist_subset[\"baseline_z_score\"]\n", "\n", "x = hist_subset[\"avg_spike_entropy\"]\n", "\n", "plt.clf()\n", "\n", "\n", "plt.scatter(x, y)\n", "\n", "\n", "plt.grid()\n", "\n", "plt.xlabel(\"Entropy\")\n", "plt.ylabel(\"z-score\")\n", "\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cols_to_tabulate = [\n", " 'idx', \n", " 'truncated_input', \n", " 'baseline_completion',\n", " 'no_bl_output', \n", " 'w_bl_output', \n", " 'avg_spike_entropy',\n", " 'no_bl_z_score',\n", " 'w_bl_z_score',\n", " 'w_bl_whitelist_fraction',\n", " 'no_bl_whitelist_fraction',\n", " 'baseline_ppl',\n", " 'no_bl_ppl',\n", " 'w_bl_ppl'\n", "]\n", "\n", "slice_size = 10\n", "\n", "num_examples = len(hist_subset)\n", "midpt = num_examples//5\n", "lower = midpt - (slice_size//2)\n", "upper = midpt + (slice_size//2)+1\n", "\n", "high_entropy_examples = hist_subset[cols_to_tabulate].sort_values([\"avg_spike_entropy\"],ascending=True).tail(slice_size)\n", "mid_entropy_examples = hist_subset[cols_to_tabulate].sort_values([\"avg_spike_entropy\"],ascending=True).iloc[lower:upper]\n", "low_entropy_examples = hist_subset[cols_to_tabulate].sort_values([\"avg_spike_entropy\"],ascending=True).head(slice_size)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# hist_subset[cols_to_tabulate][(hist_subset[\"avg_spike_entropy\"]<0.7)&(hist_subset[\"w_bl_z_score\"]>=14.0)]\n", "hist_subset[cols_to_tabulate][(hist_subset[\"avg_spike_entropy\"]<0.7)&(hist_subset[\"baseline_z_score\"]>=7.0)]\n", "# hist_subset[cols_to_tabulate][(hist_subset[\"avg_spike_entropy\"]<0.7)&(hist_subset[\"w_bl_z_score\"]>=12.0)]\n", "# print(hist_subset[cols_to_tabulate][(hist_subset[\"avg_spike_entropy\"]<0.7)&(hist_subset[\"w_bl_z_score\"]>=14.0)].iloc[6][\"w_bl_output\"])\n", "# .to_csv(\"input/pile_low_S_high_z_outliers.csv\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "high_entropy_examples" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mid_entropy_examples" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "low_entropy_examples" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# plotting histograms of the metric for single runs (not in paper)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(f\"groupby legend: {groupby_fields}\")\n", "# hist_subset = grouped_df.get_group((True,1,2.0,0.1)) # needs to match the groupby keys and order\n", "# hist_subset = grouped_df.get_group((True,1,2.0,0.25)) \n", "hist_subset = grouped_df.get_group((True,1,2.0,0.5)) \n", "# hist_subset = grouped_df.get_group((True,1,2.0,0.75)) \n", "# hist_subset = grouped_df.get_group((True,1,2.0,0.9)) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### old filters to smooth the histograms" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# hist_subset = hist_subset[(hist_subset[\"no_bl_num_tokens_generated\"] == hist_subset[\"max_new_tokens\"]) & (hist_subset[\"w_bl_num_tokens_generated\"] == hist_subset[\"max_new_tokens\"])]\n", "# hist_subset = hist_subset[hist_subset[\"truncated_input\"] != \"\"]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "all_no_bl_wl_fractions = hist_subset[\"no_bl_whitelist_fraction\"]\n", "all_w_bl_wl_fractions = hist_subset[\"w_bl_whitelist_fraction\"]\n", "all_baseline_wl_fractions = hist_subset[\"baseline_whitelist_fraction\"]\n", "# all_no_bl_wl_fractions = hist_subset[\"no_bl_z_score\"]\n", "# all_w_bl_wl_fractions = hist_subset[\"w_bl_z_score\"]\n", "# all_baseline_wl_fractions = hist_subset[\"baseline_z_score\"]\n", "\n", "plt.clf()\n", "\n", "all_vals = np.concatenate([all_baseline_wl_fractions, all_w_bl_wl_fractions, all_no_bl_wl_fractions])\n", "n_bins = 50\n", "bins = np.linspace(np.min(all_vals), np.max(all_vals), n_bins)\n", "# bins = np.linspace(0.0, 1.0, n_bins)\n", "\n", "# plt.hist(all_no_bl_wl_fractions, \n", "# bins=bins,\n", "# alpha=0.6,\n", "# label='no blacklisting')\n", "\n", "\n", "plt.hist(all_w_bl_wl_fractions, \n", " bins=bins,\n", " alpha=0.6,\n", " label='with blacklisting')\n", "\n", "plt.hist(all_baseline_wl_fractions,\n", " bins=bins,\n", " alpha=0.4,\n", " # label='wl')\n", " label='ground truth/real text')\n", "\n", "# plt.hist(all_baseline_bl_fractions, \n", "# bins=bins,\n", "# alpha=0.5,\n", "# label='bl')\n", "\n", "plt.legend(loc='upper right')\n", "\n", "# plt.xlim((-0.1,1.1))\n", "# plt.xticks(np.arange(0.0,1.0,0.1))\n", "plt.xlabel(\"fraction of total toks gen'd in WL\")\n", "plt.ylabel(\"freq\")\n", "\n", "# plt.title('baseline wl/bl fractions')\n", "plt.title(\"Output Whitelist Token Distribution\")\n", "\n", "# plot_name = \"wl_distro\"\n", "# fname = f\"figs/{plot_name}.png\"\n", "# plt.savefig(fname, dpi=600)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.clf()\n", "\n", "all_no_bl_ppls = hist_subset[\"no_bl_ppl\"]\n", "all_w_bl_ppls = hist_subset[\"w_bl_ppl\"]\n", "all_baseline_ppls = hist_subset[\"baseline_ppl\"]\n", "\n", "all_vals = list(np.concatenate([all_no_bl_ppls, all_w_bl_ppls]))\n", "all_vals = sorted(all_vals)\n", "n_bins = 50\n", "# bins = np.linspace(all_vals[0], all_vals[-1], n_bins)\n", "bins = np.linspace(all_vals[0], 20, n_bins)\n", "\n", "plt.hist(all_no_bl_ppls, \n", " bins=bins,\n", " alpha=0.6,\n", " label='no blacklisting')\n", "\n", "plt.hist(all_w_bl_ppls, \n", " bins=bins,\n", " alpha=0.6,\n", " label='with blacklisting')\n", "\n", "plt.legend(loc='upper right')\n", "\n", "# plt.xlim((0,1))\n", "plt.xlabel(\"perplexity (lower is better)\")\n", "plt.ylabel(\"freq\")\n", "\n", "plt.title('Model-based Output Quality/Fluency')\n", "\n", "# plot_name = \"ppl_no_baseline\"\n", "# fname = f\"figs/{plot_name}.png\"\n", "# plt.savefig(fname, dpi=600)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.clf()\n", "\n", "all_vals = list(np.concatenate([all_no_bl_ppls, all_w_bl_ppls]))\n", "all_vals = sorted(all_vals)\n", "n_bins = 50\n", "# bins = np.linspace(all_vals[0], all_vals[-1], n_bins)\n", "bins = np.linspace(all_vals[0], 20, n_bins)\n", "\n", "plt.hist(all_no_bl_ppls, \n", " bins=bins,\n", " alpha=0.6,\n", " label='no blacklisting')\n", "\n", "plt.hist(all_w_bl_ppls, \n", " bins=bins,\n", " alpha=0.6,\n", " label='with blacklisting')\n", "\n", "plt.hist(all_baseline_ppls, \n", " bins=bins, \n", " alpha=0.4,\n", " label='ground truth/real text')\n", "\n", "plt.legend(loc='upper right')\n", "\n", "# plt.xlim((0,1))\n", "plt.xlabel(\"perplexity (lower is better)\")\n", "plt.ylabel(\"freq\")\n", "\n", "plt.title('Model-based Output Quality/Fluency')\n", "\n", "# plot_name = \"ppl_w_baseline\"\n", "# fname = f\"figs/{plot_name}.png\"\n", "# plt.savefig(fname, dpi=600)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" }, "vscode": { "interpreter": { "hash": "365524a309ad80022da286f2ec5d2060ce5cb229abb6076cf68d9a1ab14bd8fe" } } }, "nbformat": 4, "nbformat_minor": 4 }