# Watermark Analysis

Notebook for performing analysis and visualization of the effects of watermarking schemes

In [None]:
from datasets import load_from_disk

### Load the processed dataset/frame

In [None]:
save_name = "analysis_ds_1-19_realnews_1-3_v1" # in figure
# save_name = "analysis_ds_1-21_greedy_redo" 
# save_name = "analysis_ds_1-21_greedy_redo_truncated" # in figure

# save_name = "analysis_ds_1-20_more_attack" # in figure

save_dir = f"input/{save_name}"

In [None]:
raw_data = load_from_disk(save_dir)

#### convert to pandas df

In [None]:
df = raw_data.to_pandas()

retok_problematic_rows = df[(df['w_bl_whitelist_fraction'] != -1.0) & (df['w_bl_whitelist_fraction'] != 1.0) & (df['bl_type'] == 'hard')]
print(f"Num rows that are hard-blacklisted, and measureable, but still have a non-100% WL fraction: {len(retok_problematic_rows)} out of {len(df[df['bl_type'] == 'hard'])}")

orig_len = len(df)

df = df[df["no_bl_whitelist_fraction"] != -1.0]
df = df[df["w_bl_whitelist_fraction"] != -1.0]

print(f"Dropped {orig_len-len(df)} rows, new len {len(df)}")

orig_len = len(df)
# df = df[df["no_bl_ppl"].isna()]
# df = df[df["w_bl_ppl"].isna()]
df = df[~(df["no_bl_ppl"].isna() | df["w_bl_ppl"].isna())]
print(f"Dropped {orig_len-len(df)} rows, new len {len(df)}")

orig_len = len(df)

df = df[df["bl_logit_bias"] <= 100.0]

print(f"Dropped {orig_len-len(df)} rows, new len {len(df)}")


orig_len = len(df)

# df = df[df["bl_hparams"].apply(lambda tup: (tup[0] == False and tup[2] != 1) or (tup[0] == True and tup[2] == 1) or (tup[0] == False))]
df = df[((df["use_sampling"]==True) & (df["num_beams"] == 1)) | (df["use_sampling"]==False)]

print(f"Dropped {orig_len-len(df)} rows, new len {len(df)}")


df.loc[df["use_sampling"]==False,"sampling_temp"] = df.loc[df["use_sampling"]==False,"sampling_temp"].fillna(0.0)
df.loc[df["use_sampling"]==True,"sampling_temp"] = df.loc[df["use_sampling"]==True,"sampling_temp"].fillna(1.0)


df.loc[df["bl_type"]=="hard","bl_logit_bias"] = np.inf
# df.loc[df["bl_type"]=="hard","bl_logit_bias"] = 10000 # crosscheck with whats hardcoded in the bl processor


df["delta"] = df["bl_logit_bias"].values
df["gamma"] = 1 - df["bl_proportion"].values
df["gamma"] = df["gamma"].round(3)

df["no_bl_act_num_wl_tokens"] = np.round(df["no_bl_whitelist_fraction"].values*df["no_bl_num_tokens_generated"],1) # round to 1 for sanity
df["w_bl_act_num_wl_tokens"] = np.round(df["w_bl_whitelist_fraction"].values*df["w_bl_num_tokens_generated"],1) # round to 1 for sanity

df["w_bl_std_num_wl_tokens"] = np.sqrt(df["w_bl_var_num_wl_tokens"].values)

if "real_completion_length":
 df["baseline_num_tokens_generated"] = df["real_completion_length"].values

if "actual_attacked_ratio" in df.columns:
 df["actual_attacked_fraction"] = df["actual_attacked_ratio"].values*df["replace_ratio"].values



df["baseline_hit_list_length"] = df["baseline_hit_list"].apply(len)
df["no_bl_hit_list_length"] = df["no_bl_hit_list"].apply(len)
df["w_bl_hit_list_length"] = df["w_bl_hit_list"].apply(len)

## Filter for the generation lengths we want to look at

In [None]:
orig_len = len(df)

upper_T = 205
lower_T = 195
df = df[(df["baseline_hit_list_length"] >= lower_T) & (df["no_bl_hit_list_length"] >= lower_T) & (df["w_bl_hit_list_length"] >= lower_T)] # now also applies to the truncated version
df = df[(df["baseline_hit_list_length"] <= upper_T) & (df["no_bl_hit_list_length"] <= upper_T) & (df["w_bl_hit_list_length"] <= upper_T)] # now also applies to the truncated version


print(f"Dropped {orig_len-len(df)} rows, new len {len(df)}")

#### Add z-scores

In [None]:
from math import sqrt
import scipy.stats
def compute_z_score(observed_wl_frac, T, gamma):
 numer = observed_wl_frac - gamma
 denom = sqrt(gamma*(1-gamma)/T)
 z = numer/denom
 return z

def compute_wl_for_z(z, T, gamma):
 denom = sqrt(gamma*(1-gamma)/T)
 numer = ((z*denom)+gamma)*T
 return numer

def compute_p_value(z):
 p_value = scipy.stats.norm.sf(abs(z))
 return p_value

df["baseline_z_score"] = df[["baseline_whitelist_fraction", "baseline_num_tokens_generated", "gamma"]].apply(lambda tup: compute_z_score(*tup), axis=1)
df["no_bl_z_score"] = df[["no_bl_whitelist_fraction", "no_bl_num_tokens_generated", "gamma"]].apply(lambda tup: compute_z_score(*tup), axis=1)
df["w_bl_z_score"] = df[["w_bl_whitelist_fraction", "w_bl_num_tokens_generated", "gamma"]].apply(lambda tup: compute_z_score(*tup), axis=1)

if "w_bl_attacked_whitelist_fraction" in df.columns:
 df["w_bl_attacked_z_score"] = df[["w_bl_attacked_whitelist_fraction", "w_bl_attacked_num_tokens_generated", "gamma"]].apply(lambda tup: compute_z_score(*tup), axis=1)

In [None]:
# # if attacked in df
if "w_bl_attacked_whitelist_fraction" in df.columns:
 df["w_bl_attacked_act_num_wl_tokens"] = np.round(df["w_bl_attacked_whitelist_fraction"].values*df["w_bl_attacked_num_tokens_generated"],1) # round to 1 for sanity

 df["w_bl_attacked_z_score"] = df[["w_bl_attacked_whitelist_fraction", "w_bl_attacked_num_tokens_generated", "gamma"]].apply(lambda tup: compute_z_score(*tup), axis=1)

 df[["bl_proportion","w_bl_attacked_whitelist_fraction", "w_bl_attacked_num_tokens_generated","w_bl_attacked_act_num_wl_tokens", "w_bl_attacked_z_score"]]

#### Groupby

In [None]:
if "w_bl_attacked_whitelist_fraction" in df.columns: 
 groupby_fields = ['use_sampling','num_beams','gamma','delta', 'replace_ratio'] # attack grouping
else:
 groupby_fields = ['use_sampling','num_beams','delta','gamma'] # regular grouping
 # groupby_fields = ['use_sampling','delta','gamma'] # regular grouping, but no beam variation
 # groupby_fields = ['delta','gamma'] # regular grouping, but no beam variation, and all sampling

#### Main groupby

In [None]:
grouped_df = df.groupby(groupby_fields)

In [None]:
print(f"Number of rows after filtering: {len(df)}")
print(f"Number of groups: {len(grouped_df)}")

### Loop to compute confusion matrix at some z scores for tabulation

In [None]:
import sklearn.metrics as metrics

def reject_null_hypo(z_score=None,cuttoff=None):
 return z_score > cuttoff

records = []

for group_params in tqdm(list(grouped_df.groups.keys())):
 sub_df = grouped_df.get_group(group_params)
 grp_size = len(sub_df)

 # baseline_z_scores = sub_df["baseline_z_score"].values
 # w_bl_z_scores = sub_df["w_bl_z_score"].values
 # all_scores = np.concatenate([baseline_z_scores,w_bl_z_scores])

 # baseline_labels = np.zeros_like(baseline_z_scores)
 # attacked_labels = np.ones_like(w_bl_z_scores)
 # all_labels = np.concatenate([baseline_labels,attacked_labels])

 # fpr, tpr, thresholds = metrics.roc_curve(all_labels, all_scores, pos_label=1)
 # roc_auc = metrics.auc(fpr, tpr)
 record = {k:v for k,v in zip(groupby_fields,group_params)}

 for thresh in [4.0,5.0]:
 
 record["count"] = grp_size
 record[f"baseline_fpr_at_{thresh}"] = reject_null_hypo(z_score=sub_df["baseline_z_score"].values,cuttoff=thresh).sum() / grp_size
 record[f"baseline_tnr_at_{thresh}"] = (~reject_null_hypo(z_score=sub_df["baseline_z_score"],cuttoff=thresh)).sum() / grp_size
 record[f"no_bl_fpr_at_{thresh}"] = reject_null_hypo(z_score=sub_df["no_bl_z_score"].values,cuttoff=thresh).sum() / grp_size
 record[f"no_bl_tnr_at_{thresh}"] = (~reject_null_hypo(z_score=sub_df["no_bl_z_score"].values,cuttoff=thresh)).sum() / grp_size
 record[f"w_bl_tpr_at_{thresh}"] = reject_null_hypo(z_score=sub_df["w_bl_z_score"].values,cuttoff=thresh).sum() / grp_size
 record[f"w_bl_fnr_at_{thresh}"] = (~reject_null_hypo(z_score=sub_df["w_bl_z_score"].values,cuttoff=thresh)).sum() / grp_size

 if "w_bl_attacked_z_score" in sub_df.columns:
 record[f"w_bl_attacked_tpr_at_{thresh}"] = reject_null_hypo(z_score=sub_df["w_bl_attacked_z_score"].values,cuttoff=thresh).sum() / grp_size
 record[f"w_bl_attacked_fnr_at_{thresh}"] = (~reject_null_hypo(z_score=sub_df["w_bl_attacked_z_score"].values,cuttoff=thresh)).sum() / grp_size

 records.append(record)

 # # df[f"baseline_fp_at_{thresh}"] = reject_null_hypo(z_score=df["baseline_z_score"].values,cuttoff=thresh)
 # # df[f"baseline_tn_at_{thresh}"] = ~reject_null_hypo(z_score=df["baseline_z_score"],cuttoff=thresh)
 # # df[f"no_bl_fp_at_{thresh}"] = reject_null_hypo(z_score=df["no_bl_z_score"].values,cuttoff=thresh)
 # # df[f"no_bl_tn_at_{thresh}"] = ~reject_null_hypo(z_score=df["no_bl_z_score"].values,cuttoff=thresh)
 # # df[f"w_bl_tp_at_{thresh}"] = reject_null_hypo(z_score=df["w_bl_z_score"].values,cuttoff=thresh)
 # # df[f"w_bl_fn_at_{thresh}"] = ~reject_null_hypo(z_score=df["w_bl_z_score"].values,cuttoff=thresh)


roc_df = pd.DataFrame.from_records(records)


In [None]:
# thresh = 6.0
# thresh = 5.0
std_threshes = [4.0, 5.0] #, 6.0]
# std_threshes = [4.0]

# roc_df["params"] = roc_df.index.to_list()

columns = ["delta", "gamma", "count"]
# columns = ["use_sampling", "replace_ratio", "count"]

for thresh in std_threshes:
 # columns += [f"baseline_fpr_at_{thresh}",f"no_bl_fpr_at_{thresh}",f"w_bl_tpr_at_{thresh}"]
 # columns += [f"baseline_fpr_at_{thresh}",f"baseline_tnr_at_{thresh}",f"no_bl_fpr_at_{thresh}",f"no_bl_tnr_at_{thresh}",f"w_bl_tpr_at_{thresh}",f"w_bl_fn_at_{thresh}"]


 # columns += [f"baseline_fpr_at_{thresh}",f"baseline_tnr_at_{thresh}",f"w_bl_tpr_at_{thresh}",f"w_bl_fnr_at_{thresh}"]
 
 if f"w_bl_attacked_fnr_at_{thresh}" in roc_df.columns:
 columns += [f"w_bl_tpr_at_{thresh}",f"w_bl_fnr_at_{thresh}"]
 columns += [f"w_bl_attacked_tpr_at_{thresh}",f"w_bl_attacked_fnr_at_{thresh}"] # if attack
 else:
 columns += [f"baseline_fpr_at_{thresh}",f"baseline_tnr_at_{thresh}",f"w_bl_tpr_at_{thresh}",f"w_bl_fnr_at_{thresh}"]

# filter ot not
sub_df = roc_df[(roc_df["use_sampling"] == True) & ((roc_df["delta"] == 1.0) | (roc_df["delta"] == 2.0) | (roc_df["delta"] == 10.0)) & ((roc_df["gamma"] == 0.1) | (roc_df["gamma"] == 0.25) |(roc_df["gamma"] == 0.5) )]
# sub_df = roc_df[(roc_df["replace_ratio"] == 0.1) | (roc_df["replace_ratio"] == 0.3) | (roc_df["replace_ratio"] == 0.5) | (roc_df["replace_ratio"] == 0.7)]
# sub_df = roc_df

sub_df.sort_values("delta")[columns]
# sub_df.sort_values("num_beams")[columns]

In [None]:
# print(roc_df[columns].drop(["count"],axis=1).sort_values("gamma").round(3).to_latex(index=False))
# print(roc_df[columns].drop(["count"],axis=1).sort_values("delta").round(3).to_latex(index=False))
# print(roc_df[columns].drop(["count"],axis=1).sort_values("num_beams").round(3).to_latex(index=False))

print(sub_df.sort_values("delta")[columns].round(3).to_latex(index=False))
# print(sub_df.sort_values("num_beams")[columns].round(3).to_latex(index=False))

### write to csv maybe

In [None]:
# cols_to_drop = ['no_bl_gen_time',
# 'w_bl_gen_time', 'spike_entropies', 
# 'no_bl_sec_per_tok', 'no_bl_tok_per_sec', 'w_bl_sec_per_tok',
# 'w_bl_tok_per_sec', 'baseline_loss','no_bl_loss',
# 'w_bl_loss', 'model_name', 'dataset_name',
# 'dataset_config_name', 'shuffle_dataset', 'shuffle_seed',
# 'shuffle_buffer_size', 'max_new_tokens', 'min_prompt_tokens',
# 'limit_indices', 'input_truncation_strategy',
# 'input_filtering_strategy', 'output_filtering_strategy', 'initial_seed',
# 'dynamic_seed','no_repeat_ngram_size', 'early_stopping',
# 'oracle_model_name', 'no_wandb', 'wandb_project', 'wandb_entity', 'output_dir', 'load_prev_generations', 'store_bl_ids',
# 'store_spike_ents', 'generate_only',
# 'SLURM_JOB_ID', 'SLURM_ARRAY_JOB_ID', 'SLURM_ARRAY_TASK_ID',
# 'gen_table_already_existed', 'baseline_num_toks_gend_eq_0',
# 'baseline_hit_list', 'no_bl_num_toks_gend_eq_0',
# 'no_bl_hit_list', 'w_bl_num_toks_gend_eq_0', 'w_bl_hit_list']
# df.drop(cols_to_drop,axis=1).to_csv("input/for_poking.csv")
# df

In [None]:
df.columns

# Extract examples (actual text) for tabulation based on entropy and z scores (tables 1,3,4,5,6)

In [None]:
print(f"groupby legend: {groupby_fields}")

In [None]:
groups = [
 (True, 1, 2.0, 0.5),
 # (True, 1, 10.0, 0.5),
 # (False, 8, 2.0, 0.5),
 # (False, 8, 10.0, 0.5),
]
group_dfs = []
for group in groups:
 sub_df = grouped_df.get_group(group)
 group_dfs.append(sub_df)

subset_df = pd.concat(group_dfs,axis=0)

print(len(subset_df))
# subset_df

# cols_to_tabulate = groupby_fields + [
cols_to_tabulate = [
 'idx', 
 'truncated_input', 
 # 'prompt_length',
 'baseline_completion',
 'no_bl_output', 
 'w_bl_output', 
 # 'real_completion_length',
 # 'no_bl_num_tokens_generated',
 # 'w_bl_num_tokens_generated',
 'avg_spike_entropy',
 # 'baseline_whitelist_fraction',
 # 'no_bl_whitelist_fraction',
 # 'w_bl_whitelist_fraction',
 # 'baseline_z_score',
 'no_bl_z_score',
 'w_bl_z_score',
 # 'baseline_ppl',
 'no_bl_ppl',
 'w_bl_ppl'
]

# subset_df[cols_to_tabulate]["idx"].value_counts()

for idx,occurrences in subset_df["idx"].value_counts().to_dict().items():
 subset_df.loc[(subset_df["idx"]==idx),"occurences"] = occurrences

subset_df["occurences"] = subset_df["occurences"].apply(int)

# cols_to_tabulate = ["occurences"] + cols_to_tabulate

In [None]:
# subset_df[cols_to_tabulate].sort_values(["occurences", "idx"],ascending=False)
# subset_df[cols_to_tabulate].sort_values(["avg_spike_entropy"],ascending=False)

In [None]:
max_prompt_chars = 200
max_output_chars = 200
# subset_df["truncated_input"] = subset_df["truncated_input"].apply(lambda s: f"[...]{s[-max_prompt_chars:]}")
# subset_df["baseline_completion"] = subset_df["baseline_completion"].apply(lambda s: f"{s[:max_output_chars]}[...truncated]")
# subset_df["no_bl_output"] = subset_df["no_bl_output"].apply(lambda s: f"{s[:max_output_chars]}[...truncated]")
# subset_df["w_bl_output"] = subset_df["w_bl_output"].apply(lambda s: f"{s[:max_output_chars]}[...truncated]")

# if you dont have the indexx you cant start with brackets
subset_df["truncated_input"] = subset_df["truncated_input"].apply(lambda s: f"(...){s[-max_prompt_chars:]}")
subset_df["baseline_completion"] = subset_df["baseline_completion"].apply(lambda s: f"{s[:max_output_chars]}[...continues]")
subset_df["no_bl_output"] = subset_df["no_bl_output"].apply(lambda s: f"{s[:max_output_chars]}[...continues]")
subset_df["w_bl_output"] = subset_df["w_bl_output"].apply(lambda s: f"{s[:max_output_chars]}[...continues]")


In [None]:
slice_size = 2

# subset_df[cols_to_tabulate]["avg_spike_entropy"].describe()[]

In [None]:
num_examples = len(subset_df)
midpt = num_examples//5
lower = midpt - (slice_size//2)
upper = midpt + (slice_size//2)+1

high_entropy_examples = subset_df[cols_to_tabulate].sort_values(["avg_spike_entropy"],ascending=True).tail(slice_size)
mid_entropy_examples = subset_df[cols_to_tabulate].sort_values(["avg_spike_entropy"],ascending=True).iloc[lower:upper]
low_entropy_examples = subset_df[cols_to_tabulate].sort_values(["avg_spike_entropy"],ascending=True).head(slice_size)

num_examples = len(subset_df)
midpt = num_examples//65
lower = midpt - (slice_size//2)
upper = midpt + (slice_size//2)+1

high_z_examples = subset_df[cols_to_tabulate].sort_values(["w_bl_z_score"],ascending=True).tail(slice_size)
mid_z_examples = subset_df[cols_to_tabulate].sort_values(["w_bl_z_score"],ascending=True).iloc[lower:upper]
low_z_examples = subset_df[cols_to_tabulate].sort_values(["w_bl_z_score"],ascending=True).head(slice_size)

In [None]:
# high_entropy_examples.head()
high_z_examples.head()

In [None]:
# mid_entropy_examples.head()
mid_z_examples.head()


In [None]:
# low_entropy_examples.head()
low_z_examples.head()

In [None]:
# slices_set_df = pd.concat([high_entropy_examples,low_entropy_examples],axis=0)
slices_set_df = pd.concat([high_z_examples,low_z_examples],axis=0).sort_values("w_bl_z_score",ascending=False)
slices_set_df

In [None]:
# slices_set_df.T.iloc[:,0:2]

In [None]:
# print(slices_set_df.to_latex(index=False))
# print(low_entropy_examples.to_latex(index=False))
# print(mid_entropy_examples.to_latex(index=False))
# print(high_entropy_examples.to_latex(index=False))

In [None]:
# for c,t in zip(low_entropy_examples.columns,low_entropy_examples.dtypes):
# if t==object:
# low_entropy_examples[c] = low_entropy_examples[c].apply(lambda s: f"{s[:100]}[...truncated]")

In [None]:
# low_entropy_examples.T.to_latex(buf=open("figs/low_ent_examples.txt", "w"),index=False)

In [None]:
# df_to_write = high_entropy_examples
# df_to_write = mid_entropy_examples
# df_to_write = low_entropy_examples
# df_to_write = high_z_examples
# df_to_write = mid_z_examples
# df_to_write = low_z_examples

cols_to_drop = ["idx", "avg_spike_entropy", "no_bl_z_score"] #, "no_bl_ppl", "w_bl_ppl"]
df_to_write = slices_set_df.drop(cols_to_drop,axis=1)


with pd.option_context("max_colwidth", 1000):
 column_format="".join([(r'p{3cm}|' if t==object else r'p{0.4cm}|') for c,t in zip(df_to_write.columns,df_to_write.dtypes)])[:-1]
 # low_entropy_examples.round(2).to_latex(buf=open("figs/low_ent_examples.txt", "w"),column_format=column_format,index=False)
 latex_str = df_to_write.round(2).to_latex(column_format=column_format,index=False)

print(latex_str)

In [None]:
# column_format="".join([r'p{2cm}|' for c in low_entropy_examples.columns])
# column_format

In [None]:
# low_entropy_examples.dtypes

In [None]:
with pd.option_context("max_colwidth", 1000):
 print(grouped_df.get_group((True, 1, 2.0, 0.9)).head(10)["w_bl_output"])

### Set up data for charts

In [None]:
# viz_df = pd.DataFrame()

# # set the hparam keys, including an indiv column for each you want to ablate on
# viz_df["bl_hparams"] = grouped_df["w_bl_exp_whitelist_fraction"].describe().index.to_list()
# for i,key in enumerate(groupby_fields):
# viz_df[key] = viz_df["bl_hparams"].apply(lambda tup: tup[i])

# describe_dict = grouped_df["w_bl_whitelist_fraction"].describe()
# viz_df["w_bl_whitelist_fraction_mean"] = describe_dict["mean"].to_list()
# viz_df["w_bl_whitelist_fraction_std"] = describe_dict["std"].to_list()

# describe_dict = grouped_df["no_bl_whitelist_fraction"].describe()
# viz_df["no_bl_whitelist_fraction_mean"] = describe_dict["mean"].to_list()
# viz_df["no_bl_whitelist_fraction_std"] = describe_dict["std"].to_list()


# describe_dict = grouped_df["w_bl_z_score"].describe()
# viz_df["w_bl_z_score_mean"] = describe_dict["mean"].to_list()
# viz_df["w_bl_z_score_std"] = describe_dict["std"].to_list()

# describe_dict = grouped_df["no_bl_z_score"].describe()
# viz_df["no_bl_z_score_mean"] = describe_dict["mean"].to_list()
# viz_df["no_bl_z_score_std"] = describe_dict["std"].to_list()


# describe_dict = grouped_df["w_bl_ppl"].describe()
# viz_df["w_bl_ppl_mean"] = describe_dict["mean"].to_list()
# viz_df["w_bl_ppl_std"] = describe_dict["std"].to_list()

# describe_dict = grouped_df["no_bl_ppl"].describe()
# viz_df["no_bl_ppl_mean"] = describe_dict["mean"].to_list()
# viz_df["no_bl_ppl_std"] = describe_dict["std"].to_list()

# describe_dict = grouped_df["avg_spike_entropy"].describe()
# viz_df["avg_spike_entropy_mean"] = describe_dict["mean"].to_list()
# viz_df["avg_spike_entropy_std"] = describe_dict["std"].to_list()

# print(f"groupby legend: {groupby_fields}")


In [None]:
# # filtering

# viz_df = viz_df[viz_df["bl_hparams"].apply(lambda tup: (tup[0] == True))] # sampling

# # viz_df = viz_df[viz_df["bl_hparams"].apply(lambda tup: (tup[0] == False))] # greedy


# # fix one of the bl params for analytic chart
# viz_df = viz_df[(viz_df["gamma"]==0.5) & (viz_df["delta"]<=10.0)]

# # viz_df = viz_df[(viz_df["delta"] > 0.5) & (viz_df["delta"]<=10.0)]

# # viz_df = viz_df[(viz_df["delta"]==0.5) | (viz_df["delta"]==2.0) | (viz_df["delta"]==10.0)]

# # viz_df = viz_df[(viz_df["delta"]!=0.1)&(viz_df["delta"]!=0.5)&(viz_df["delta"]!=50.0)]

# # viz_df = viz_df[(viz_df["delta"]!=50.0)]
# # viz_df = viz_df[(viz_df["delta"]!=50.0) & (viz_df["num_beams"]!=1)]

# print(len(viz_df))

# viz_df

# Visualize the WL/BL hits via highlighting in html

In [None]:
# idx = 75
# # idx = 62

# # debug
# # idx = 7
# # idx = 18
# # idx = 231

# print(gen_table_w_bl_stats[idx])
# print(f"\nPrompt:",gen_table_w_bl_stats[idx]["truncated_input"])
# print(f"\nBaseline (real text):{gen_table_w_bl_stats[idx]['baseline_completion']}")
# print(f"\nNo Blacklist:{gen_table_w_bl_stats[idx]['no_bl_output']}")
# print(f"\nw/ Blacklist:{gen_table_w_bl_stats[idx]['w_bl_output']}")

In [None]:
# from ipymarkup import show_span_box_markup, get_span_box_markup
# from ipymarkup.palette import palette, RED, GREEN, BLUE

# from IPython.display import display, HTML

# from transformers import GPT2TokenizerFast
# # fast_tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
# fast_tokenizer = GPT2TokenizerFast.from_pretrained("facebook/opt-2.7b")

In [None]:
# %autoreload

# vis_bl = partial(
# compute_bl_metrics,
# tokenizer=fast_tokenizer,
# hf_model_name=gen_table_meta["model_name"],
# initial_seed=gen_table_meta["initial_seed"],
# dynamic_seed=gen_table_meta["dynamic_seed"],
# bl_proportion=gen_table_meta["bl_proportion"],
# record_hits = True,
# use_cuda=True, # this is obvi critical to match the pseudorandomness
# )

In [None]:
# stats = vis_bl(gen_table_w_bl_stats[idx], 0)

# baseline_hit_list = stats["baseline_hit_list"]
# no_bl_hit_list = stats["no_bl_hit_list"]
# w_bl_hit_list = stats["w_bl_hit_list"]

In [None]:
# text = stats["truncated_input"]
# fast_encoded = fast_tokenizer(text, truncation=True, max_length=2048)
# hit_list = baseline_hit_list

# charspans = [fast_encoded.token_to_chars(i) for i in range(len(fast_encoded["input_ids"]))]
# charspans = [cs for cs in charspans if cs is not None]
# # spans = [(cs.start,cs.end, "PR") for i,cs in enumerate(charspans)]
# spans = []

# html = get_span_box_markup(text, spans, palette=palette(PR=BLUE), background='white', text_color="black")


# with open("figs/prompt_html.html", "w") as f:
# f.write(HTML(html).data)

# HTML(html)

In [None]:
# text = stats["baseline_completion"]
# fast_encoded = fast_tokenizer(text, truncation=True, max_length=2048)
# hit_list = baseline_hit_list

# charspans = [fast_encoded.token_to_chars(i) for i in range(len(fast_encoded["input_ids"]))]
# charspans = [cs for cs in charspans if cs is not None]
# spans = [(cs.start,cs.end, "BL") if hit_list[i] else (cs.start,cs.end, "WL") for i,cs in enumerate(charspans)]

# html = get_span_box_markup(text, spans, palette=palette(BL=RED, WL=GREEN), background='white', text_color="black")


# with open("figs/baseline_html.html", "w") as f:
# f.write(HTML(html).data)

# HTML(html)


In [None]:

# text = stats["no_bl_output"]
# fast_encoded = fast_tokenizer(text, truncation=True, max_length=2048)
# hit_list = no_bl_hit_list

# charspans = [fast_encoded.token_to_chars(i) for i in range(len(fast_encoded["input_ids"]))]
# charspans = [cs for cs in charspans if cs is not None]
# spans = [(cs.start,cs.end, "BL") if hit_list[i] else (cs.start,cs.end, "WL") for i,cs in enumerate(charspans)]

# html = get_span_box_markup(text, spans, palette=palette(BL=RED, WL=GREEN), background='white', text_color="black")


# with open("figs/no_bl_html.html", "w") as f:
# f.write(HTML(html).data)

# HTML(html)


In [None]:

# text = stats["w_bl_output"]
# fast_encoded = fast_tokenizer(text, truncation=True, max_length=2048)
# hit_list = w_bl_hit_list

# charspans = [fast_encoded.token_to_chars(i) for i in range(len(fast_encoded["input_ids"]))]
# charspans = [cs for cs in charspans if cs is not None]
# spans = [(cs.start,cs.end, "BL") if hit_list[i] else (cs.start,cs.end, "WL") for i,cs in enumerate(charspans)]

# html = get_span_box_markup(text, spans, palette=palette(BL=RED, WL=GREEN), background='white', text_color="black")


# with open("figs/w_bl_html.html", "w") as f:
# f.write(HTML(html).data)

# HTML(html)