|
|
|
|
|
|
|
|
|
|
|
import os |
|
import sys |
|
import argparse |
|
from tqdm import trange |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
import numpy as np |
|
from IPython import embed |
|
from operator import add |
|
from style_utils import to_var, top_k_logits |
|
import pickle |
|
import csv |
|
|
|
from gpt2tunediscrim import ClassificationHead |
|
|
|
|
|
|
|
|
|
from pytorch_pretrained_bert import GPT2LMHeadModel, GPT2Tokenizer |
|
|
|
SmallConst = 1e-15 |
|
enc = GPT2Tokenizer.from_pretrained('gpt-2_pt_models/345M/') |
|
|
|
def perturb_past(past, model, prev, args, classifier, good_index=None, stepsize=0.01, vocab_size=50257, |
|
original_probs=None, accumulated_hidden=None, true_past=None, grad_norms=None): |
|
window_length = args.window_length |
|
gm_scale, kl_scale = args.fusion_gm_scale, args.fusion_kl_scale |
|
one_hot_vectors = [] |
|
for good_list in good_index: |
|
good_list = list(filter(lambda x: len(x) <= 1, good_list)) |
|
good_list = torch.tensor(good_list).cuda() |
|
num_good = good_list.shape[0] |
|
one_hot_good = torch.zeros(num_good, vocab_size).cuda() |
|
one_hot_good.scatter_(1, good_list, 1) |
|
one_hot_vectors.append(one_hot_good) |
|
|
|
|
|
|
|
past_perturb_orig = [(np.random.uniform(0.0, 0.0, p.shape).astype('float32')) |
|
for p in past] |
|
|
|
if accumulated_hidden is None: |
|
accumulated_hidden = 0 |
|
|
|
if args.decay: |
|
decay_mask = torch.arange(0., 1.0 + SmallConst, 1.0/(window_length))[1:] |
|
else: |
|
decay_mask = 1.0 |
|
|
|
|
|
_, _, _, current_length, _ = past[0].shape |
|
|
|
if current_length > window_length and window_length > 0: |
|
ones_key_val_shape = tuple(past[0].shape[:-2]) + tuple([window_length]) + tuple( |
|
past[0].shape[-1:]) |
|
|
|
zeros_key_val_shape = tuple(past[0].shape[:-2]) + tuple([current_length - window_length]) + tuple( |
|
past[0].shape[-1:]) |
|
|
|
ones_mask = torch.ones(ones_key_val_shape) |
|
ones_mask = decay_mask*ones_mask.permute(0, 1, 2, 4, 3) |
|
ones_mask = ones_mask.permute(0, 1, 2, 4, 3) |
|
|
|
window_mask = torch.cat((ones_mask, torch.zeros(zeros_key_val_shape)), dim=-2).cuda() |
|
else: |
|
window_mask = torch.ones_like(past[0]).cuda() |
|
|
|
loss_per_iter = [] |
|
for i in range(args.num_iterations): |
|
past_perturb = [torch.from_numpy(p_) for p_ in past_perturb_orig] |
|
past_perturb = [to_var(p_, requires_grad=True) for p_ in past_perturb] |
|
|
|
perturbed_past = list(map(add, past, past_perturb)) |
|
|
|
_, _, _, current_length, _ = past_perturb[0].shape |
|
|
|
|
|
_, future_past = model(prev, past=perturbed_past) |
|
hidden = model.hidden_states |
|
new_accumulated_hidden = accumulated_hidden + torch.sum(hidden, dim=1).detach() |
|
|
|
|
|
logits = model.forward_hidden(hidden) |
|
logits = logits[:, -1, :] |
|
probabs = F.softmax(logits, dim=-1) |
|
loss = 0.0 |
|
loss_list = [] |
|
if args.loss_type == 1 or args.loss_type == 3: |
|
for one_hot_good in one_hot_vectors: |
|
good_logits = torch.mm(probabs, torch.t(one_hot_good)) |
|
loss_word = good_logits |
|
loss_word = torch.sum(loss_word) |
|
loss_word = -torch.log(loss_word) |
|
|
|
loss += loss_word |
|
loss_list.append(loss_word) |
|
print('words', loss.data.cpu().numpy()) |
|
|
|
if args.loss_type == 2 or args.loss_type == 3: |
|
ce_loss = torch.nn.CrossEntropyLoss() |
|
new_true_past = true_past |
|
for i in range(args.horizon_length): |
|
|
|
future_probabs = F.softmax(logits, dim=-1) |
|
future_probabs = torch.unsqueeze(future_probabs, dim=1) |
|
|
|
_, new_true_past = model(future_probabs, past=new_true_past) |
|
future_hidden = model.hidden_states |
|
new_accumulated_hidden = new_accumulated_hidden + torch.sum(future_hidden, dim=1) |
|
|
|
predicted_sentiment = classifier(new_accumulated_hidden / (current_length + 1 + args.horizon_length)) |
|
|
|
label = torch.tensor([args.label_class], device='cuda', dtype=torch.long) |
|
discrim_loss = ce_loss(predicted_sentiment, label) |
|
print('discrim', discrim_loss.data.cpu().numpy()) |
|
loss += discrim_loss |
|
loss_list.append(discrim_loss) |
|
|
|
|
|
kl_loss = 0.0 |
|
if kl_scale > 0.0: |
|
p = (F.softmax(original_probs[:, -1, :], dim=-1)) |
|
p = p + SmallConst * (p <= SmallConst).type(torch.FloatTensor).cuda().detach() |
|
correction = SmallConst * (probabs <= SmallConst).type(torch.FloatTensor).cuda().detach() |
|
corrected_probabs = probabs + correction.detach() |
|
kl_loss = kl_scale * ((corrected_probabs * (corrected_probabs / p).log()).sum()) |
|
|
|
loss += kl_loss |
|
|
|
print((loss - kl_loss).data.cpu().numpy()) |
|
|
|
loss_per_iter.append(loss.data.cpu().numpy()) |
|
loss.backward() |
|
if grad_norms is not None and args.loss_type == 1: |
|
grad_norms = [torch.max(grad_norms[index], torch.norm(p_.grad * window_mask)) for index, p_ in |
|
enumerate(past_perturb)] |
|
else: |
|
grad_norms = [(torch.norm(p_.grad * window_mask) + SmallConst) for index, p_ in enumerate(past_perturb)] |
|
|
|
grad = [ |
|
-stepsize * (p_.grad * window_mask / grad_norms[index] ** args.gamma).data.cpu().numpy() |
|
for index, p_ in enumerate(past_perturb)] |
|
past_perturb_orig = list(map(add, grad, past_perturb_orig)) |
|
|
|
for p_ in past_perturb: |
|
p_.grad.data.zero_() |
|
|
|
new_past = [] |
|
for p in past: |
|
new_past.append(p.detach()) |
|
|
|
past = new_past |
|
|
|
past_perturb = [torch.from_numpy(p_) for p_ in past_perturb_orig] |
|
past_perturb = [to_var(p_, requires_grad=True) for p_ in past_perturb] |
|
perturbed_past = list(map(add, past, past_perturb)) |
|
|
|
return perturbed_past, new_accumulated_hidden, grad_norms, loss_per_iter |
|
|
|
|
|
def latent_perturb(model, args, context=None, sample=True, device='cuda'): |
|
if args.discrim == 'clickbait': |
|
classifier = ClassificationHead(class_size=2, embed_size=1024).to(device) |
|
classifier.load_state_dict(torch.load("discrim_models/clickbait_classifierhead.pt")) |
|
classifier.eval() |
|
args.label_class = 1 |
|
|
|
elif args.discrim == 'sentiment': |
|
classifier = ClassificationHead(class_size=5, embed_size=1024).to(device) |
|
classifier.load_state_dict(torch.load("discrim_models/sentiment_classifierhead.pt")) |
|
classifier.eval() |
|
if args.label_class < 0: |
|
raise Exception('Wrong class for sentiment, use --label-class 2 for *very positive*, 3 for *very negative*') |
|
|
|
|
|
|
|
elif args.discrim == 'toxicity': |
|
classifier = ClassificationHead(class_size=2, embed_size=1024).to(device) |
|
classifier.load_state_dict(torch.load("discrim_models/toxicity_classifierhead.pt")) |
|
classifier.eval() |
|
args.label_class = 0 |
|
else: |
|
classifier = None |
|
|
|
|
|
def list_tokens(word_list): |
|
token_list = [] |
|
for word in word_list: |
|
token_list.append(enc.encode(" " + word)) |
|
return token_list |
|
|
|
|
|
good_index = [] |
|
if args.bag_of_words: |
|
bags_of_words = args.bag_of_words.split(";") |
|
for wordlist in bags_of_words: |
|
with open(wordlist, "r") as f: |
|
words = f.read() |
|
words = words.split('\n') |
|
good_index.append(list_tokens(words)) |
|
|
|
if args.bag_of_words and classifier: |
|
print('Both PPLM-BoW and PPLM-Discrim are on. This is not optimized.') |
|
args.loss_type = 3 |
|
|
|
elif args.bag_of_words: |
|
args.loss_type = 1 |
|
print('Using PPLM-BoW') |
|
|
|
elif classifier is not None: |
|
args.loss_type = 2 |
|
print('Using PPLM-Discrim') |
|
|
|
else: |
|
raise Exception('Supply either --bag-of-words (-B) or --discrim -D') |
|
|
|
|
|
original, _, _ = sample_from_hidden(model=model, args=args, context=context, device=device, |
|
perturb=False, good_index=good_index, classifier=classifier) |
|
torch.cuda.empty_cache() |
|
|
|
perturbed_list = [] |
|
discrim_loss_list = [] |
|
loss_in_time_list = [] |
|
|
|
for i in range(args.num_samples): |
|
perturbed, discrim_loss, loss_in_time = sample_from_hidden(model=model, args=args, context=context, |
|
device=device, perturb=True, good_index=good_index, |
|
classifier=classifier) |
|
perturbed_list.append(perturbed) |
|
if classifier is not None: |
|
discrim_loss_list.append(discrim_loss.data.cpu().numpy()) |
|
loss_in_time_list.append(loss_in_time) |
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
return original, perturbed_list, discrim_loss_list, loss_in_time_list |
|
|
|
|
|
def sample_from_hidden(model, args, classifier, context=None, past=None, device='cuda', |
|
sample=True, perturb=True, good_index=None): |
|
output = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0) if context else None |
|
|
|
grad_norms = None |
|
loss_in_time = [] |
|
for i in trange(args.length, ascii=True): |
|
|
|
|
|
|
|
|
|
|
|
|
|
if past is None and output is not None: |
|
prev = output[:, -1:] |
|
_, past = model(output[:, :-1]) |
|
original_probs, true_past = model(output) |
|
true_hidden = model.hidden_states |
|
|
|
else: |
|
original_probs, true_past = model(output) |
|
true_hidden = model.hidden_states |
|
|
|
|
|
|
|
if i >= args.grad_length: |
|
current_stepsize = args.stepsize * 0 |
|
else: |
|
current_stepsize = args.stepsize |
|
|
|
if not perturb or args.num_iterations == 0: |
|
perturbed_past = past |
|
|
|
else: |
|
accumulated_hidden = model.hidden_states[:, :-1, :] |
|
accumulated_hidden = torch.sum(accumulated_hidden, dim=1) |
|
|
|
perturbed_past, _, grad_norms, loss_per_iter = perturb_past(past, model, prev, args, |
|
good_index=good_index, stepsize=current_stepsize, |
|
original_probs=original_probs, |
|
true_past=true_past, |
|
accumulated_hidden=accumulated_hidden, |
|
classifier=classifier, |
|
grad_norms=grad_norms) |
|
loss_in_time.append(loss_per_iter) |
|
|
|
test_logits, past = model(prev, past=perturbed_past) |
|
|
|
|
|
|
|
|
|
if classifier is not None: |
|
ce_loss = torch.nn.CrossEntropyLoss() |
|
predicted_sentiment = classifier(torch.mean(true_hidden, dim=1)) |
|
label = torch.tensor([args.label_class], device='cuda', dtype=torch.long) |
|
true_discrim_loss = ce_loss(predicted_sentiment, label) |
|
print("true discrim loss", true_discrim_loss.data.cpu().numpy()) |
|
else: |
|
true_discrim_loss = 0 |
|
|
|
hidden = model.hidden_states |
|
logits = model.forward_hidden(hidden) |
|
logits = logits[:, -1, :] / args.temperature |
|
|
|
|
|
|
|
log_probs = F.softmax(logits, dim=-1) |
|
|
|
|
|
if perturb: |
|
|
|
|
|
original_probs = F.softmax(original_probs[:, -1, :], dim=-1) |
|
|
|
|
|
|
|
gm_scale = args.fusion_gm_scale |
|
log_probs = ((log_probs ** gm_scale) * (original_probs ** (1 - gm_scale))) |
|
|
|
log_probs = top_k_logits(log_probs, k=args.top_k, probs=True) |
|
|
|
if torch.sum(log_probs) <= 1: |
|
log_probs = log_probs / torch.sum(log_probs) |
|
|
|
else: |
|
logits = top_k_logits(logits, k=args.top_k) |
|
log_probs = F.softmax(logits, dim=-1) |
|
|
|
if sample: |
|
|
|
|
|
|
|
prev = torch.multinomial(log_probs, num_samples=1) |
|
else: |
|
_, prev = torch.topk(log_probs, k=1, dim=-1) |
|
|
|
|
|
output = prev if output is None else torch.cat((output, prev), dim=1) |
|
print(enc.decode(output.tolist()[0])) |
|
|
|
return output, true_discrim_loss, loss_in_time |
|
|
|
|
|
def run_model(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--model_path', '-M', type=str, default='gpt-2_pt_models/345M/', |
|
help='pretrained model name or path to local checkpoint') |
|
parser.add_argument('--bag-of-words', '-B', type=str, default=None, |
|
help='Bags of words used for PPLM-BoW. Multiple BoWs separated by ;') |
|
parser.add_argument('--discrim', '-D', type=str, default=None, |
|
choices=('clickbait', 'sentiment', 'toxicity'), |
|
help='Discriminator to use for loss-type 2') |
|
parser.add_argument('--label-class', type=int, default=-1, help='Class label used for the discriminator') |
|
parser.add_argument('--stepsize', type=float, default=0.02) |
|
parser.add_argument("--length", type=int, default=100) |
|
parser.add_argument("--seed", type=int, default=0) |
|
parser.add_argument("--temperature", type=float, default=1.0) |
|
parser.add_argument("--top_k", type=int, default=10) |
|
parser.add_argument("--fusion-gm-scale", type=float, default=0.9) |
|
parser.add_argument("--fusion-kl-scale", type=float, default=0.01) |
|
parser.add_argument('--nocuda', action='store_true', help='no cuda') |
|
parser.add_argument('--uncond', action='store_true', help='Generate from end-of-text as prefix') |
|
parser.add_argument("--cond-text", type=str, default='The lake', help='Prefix texts to condition on') |
|
parser.add_argument('--num-iterations', type=int, default=3) |
|
parser.add_argument('--grad-length', type=int, default=10000) |
|
parser.add_argument('--num-samples', type=int, default=1, |
|
help='Number of samples to generate from the modified latents') |
|
parser.add_argument('--horizon-length', type=int, default=1, help='Length of future to optimize over') |
|
|
|
parser.add_argument('--window-length', type=int, default=0, |
|
help='Length of past which is being optimizer; 0 corresponds to infinite window length') |
|
parser.add_argument('--decay', action='store_true', help='whether to decay or not') |
|
parser.add_argument('--gamma', type=float, default=1.5) |
|
|
|
args = parser.parse_args() |
|
|
|
torch.manual_seed(args.seed) |
|
np.random.seed(args.seed) |
|
|
|
device = 'cpu' if args.nocuda else 'cuda' |
|
|
|
model = GPT2LMHeadModel.from_pretrained(args.model_path) |
|
model.to(device) |
|
model.eval() |
|
|
|
|
|
for param in model.parameters(): |
|
param.requires_grad = False |
|
pass |
|
|
|
if args.uncond: |
|
seq = [[50256, 50256]] |
|
|
|
else: |
|
raw_text = args.cond_text |
|
while not raw_text: |
|
print('Did you forget to add `--cond-text`? ') |
|
raw_text = input("Model prompt >>> ") |
|
seq = [[50256] + enc.encode(raw_text)] |
|
|
|
collect_gen = dict() |
|
current_index = 0 |
|
for out in seq: |
|
|
|
text = enc.decode(out) |
|
print("=" * 40 + " Prefix of sentence " + "=" * 40) |
|
print(text) |
|
print("=" * 80) |
|
|
|
out1, out_perturb, discrim_loss_list, loss_in_time_list = latent_perturb(model=model, args=args, context=out, |
|
device=device) |
|
|
|
text_whole = enc.decode(out1.tolist()[0]) |
|
|
|
print("=" * 80) |
|
print("=" * 40 + " Whole sentence (Original)" + "=" * 40) |
|
print(text_whole) |
|
print("=" * 80) |
|
|
|
out_perturb_copy = out_perturb |
|
|
|
generated = 0 |
|
for out_perturb in out_perturb_copy: |
|
try: |
|
print("=" * 40 + " Whole sentence (Perturbed)" + "=" * 40) |
|
text_whole = enc.decode(out_perturb.tolist()[0]) |
|
print(text_whole) |
|
print("=" * 80) |
|
except: |
|
pass |
|
collect_gen[current_index] = [out, out_perturb, out1] |
|
|
|
|
|
current_index = current_index + 1 |
|
|
|
return |
|
|
|
|
|
if __name__ == '__main__': |
|
run_model() |
|
|