{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Watermark Analysis\n", "\n", "Notebook for performing analysis and visualization of the effects of watermarking schemes" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from datasets import load_from_disk" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Load the processed dataset/frame" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "save_name = \"analysis_ds_1-19_realnews_1-3_v1\" # in figure\n", "# save_name = \"analysis_ds_1-21_greedy_redo\" \n", "# save_name = \"analysis_ds_1-21_greedy_redo_truncated\" # in figure\n", "\n", "# save_name = \"analysis_ds_1-20_more_attack\" # in figure\n", "\n", "save_dir = f\"input/{save_name}\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "raw_data = load_from_disk(save_dir)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### convert to pandas df" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df = raw_data.to_pandas()\n", "\n", "retok_problematic_rows = df[(df['w_bl_whitelist_fraction'] != -1.0) & (df['w_bl_whitelist_fraction'] != 1.0) & (df['bl_type'] == 'hard')]\n", "print(f\"Num rows that are hard-blacklisted, and measureable, but still have a non-100% WL fraction: {len(retok_problematic_rows)} out of {len(df[df['bl_type'] == 'hard'])}\")\n", "\n", "orig_len = len(df)\n", "\n", "df = df[df[\"no_bl_whitelist_fraction\"] != -1.0]\n", "df = df[df[\"w_bl_whitelist_fraction\"] != -1.0]\n", "\n", "print(f\"Dropped {orig_len-len(df)} rows, new len {len(df)}\")\n", "\n", "orig_len = len(df)\n", "# df = df[df[\"no_bl_ppl\"].isna()]\n", "# df = df[df[\"w_bl_ppl\"].isna()]\n", "df = df[~(df[\"no_bl_ppl\"].isna() | df[\"w_bl_ppl\"].isna())]\n", "print(f\"Dropped {orig_len-len(df)} rows, new len {len(df)}\")\n", "\n", "orig_len = len(df)\n", "\n", "df = df[df[\"bl_logit_bias\"] <= 100.0]\n", "\n", "print(f\"Dropped {orig_len-len(df)} rows, new len {len(df)}\")\n", "\n", "\n", "orig_len = len(df)\n", "\n", "# df = df[df[\"bl_hparams\"].apply(lambda tup: (tup[0] == False and tup[2] != 1) or (tup[0] == True and tup[2] == 1) or (tup[0] == False))]\n", "df = df[((df[\"use_sampling\"]==True) & (df[\"num_beams\"] == 1)) | (df[\"use_sampling\"]==False)]\n", "\n", "print(f\"Dropped {orig_len-len(df)} rows, new len {len(df)}\")\n", "\n", "\n", "df.loc[df[\"use_sampling\"]==False,\"sampling_temp\"] = df.loc[df[\"use_sampling\"]==False,\"sampling_temp\"].fillna(0.0)\n", "df.loc[df[\"use_sampling\"]==True,\"sampling_temp\"] = df.loc[df[\"use_sampling\"]==True,\"sampling_temp\"].fillna(1.0)\n", "\n", "\n", "df.loc[df[\"bl_type\"]==\"hard\",\"bl_logit_bias\"] = np.inf\n", "# df.loc[df[\"bl_type\"]==\"hard\",\"bl_logit_bias\"] = 10000 # crosscheck with whats hardcoded in the bl processor\n", "\n", "\n", "df[\"delta\"] = df[\"bl_logit_bias\"].values\n", "df[\"gamma\"] = 1 - df[\"bl_proportion\"].values\n", "df[\"gamma\"] = df[\"gamma\"].round(3)\n", "\n", "df[\"no_bl_act_num_wl_tokens\"] = np.round(df[\"no_bl_whitelist_fraction\"].values*df[\"no_bl_num_tokens_generated\"],1) # round to 1 for sanity\n", "df[\"w_bl_act_num_wl_tokens\"] = np.round(df[\"w_bl_whitelist_fraction\"].values*df[\"w_bl_num_tokens_generated\"],1) # round to 1 for sanity\n", "\n", "df[\"w_bl_std_num_wl_tokens\"] = np.sqrt(df[\"w_bl_var_num_wl_tokens\"].values)\n", "\n", "if \"real_completion_length\":\n", " df[\"baseline_num_tokens_generated\"] = df[\"real_completion_length\"].values\n", "\n", "if \"actual_attacked_ratio\" in df.columns:\n", " df[\"actual_attacked_fraction\"] = df[\"actual_attacked_ratio\"].values*df[\"replace_ratio\"].values\n", "\n", "\n", "\n", "df[\"baseline_hit_list_length\"] = df[\"baseline_hit_list\"].apply(len)\n", "df[\"no_bl_hit_list_length\"] = df[\"no_bl_hit_list\"].apply(len)\n", "df[\"w_bl_hit_list_length\"] = df[\"w_bl_hit_list\"].apply(len)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Filter for the generation lengths we want to look at" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "orig_len = len(df)\n", "\n", "upper_T = 205\n", "lower_T = 195\n", "df = df[(df[\"baseline_hit_list_length\"] >= lower_T) & (df[\"no_bl_hit_list_length\"] >= lower_T) & (df[\"w_bl_hit_list_length\"] >= lower_T)] # now also applies to the truncated version\n", "df = df[(df[\"baseline_hit_list_length\"] <= upper_T) & (df[\"no_bl_hit_list_length\"] <= upper_T) & (df[\"w_bl_hit_list_length\"] <= upper_T)] # now also applies to the truncated version\n", "\n", "\n", "print(f\"Dropped {orig_len-len(df)} rows, new len {len(df)}\")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "#### Add z-scores" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from math import sqrt\n", "import scipy.stats\n", "def compute_z_score(observed_wl_frac, T, gamma):\n", " numer = observed_wl_frac - gamma\n", " denom = sqrt(gamma*(1-gamma)/T)\n", " z = numer/denom\n", " return z\n", "\n", "def compute_wl_for_z(z, T, gamma):\n", " denom = sqrt(gamma*(1-gamma)/T)\n", " numer = ((z*denom)+gamma)*T\n", " return numer\n", "\n", "def compute_p_value(z):\n", " p_value = scipy.stats.norm.sf(abs(z))\n", " return p_value\n", "\n", "df[\"baseline_z_score\"] = df[[\"baseline_whitelist_fraction\", \"baseline_num_tokens_generated\", \"gamma\"]].apply(lambda tup: compute_z_score(*tup), axis=1)\n", "df[\"no_bl_z_score\"] = df[[\"no_bl_whitelist_fraction\", \"no_bl_num_tokens_generated\", \"gamma\"]].apply(lambda tup: compute_z_score(*tup), axis=1)\n", "df[\"w_bl_z_score\"] = df[[\"w_bl_whitelist_fraction\", \"w_bl_num_tokens_generated\", \"gamma\"]].apply(lambda tup: compute_z_score(*tup), axis=1)\n", "\n", "if \"w_bl_attacked_whitelist_fraction\" in df.columns:\n", " df[\"w_bl_attacked_z_score\"] = df[[\"w_bl_attacked_whitelist_fraction\", \"w_bl_attacked_num_tokens_generated\", \"gamma\"]].apply(lambda tup: compute_z_score(*tup), axis=1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# # if attacked in df\n", "if \"w_bl_attacked_whitelist_fraction\" in df.columns:\n", " df[\"w_bl_attacked_act_num_wl_tokens\"] = np.round(df[\"w_bl_attacked_whitelist_fraction\"].values*df[\"w_bl_attacked_num_tokens_generated\"],1) # round to 1 for sanity\n", "\n", " df[\"w_bl_attacked_z_score\"] = df[[\"w_bl_attacked_whitelist_fraction\", \"w_bl_attacked_num_tokens_generated\", \"gamma\"]].apply(lambda tup: compute_z_score(*tup), axis=1)\n", "\n", " df[[\"bl_proportion\",\"w_bl_attacked_whitelist_fraction\", \"w_bl_attacked_num_tokens_generated\",\"w_bl_attacked_act_num_wl_tokens\", \"w_bl_attacked_z_score\"]]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Groupby" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if \"w_bl_attacked_whitelist_fraction\" in df.columns: \n", " groupby_fields = ['use_sampling','num_beams','gamma','delta', 'replace_ratio'] # attack grouping\n", "else:\n", " groupby_fields = ['use_sampling','num_beams','delta','gamma'] # regular grouping\n", " # groupby_fields = ['use_sampling','delta','gamma'] # regular grouping, but no beam variation\n", " # groupby_fields = ['delta','gamma'] # regular grouping, but no beam variation, and all sampling" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "#### Main groupby" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "grouped_df = df.groupby(groupby_fields)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(f\"Number of rows after filtering: {len(df)}\")\n", "print(f\"Number of groups: {len(grouped_df)}\")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Loop to compute confusion matrix at some z scores for tabulation" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sklearn.metrics as metrics\n", "\n", "def reject_null_hypo(z_score=None,cuttoff=None):\n", " return z_score > cuttoff\n", "\n", "records = []\n", "\n", "for group_params in tqdm(list(grouped_df.groups.keys())):\n", " sub_df = grouped_df.get_group(group_params)\n", " grp_size = len(sub_df)\n", "\n", " # baseline_z_scores = sub_df[\"baseline_z_score\"].values\n", " # w_bl_z_scores = sub_df[\"w_bl_z_score\"].values\n", " # all_scores = np.concatenate([baseline_z_scores,w_bl_z_scores])\n", "\n", " # baseline_labels = np.zeros_like(baseline_z_scores)\n", " # attacked_labels = np.ones_like(w_bl_z_scores)\n", " # all_labels = np.concatenate([baseline_labels,attacked_labels])\n", "\n", " # fpr, tpr, thresholds = metrics.roc_curve(all_labels, all_scores, pos_label=1)\n", " # roc_auc = metrics.auc(fpr, tpr)\n", " record = {k:v for k,v in zip(groupby_fields,group_params)}\n", "\n", " for thresh in [4.0,5.0]:\n", " \n", " record[\"count\"] = grp_size\n", " record[f\"baseline_fpr_at_{thresh}\"] = reject_null_hypo(z_score=sub_df[\"baseline_z_score\"].values,cuttoff=thresh).sum() / grp_size\n", " record[f\"baseline_tnr_at_{thresh}\"] = (~reject_null_hypo(z_score=sub_df[\"baseline_z_score\"],cuttoff=thresh)).sum() / grp_size\n", " record[f\"no_bl_fpr_at_{thresh}\"] = reject_null_hypo(z_score=sub_df[\"no_bl_z_score\"].values,cuttoff=thresh).sum() / grp_size\n", " record[f\"no_bl_tnr_at_{thresh}\"] = (~reject_null_hypo(z_score=sub_df[\"no_bl_z_score\"].values,cuttoff=thresh)).sum() / grp_size\n", " record[f\"w_bl_tpr_at_{thresh}\"] = reject_null_hypo(z_score=sub_df[\"w_bl_z_score\"].values,cuttoff=thresh).sum() / grp_size\n", " record[f\"w_bl_fnr_at_{thresh}\"] = (~reject_null_hypo(z_score=sub_df[\"w_bl_z_score\"].values,cuttoff=thresh)).sum() / grp_size\n", "\n", " if \"w_bl_attacked_z_score\" in sub_df.columns:\n", " record[f\"w_bl_attacked_tpr_at_{thresh}\"] = reject_null_hypo(z_score=sub_df[\"w_bl_attacked_z_score\"].values,cuttoff=thresh).sum() / grp_size\n", " record[f\"w_bl_attacked_fnr_at_{thresh}\"] = (~reject_null_hypo(z_score=sub_df[\"w_bl_attacked_z_score\"].values,cuttoff=thresh)).sum() / grp_size\n", "\n", " records.append(record)\n", "\n", " # # df[f\"baseline_fp_at_{thresh}\"] = reject_null_hypo(z_score=df[\"baseline_z_score\"].values,cuttoff=thresh)\n", " # # df[f\"baseline_tn_at_{thresh}\"] = ~reject_null_hypo(z_score=df[\"baseline_z_score\"],cuttoff=thresh)\n", " # # df[f\"no_bl_fp_at_{thresh}\"] = reject_null_hypo(z_score=df[\"no_bl_z_score\"].values,cuttoff=thresh)\n", " # # df[f\"no_bl_tn_at_{thresh}\"] = ~reject_null_hypo(z_score=df[\"no_bl_z_score\"].values,cuttoff=thresh)\n", " # # df[f\"w_bl_tp_at_{thresh}\"] = reject_null_hypo(z_score=df[\"w_bl_z_score\"].values,cuttoff=thresh)\n", " # # df[f\"w_bl_fn_at_{thresh}\"] = ~reject_null_hypo(z_score=df[\"w_bl_z_score\"].values,cuttoff=thresh)\n", "\n", "\n", "roc_df = pd.DataFrame.from_records(records)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# thresh = 6.0\n", "# thresh = 5.0\n", "std_threshes = [4.0, 5.0] #, 6.0]\n", "# std_threshes = [4.0]\n", "\n", "# roc_df[\"params\"] = roc_df.index.to_list()\n", "\n", "columns = [\"delta\", \"gamma\", \"count\"]\n", "# columns = [\"use_sampling\", \"replace_ratio\", \"count\"]\n", "\n", "for thresh in std_threshes:\n", " # columns += [f\"baseline_fpr_at_{thresh}\",f\"no_bl_fpr_at_{thresh}\",f\"w_bl_tpr_at_{thresh}\"]\n", " # columns += [f\"baseline_fpr_at_{thresh}\",f\"baseline_tnr_at_{thresh}\",f\"no_bl_fpr_at_{thresh}\",f\"no_bl_tnr_at_{thresh}\",f\"w_bl_tpr_at_{thresh}\",f\"w_bl_fn_at_{thresh}\"]\n", "\n", "\n", " # columns += [f\"baseline_fpr_at_{thresh}\",f\"baseline_tnr_at_{thresh}\",f\"w_bl_tpr_at_{thresh}\",f\"w_bl_fnr_at_{thresh}\"]\n", " \n", " if f\"w_bl_attacked_fnr_at_{thresh}\" in roc_df.columns:\n", " columns += [f\"w_bl_tpr_at_{thresh}\",f\"w_bl_fnr_at_{thresh}\"]\n", " columns += [f\"w_bl_attacked_tpr_at_{thresh}\",f\"w_bl_attacked_fnr_at_{thresh}\"] # if attack\n", " else:\n", " columns += [f\"baseline_fpr_at_{thresh}\",f\"baseline_tnr_at_{thresh}\",f\"w_bl_tpr_at_{thresh}\",f\"w_bl_fnr_at_{thresh}\"]\n", "\n", "# filter ot not\n", "sub_df = roc_df[(roc_df[\"use_sampling\"] == True) & ((roc_df[\"delta\"] == 1.0) | (roc_df[\"delta\"] == 2.0) | (roc_df[\"delta\"] == 10.0)) & ((roc_df[\"gamma\"] == 0.1) | (roc_df[\"gamma\"] == 0.25) |(roc_df[\"gamma\"] == 0.5) )]\n", "# sub_df = roc_df[(roc_df[\"replace_ratio\"] == 0.1) | (roc_df[\"replace_ratio\"] == 0.3) | (roc_df[\"replace_ratio\"] == 0.5) | (roc_df[\"replace_ratio\"] == 0.7)]\n", "# sub_df = roc_df\n", "\n", "sub_df.sort_values(\"delta\")[columns]\n", "# sub_df.sort_values(\"num_beams\")[columns]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# print(roc_df[columns].drop([\"count\"],axis=1).sort_values(\"gamma\").round(3).to_latex(index=False))\n", "# print(roc_df[columns].drop([\"count\"],axis=1).sort_values(\"delta\").round(3).to_latex(index=False))\n", "# print(roc_df[columns].drop([\"count\"],axis=1).sort_values(\"num_beams\").round(3).to_latex(index=False))\n", "\n", "print(sub_df.sort_values(\"delta\")[columns].round(3).to_latex(index=False))\n", "# print(sub_df.sort_values(\"num_beams\")[columns].round(3).to_latex(index=False))" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### write to csv maybe" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# cols_to_drop = ['no_bl_gen_time',\n", "# 'w_bl_gen_time', 'spike_entropies', \n", "# 'no_bl_sec_per_tok', 'no_bl_tok_per_sec', 'w_bl_sec_per_tok',\n", "# 'w_bl_tok_per_sec', 'baseline_loss','no_bl_loss',\n", "# 'w_bl_loss', 'model_name', 'dataset_name',\n", "# 'dataset_config_name', 'shuffle_dataset', 'shuffle_seed',\n", "# 'shuffle_buffer_size', 'max_new_tokens', 'min_prompt_tokens',\n", "# 'limit_indices', 'input_truncation_strategy',\n", "# 'input_filtering_strategy', 'output_filtering_strategy', 'initial_seed',\n", "# 'dynamic_seed','no_repeat_ngram_size', 'early_stopping',\n", "# 'oracle_model_name', 'no_wandb', 'wandb_project', 'wandb_entity', 'output_dir', 'load_prev_generations', 'store_bl_ids',\n", "# 'store_spike_ents', 'generate_only',\n", "# 'SLURM_JOB_ID', 'SLURM_ARRAY_JOB_ID', 'SLURM_ARRAY_TASK_ID',\n", "# 'gen_table_already_existed', 'baseline_num_toks_gend_eq_0',\n", "# 'baseline_hit_list', 'no_bl_num_toks_gend_eq_0',\n", "# 'no_bl_hit_list', 'w_bl_num_toks_gend_eq_0', 'w_bl_hit_list']\n", "# df.drop(cols_to_drop,axis=1).to_csv(\"input/for_poking.csv\")\n", "# df" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df.columns" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Extract examples (actual text) for tabulation based on entropy and z scores (tables 1,3,4,5,6)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(f\"groupby legend: {groupby_fields}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "groups = [\n", " (True, 1, 2.0, 0.5),\n", " # (True, 1, 10.0, 0.5),\n", " # (False, 8, 2.0, 0.5),\n", " # (False, 8, 10.0, 0.5),\n", "]\n", "group_dfs = []\n", "for group in groups:\n", " sub_df = grouped_df.get_group(group)\n", " group_dfs.append(sub_df)\n", "\n", "subset_df = pd.concat(group_dfs,axis=0)\n", "\n", "print(len(subset_df))\n", "# subset_df\n", "\n", "# cols_to_tabulate = groupby_fields + [\n", "cols_to_tabulate = [\n", " 'idx', \n", " 'truncated_input', \n", " # 'prompt_length',\n", " 'baseline_completion',\n", " 'no_bl_output', \n", " 'w_bl_output', \n", " # 'real_completion_length',\n", " # 'no_bl_num_tokens_generated',\n", " # 'w_bl_num_tokens_generated',\n", " 'avg_spike_entropy',\n", " # 'baseline_whitelist_fraction',\n", " # 'no_bl_whitelist_fraction',\n", " # 'w_bl_whitelist_fraction',\n", " # 'baseline_z_score',\n", " 'no_bl_z_score',\n", " 'w_bl_z_score',\n", " # 'baseline_ppl',\n", " 'no_bl_ppl',\n", " 'w_bl_ppl'\n", "]\n", "\n", "# subset_df[cols_to_tabulate][\"idx\"].value_counts()\n", "\n", "for idx,occurrences in subset_df[\"idx\"].value_counts().to_dict().items():\n", " subset_df.loc[(subset_df[\"idx\"]==idx),\"occurences\"] = occurrences\n", "\n", "subset_df[\"occurences\"] = subset_df[\"occurences\"].apply(int)\n", "\n", "# cols_to_tabulate = [\"occurences\"] + cols_to_tabulate" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# subset_df[cols_to_tabulate].sort_values([\"occurences\", \"idx\"],ascending=False)\n", "# subset_df[cols_to_tabulate].sort_values([\"avg_spike_entropy\"],ascending=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "max_prompt_chars = 200\n", "max_output_chars = 200\n", "# subset_df[\"truncated_input\"] = subset_df[\"truncated_input\"].apply(lambda s: f\"[...]{s[-max_prompt_chars:]}\")\n", "# subset_df[\"baseline_completion\"] = subset_df[\"baseline_completion\"].apply(lambda s: f\"{s[:max_output_chars]}[...truncated]\")\n", "# subset_df[\"no_bl_output\"] = subset_df[\"no_bl_output\"].apply(lambda s: f\"{s[:max_output_chars]}[...truncated]\")\n", "# subset_df[\"w_bl_output\"] = subset_df[\"w_bl_output\"].apply(lambda s: f\"{s[:max_output_chars]}[...truncated]\")\n", "\n", "# if you dont have the indexx you cant start with brackets\n", "subset_df[\"truncated_input\"] = subset_df[\"truncated_input\"].apply(lambda s: f\"(...){s[-max_prompt_chars:]}\")\n", "subset_df[\"baseline_completion\"] = subset_df[\"baseline_completion\"].apply(lambda s: f\"{s[:max_output_chars]}[...continues]\")\n", "subset_df[\"no_bl_output\"] = subset_df[\"no_bl_output\"].apply(lambda s: f\"{s[:max_output_chars]}[...continues]\")\n", "subset_df[\"w_bl_output\"] = subset_df[\"w_bl_output\"].apply(lambda s: f\"{s[:max_output_chars]}[...continues]\")\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "slice_size = 2\n", "\n", "# subset_df[cols_to_tabulate][\"avg_spike_entropy\"].describe()[]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "num_examples = len(subset_df)\n", "midpt = num_examples//5\n", "lower = midpt - (slice_size//2)\n", "upper = midpt + (slice_size//2)+1\n", "\n", "high_entropy_examples = subset_df[cols_to_tabulate].sort_values([\"avg_spike_entropy\"],ascending=True).tail(slice_size)\n", "mid_entropy_examples = subset_df[cols_to_tabulate].sort_values([\"avg_spike_entropy\"],ascending=True).iloc[lower:upper]\n", "low_entropy_examples = subset_df[cols_to_tabulate].sort_values([\"avg_spike_entropy\"],ascending=True).head(slice_size)\n", "\n", "num_examples = len(subset_df)\n", "midpt = num_examples//65\n", "lower = midpt - (slice_size//2)\n", "upper = midpt + (slice_size//2)+1\n", "\n", "high_z_examples = subset_df[cols_to_tabulate].sort_values([\"w_bl_z_score\"],ascending=True).tail(slice_size)\n", "mid_z_examples = subset_df[cols_to_tabulate].sort_values([\"w_bl_z_score\"],ascending=True).iloc[lower:upper]\n", "low_z_examples = subset_df[cols_to_tabulate].sort_values([\"w_bl_z_score\"],ascending=True).head(slice_size)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# high_entropy_examples.head()\n", "high_z_examples.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# mid_entropy_examples.head()\n", "mid_z_examples.head()\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# low_entropy_examples.head()\n", "low_z_examples.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# slices_set_df = pd.concat([high_entropy_examples,low_entropy_examples],axis=0)\n", "slices_set_df = pd.concat([high_z_examples,low_z_examples],axis=0).sort_values(\"w_bl_z_score\",ascending=False)\n", "slices_set_df" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# slices_set_df.T.iloc[:,0:2]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# print(slices_set_df.to_latex(index=False))\n", "# print(low_entropy_examples.to_latex(index=False))\n", "# print(mid_entropy_examples.to_latex(index=False))\n", "# print(high_entropy_examples.to_latex(index=False))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# for c,t in zip(low_entropy_examples.columns,low_entropy_examples.dtypes):\n", "# if t==object:\n", "# low_entropy_examples[c] = low_entropy_examples[c].apply(lambda s: f\"{s[:100]}[...truncated]\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# low_entropy_examples.T.to_latex(buf=open(\"figs/low_ent_examples.txt\", \"w\"),index=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# df_to_write = high_entropy_examples\n", "# df_to_write = mid_entropy_examples\n", "# df_to_write = low_entropy_examples\n", "# df_to_write = high_z_examples\n", "# df_to_write = mid_z_examples\n", "# df_to_write = low_z_examples\n", "\n", "cols_to_drop = [\"idx\", \"avg_spike_entropy\", \"no_bl_z_score\"] #, \"no_bl_ppl\", \"w_bl_ppl\"]\n", "df_to_write = slices_set_df.drop(cols_to_drop,axis=1)\n", "\n", "\n", "with pd.option_context(\"max_colwidth\", 1000):\n", " column_format=\"\".join([(r'p{3cm}|' if t==object else r'p{0.4cm}|') for c,t in zip(df_to_write.columns,df_to_write.dtypes)])[:-1]\n", " # low_entropy_examples.round(2).to_latex(buf=open(\"figs/low_ent_examples.txt\", \"w\"),column_format=column_format,index=False)\n", " latex_str = df_to_write.round(2).to_latex(column_format=column_format,index=False)\n", "\n", "print(latex_str)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# column_format=\"\".join([r'p{2cm}|' for c in low_entropy_examples.columns])\n", "# column_format" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# low_entropy_examples.dtypes" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with pd.option_context(\"max_colwidth\", 1000):\n", " print(grouped_df.get_group((True, 1, 2.0, 0.9)).head(10)[\"w_bl_output\"])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Set up data for charts" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# viz_df = pd.DataFrame()\n", "\n", "# # set the hparam keys, including an indiv column for each you want to ablate on\n", "# viz_df[\"bl_hparams\"] = grouped_df[\"w_bl_exp_whitelist_fraction\"].describe().index.to_list()\n", "# for i,key in enumerate(groupby_fields):\n", "# viz_df[key] = viz_df[\"bl_hparams\"].apply(lambda tup: tup[i])\n", "\n", "# describe_dict = grouped_df[\"w_bl_whitelist_fraction\"].describe()\n", "# viz_df[\"w_bl_whitelist_fraction_mean\"] = describe_dict[\"mean\"].to_list()\n", "# viz_df[\"w_bl_whitelist_fraction_std\"] = describe_dict[\"std\"].to_list()\n", "\n", "# describe_dict = grouped_df[\"no_bl_whitelist_fraction\"].describe()\n", "# viz_df[\"no_bl_whitelist_fraction_mean\"] = describe_dict[\"mean\"].to_list()\n", "# viz_df[\"no_bl_whitelist_fraction_std\"] = describe_dict[\"std\"].to_list()\n", "\n", "\n", "# describe_dict = grouped_df[\"w_bl_z_score\"].describe()\n", "# viz_df[\"w_bl_z_score_mean\"] = describe_dict[\"mean\"].to_list()\n", "# viz_df[\"w_bl_z_score_std\"] = describe_dict[\"std\"].to_list()\n", "\n", "# describe_dict = grouped_df[\"no_bl_z_score\"].describe()\n", "# viz_df[\"no_bl_z_score_mean\"] = describe_dict[\"mean\"].to_list()\n", "# viz_df[\"no_bl_z_score_std\"] = describe_dict[\"std\"].to_list()\n", "\n", "\n", "# describe_dict = grouped_df[\"w_bl_ppl\"].describe()\n", "# viz_df[\"w_bl_ppl_mean\"] = describe_dict[\"mean\"].to_list()\n", "# viz_df[\"w_bl_ppl_std\"] = describe_dict[\"std\"].to_list()\n", "\n", "# describe_dict = grouped_df[\"no_bl_ppl\"].describe()\n", "# viz_df[\"no_bl_ppl_mean\"] = describe_dict[\"mean\"].to_list()\n", "# viz_df[\"no_bl_ppl_std\"] = describe_dict[\"std\"].to_list()\n", "\n", "# describe_dict = grouped_df[\"avg_spike_entropy\"].describe()\n", "# viz_df[\"avg_spike_entropy_mean\"] = describe_dict[\"mean\"].to_list()\n", "# viz_df[\"avg_spike_entropy_std\"] = describe_dict[\"std\"].to_list()\n", "\n", "# print(f\"groupby legend: {groupby_fields}\")\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# # filtering\n", "\n", "# viz_df = viz_df[viz_df[\"bl_hparams\"].apply(lambda tup: (tup[0] == True))] # sampling\n", "\n", "# # viz_df = viz_df[viz_df[\"bl_hparams\"].apply(lambda tup: (tup[0] == False))] # greedy\n", "\n", "\n", "# # fix one of the bl params for analytic chart\n", "# viz_df = viz_df[(viz_df[\"gamma\"]==0.5) & (viz_df[\"delta\"]<=10.0)]\n", "\n", "# # viz_df = viz_df[(viz_df[\"delta\"] > 0.5) & (viz_df[\"delta\"]<=10.0)]\n", "\n", "# # viz_df = viz_df[(viz_df[\"delta\"]==0.5) | (viz_df[\"delta\"]==2.0) | (viz_df[\"delta\"]==10.0)]\n", "\n", "# # viz_df = viz_df[(viz_df[\"delta\"]!=0.1)&(viz_df[\"delta\"]!=0.5)&(viz_df[\"delta\"]!=50.0)]\n", "\n", "# # viz_df = viz_df[(viz_df[\"delta\"]!=50.0)]\n", "# # viz_df = viz_df[(viz_df[\"delta\"]!=50.0) & (viz_df[\"num_beams\"]!=1)]\n", "\n", "# print(len(viz_df))\n", "\n", "# viz_df" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Visualize the WL/BL hits via highlighting in html" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# idx = 75\n", "# # idx = 62\n", "\n", "# # debug\n", "# # idx = 7\n", "# # idx = 18\n", "# # idx = 231\n", "\n", "# print(gen_table_w_bl_stats[idx])\n", "# print(f\"\\nPrompt:\",gen_table_w_bl_stats[idx][\"truncated_input\"])\n", "# print(f\"\\nBaseline (real text):{gen_table_w_bl_stats[idx]['baseline_completion']}\")\n", "# print(f\"\\nNo Blacklist:{gen_table_w_bl_stats[idx]['no_bl_output']}\")\n", "# print(f\"\\nw/ Blacklist:{gen_table_w_bl_stats[idx]['w_bl_output']}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# from ipymarkup import show_span_box_markup, get_span_box_markup\n", "# from ipymarkup.palette import palette, RED, GREEN, BLUE\n", "\n", "# from IPython.display import display, HTML\n", "\n", "# from transformers import GPT2TokenizerFast\n", "# # fast_tokenizer = GPT2TokenizerFast.from_pretrained(\"gpt2\")\n", "# fast_tokenizer = GPT2TokenizerFast.from_pretrained(\"facebook/opt-2.7b\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# %autoreload\n", "\n", "# vis_bl = partial(\n", "# compute_bl_metrics,\n", "# tokenizer=fast_tokenizer,\n", "# hf_model_name=gen_table_meta[\"model_name\"],\n", "# initial_seed=gen_table_meta[\"initial_seed\"],\n", "# dynamic_seed=gen_table_meta[\"dynamic_seed\"],\n", "# bl_proportion=gen_table_meta[\"bl_proportion\"],\n", "# record_hits = True,\n", "# use_cuda=True, # this is obvi critical to match the pseudorandomness\n", "# )" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# stats = vis_bl(gen_table_w_bl_stats[idx], 0)\n", "\n", "# baseline_hit_list = stats[\"baseline_hit_list\"]\n", "# no_bl_hit_list = stats[\"no_bl_hit_list\"]\n", "# w_bl_hit_list = stats[\"w_bl_hit_list\"]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# text = stats[\"truncated_input\"]\n", "# fast_encoded = fast_tokenizer(text, truncation=True, max_length=2048)\n", "# hit_list = baseline_hit_list\n", "\n", "# charspans = [fast_encoded.token_to_chars(i) for i in range(len(fast_encoded[\"input_ids\"]))]\n", "# charspans = [cs for cs in charspans if cs is not None]\n", "# # spans = [(cs.start,cs.end, \"PR\") for i,cs in enumerate(charspans)]\n", "# spans = []\n", "\n", "# html = get_span_box_markup(text, spans, palette=palette(PR=BLUE), background='white', text_color=\"black\")\n", "\n", "\n", "# with open(\"figs/prompt_html.html\", \"w\") as f:\n", "# f.write(HTML(html).data)\n", "\n", "# HTML(html)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# text = stats[\"baseline_completion\"]\n", "# fast_encoded = fast_tokenizer(text, truncation=True, max_length=2048)\n", "# hit_list = baseline_hit_list\n", "\n", "# charspans = [fast_encoded.token_to_chars(i) for i in range(len(fast_encoded[\"input_ids\"]))]\n", "# charspans = [cs for cs in charspans if cs is not None]\n", "# spans = [(cs.start,cs.end, \"BL\") if hit_list[i] else (cs.start,cs.end, \"WL\") for i,cs in enumerate(charspans)]\n", "\n", "# html = get_span_box_markup(text, spans, palette=palette(BL=RED, WL=GREEN), background='white', text_color=\"black\")\n", "\n", "\n", "# with open(\"figs/baseline_html.html\", \"w\") as f:\n", "# f.write(HTML(html).data)\n", "\n", "# HTML(html)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "# text = stats[\"no_bl_output\"]\n", "# fast_encoded = fast_tokenizer(text, truncation=True, max_length=2048)\n", "# hit_list = no_bl_hit_list\n", "\n", "# charspans = [fast_encoded.token_to_chars(i) for i in range(len(fast_encoded[\"input_ids\"]))]\n", "# charspans = [cs for cs in charspans if cs is not None]\n", "# spans = [(cs.start,cs.end, \"BL\") if hit_list[i] else (cs.start,cs.end, \"WL\") for i,cs in enumerate(charspans)]\n", "\n", "# html = get_span_box_markup(text, spans, palette=palette(BL=RED, WL=GREEN), background='white', text_color=\"black\")\n", "\n", "\n", "# with open(\"figs/no_bl_html.html\", \"w\") as f:\n", "# f.write(HTML(html).data)\n", "\n", "# HTML(html)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "# text = stats[\"w_bl_output\"]\n", "# fast_encoded = fast_tokenizer(text, truncation=True, max_length=2048)\n", "# hit_list = w_bl_hit_list\n", "\n", "# charspans = [fast_encoded.token_to_chars(i) for i in range(len(fast_encoded[\"input_ids\"]))]\n", "# charspans = [cs for cs in charspans if cs is not None]\n", "# spans = [(cs.start,cs.end, \"BL\") if hit_list[i] else (cs.start,cs.end, \"WL\") for i,cs in enumerate(charspans)]\n", "\n", "# html = get_span_box_markup(text, spans, palette=palette(BL=RED, WL=GREEN), background='white', text_color=\"black\")\n", "\n", "\n", "# with open(\"figs/w_bl_html.html\", \"w\") as f:\n", "# f.write(HTML(html).data)\n", "\n", "# HTML(html)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" }, "vscode": { "interpreter": { "hash": "365524a309ad80022da286f2ec5d2060ce5cb229abb6076cf68d9a1ab14bd8fe" } } }, "nbformat": 4, "nbformat_minor": 4 }