|
|
|
|
|
|
|
import os |
|
import sys |
|
import argparse |
|
from tqdm import trange |
|
|
|
import torch |
|
import torch.optim |
|
import torch.nn.functional as F |
|
import numpy as np |
|
from torch.autograd import Variable |
|
|
|
lab_root = os.path.join(os.path.abspath(os.path.dirname(__file__)), '..', '..') |
|
sys.path.insert(1, lab_root) |
|
from pytorch_pretrained_bert import GPT2LMHeadModel, GPT2Tokenizer |
|
|
|
from IPython import embed |
|
|
|
def top_k_logits(logits, k, probs=False): |
|
""" |
|
Masks everything but the k top entries as -infinity (1e10). |
|
Used to mask logits such that e^-infinity -> 0 won't contribute to the |
|
sum of the denominator. |
|
""" |
|
if k == 0: |
|
return logits |
|
else: |
|
values = torch.topk(logits, k)[0] |
|
batch_mins = values[:, -1].view(-1, 1).expand_as(logits) |
|
if probs: |
|
return torch.where(logits < batch_mins, torch.ones_like(logits) * 0.0, logits) |
|
return torch.where(logits < batch_mins, torch.ones_like(logits) * -1e10, logits) |
|
|
|
def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, |
|
top_k=0, device='cuda', sample=True, return_past=False): |
|
if start_token is None: |
|
assert context is not None, 'Specify exactly one of start_token and context!' |
|
context = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0).repeat(batch_size, 1) |
|
else: |
|
assert context is None, 'Specify exactly one of start_token and context!' |
|
context = torch.full((batch_size, 1), start_token, device=device, dtype=torch.long) |
|
|
|
prev = context |
|
output = context |
|
past = None |
|
with torch.no_grad(): |
|
for i in trange(length, ascii=True): |
|
logits, past = model(prev, past=past) |
|
logits = logits[:, -1, :] / temperature |
|
logits = top_k_logits(logits, k=top_k) |
|
log_probs = F.softmax(logits, dim=-1) |
|
if sample: |
|
prev = torch.multinomial(log_probs, num_samples=1) |
|
else: |
|
_, prev = torch.topk(log_probs, k=1, dim=-1) |
|
|
|
|
|
|
|
|
|
output = torch.cat((output, prev), dim=1) |
|
|
|
if return_past: |
|
return output, past |
|
else: |
|
return output |
|
|
|
|
|
def sample_from_hidden(model, length, hidden, context=None, past=None, temperature=1, |
|
top_k=0, device='cuda', sample=True, noise_level=1e-1): |
|
output = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0) if context else None |
|
with torch.no_grad(): |
|
for i in trange(length, ascii=True): |
|
logits = model.forward_hidden(hidden) |
|
logits = logits[:, -1, :] / temperature |
|
logits = top_k_logits(logits, k=top_k) |
|
log_probs = F.softmax(logits, dim=-1) |
|
if sample: |
|
prev = torch.multinomial(log_probs, num_samples=1) |
|
else: |
|
_, prev = torch.topk(log_probs, k=1, dim=-1) |
|
|
|
|
|
|
|
output = prev if output is None else torch.cat((output, prev), dim=1) |
|
if i == 0: |
|
_, past = model(output, past=None) |
|
else: |
|
_, past = model(prev, past=past) |
|
hidden = model.hidden_states |
|
|
|
|
|
|
|
|
|
hidden = modify_hidden(hidden, noise_level) |
|
return output |
|
|
|
def modify_hidden(input_tensor, noise_level=1e-1): |
|
|
|
length = input_tensor.shape[-1] |
|
ret = input_tensor + torch.rand(length).cuda() * noise_level |
|
return ret |
|
|
|
def compute_log_likelihood(model, phrase, tokenizer, device): |
|
token_ids = tokenizer.encode(phrase) |
|
batch_size = 1 |
|
context = torch.tensor(token_ids, device=device, dtype=torch.long).unsqueeze(0).repeat(batch_size, 1) |
|
print("Computing LL of phrase \"{}\"".format(phrase)) |
|
print("After encoding, number of tokens {}".format(len(token_ids))) |
|
with torch.no_grad(): |
|
logits, past = model(context, past=None) |
|
|
|
_idxs = range(len(token_ids) - 1) |
|
token_ids = token_ids[1:] |
|
logits = logits[0, :-1] |
|
|
|
probs = F.softmax(logits, dim=-1) |
|
likelihoods = probs[_idxs, token_ids] |
|
assert len(list(likelihoods.shape)) == 1 |
|
|
|
log_likelihoods = torch.log(likelihoods) |
|
ll_list = [ls.item() for ls in log_likelihoods] |
|
|
|
for token, llh in zip(token_ids, log_likelihoods): |
|
print("LL of token {} (\'{}\') ==> {:.4f}".format(token, tokenizer.decode([token]), llh)) |
|
|
|
print("LL of the phrase (sum of the above): {}".format(np.sum(ll_list))) |
|
return np.sum(ll_list) |
|
|
|
|
|
def get_embedding_grad(model, enc, context=None, target=40, device='cuda', ll_only=False, opt_embed=False): |
|
assert context is not None, 'Input text is needed' |
|
|
|
|
|
|
|
context = torch.tensor(context, device=device, dtype=torch.float).unsqueeze(0) |
|
|
|
model.zero_grad() |
|
logits, past = model(context, past=None) |
|
|
|
|
|
|
|
|
|
logits = logits[:, -1, :] |
|
log_probs = F.softmax(logits, dim=-1) |
|
|
|
if len(target) > 1: |
|
nll = sum([-torch.log(log_probs[:, tar]) for tar in target]) |
|
else: |
|
nll = - torch.log(log_probs[:, target]) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
log_probs = F.softmax(logits, dim=-1) |
|
top1, top1ind = torch.topk(log_probs, k=1, dim=-1) |
|
|
|
print('LL of target : {}'.format(-nll.data.squeeze().cpu().numpy())) |
|
print('LL of top 1 : {}'.format(torch.log(top1).data.squeeze().cpu().numpy())) |
|
|
|
if ll_only: |
|
return |
|
|
|
if opt_embed: |
|
orig_embed = model.transformer.i_embeds.clone() |
|
embed_vars = Variable(model.transformer.i_embeds, requires_grad=True) |
|
|
|
optimizer = torch.optim.Adam([embed_vars], lr=0.01) |
|
optimizer.zero_grad() |
|
|
|
for ss in range(50): |
|
|
|
nll.backward() |
|
optimizer.step() |
|
|
|
logits, past = model.forward_embed(embed_vars, past=None) |
|
logits = logits[:, -1, :] |
|
log_probs = F.softmax(logits, dim=-1) |
|
|
|
if len(target) > 1: |
|
nll = sum([-torch.log(log_probs[:, tar]) for tar in target]) |
|
else: |
|
nll = - torch.log(log_probs[:, target]) |
|
|
|
print('LL of target (step {}): {}'.format(ss, -nll.data.squeeze().cpu().numpy())) |
|
|
|
|
|
|
|
|
|
output_ids = torch.empty_like(context.long()) |
|
with torch.no_grad(): |
|
all_embeds = model.transformer.wte.weight |
|
embed_vars_unbind = torch.unbind(embed_vars, dim=1) |
|
orig_embed_unbind = torch.unbind(orig_embed, dim=1) |
|
|
|
cc = 0 |
|
for ie_new, ie_orig, orig_id in zip(embed_vars_unbind, orig_embed_unbind, context.squeeze(0)): |
|
new_id = (all_embeds - ie_new).abs().sum(1).argmin() |
|
|
|
print('emb {}: {} (`{}`) to {} (`{}`)'.format(cc, orig_id.tolist(), enc.decode([orig_id.tolist()]), |
|
new_id.tolist(), enc.decode([new_id.tolist()]))) |
|
|
|
output_ids[0, cc] = new_id |
|
cc += 1 |
|
|
|
output_ids = torch.cat((context.long(), output_ids), dim=1) |
|
return output_ids |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return output_ids |
|
|
|
def run_model(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--model_path', '-M', type=str, default='gpt-2_pt_models/774M/', |
|
help='pretrained model name or path to local checkpoint') |
|
parser.add_argument("--seed", type=int, default=0) |
|
parser.add_argument("--nsamples", type=int, default=1) |
|
parser.add_argument("--batch_size", type=int, default=-1) |
|
parser.add_argument("--length", type=int, default=-1) |
|
parser.add_argument("--temperature", type=float, default=1.0) |
|
parser.add_argument("--top_k", type=int, default=0) |
|
parser.add_argument('--unconditional', action='store_true', help='If true, unconditional generation.') |
|
parser.add_argument('--nocuda', action='store_true', help='no cuda') |
|
parser.add_argument('--opt_ll', action='store_true', help='nll optimize') |
|
parser.add_argument('--get_ll', action='store_true', help='compute log likelihood of sentence') |
|
parser.add_argument('--hidden_playground', action='store_true', help='play around in the hidden representation') |
|
parser.add_argument("--noise_level", type=float, default=1e-1) |
|
parser.add_argument("--cond-text", type=str, default='', help='Prefix texts to condition on') |
|
parser.add_argument('--output', type=str, default=os.environ.get('GIT_RESULTS_MANAGER_DIR', None), help='output directory') |
|
args = parser.parse_args() |
|
print(args) |
|
|
|
if args.batch_size == -1: |
|
args.batch_size = 1 |
|
assert args.nsamples % args.batch_size == 0 |
|
|
|
np.random.seed(args.seed) |
|
torch.random.manual_seed(args.seed) |
|
torch.cuda.manual_seed(args.seed) |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
if args.nocuda: |
|
device = torch.device("cpu") |
|
|
|
print('device is {}'.format(device)) |
|
|
|
enc = GPT2Tokenizer.from_pretrained(args.model_path) |
|
model = GPT2LMHeadModel.from_pretrained(args.model_path) |
|
model.to(device) |
|
model.eval() |
|
|
|
|
|
if args.length == -1: |
|
args.length = model.config.n_ctx // 2 |
|
elif args.length > model.config.n_ctx: |
|
raise ValueError("Can't get samples longer than window size: %s" % model.config.n_ctx) |
|
|
|
|
|
generated = 0 |
|
for _ in range(10): |
|
context_tokens = [] |
|
if not args.unconditional: |
|
|
|
raw_text = args.cond_text |
|
while not raw_text: |
|
print('Prompt should not be empty!') |
|
raw_text = input("Model prompt >>> ") |
|
context_tokens = enc.encode(raw_text) |
|
for _ in range(args.nsamples // args.batch_size): |
|
out = sample_sequence( |
|
model=model, length=args.length, |
|
context=context_tokens, |
|
start_token=None, |
|
batch_size=args.batch_size, |
|
temperature=args.temperature, top_k=args.top_k, device=device |
|
) |
|
|
|
out = out[:, 0:].tolist() |
|
for i in range(args.batch_size): |
|
generated += 1 |
|
text = enc.decode(out[i]) |
|
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) |
|
print(text) |
|
if args.output: |
|
filepath = os.path.join(args.output, "generated_{}.txt".format(generated)) |
|
with open(filepath, "w") as f: |
|
f.write(text) |
|
|
|
|
|
if args.unconditional: |
|
generated = 0 |
|
for _ in range(args.nsamples // args.batch_size): |
|
out = sample_sequence( |
|
model=model, length=args.length, |
|
context=None, |
|
start_token=enc.encoder['<|endoftext|>'], |
|
batch_size=args.batch_size, |
|
temperature=args.temperature, top_k=args.top_k, device=device |
|
) |
|
out = out[:,1:].tolist() |
|
for i in range(args.batch_size): |
|
generated += 1 |
|
text = enc.decode(out[i]) |
|
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) |
|
print(text) |
|
|
|
if args.unconditional: |
|
break |
|
|
|
|
|
if __name__ == '__main__': |
|
run_model() |
|
|
|
|
|
|