Upload utils.py
Browse files
utils.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# helpful functions for summary generation
|
2 |
+
# the code is a customized version of:
|
3 |
+
# https://github.com/SKRohit/Generating_Text_Summary_With_GPT2/blob/master/utils.py
|
4 |
+
|
5 |
+
import json
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from transformers import GPT2TokenizerFast
|
10 |
+
#from tqdm import tnrange
|
11 |
+
import nltk
|
12 |
+
# nltk.download('punkt')
|
13 |
+
from nltk.tokenize import sent_tokenize
|
14 |
+
|
15 |
+
|
16 |
+
def print_summary(context, gen_summary, gold_summary):
|
17 |
+
print('input_text', end='\n\n')
|
18 |
+
print(context, end='\n\n')
|
19 |
+
print("generated_summary", end='\n\n')
|
20 |
+
print(gen_summary, end='\n\n')
|
21 |
+
print('golden_summary', end='\n\n')
|
22 |
+
print(gold_summary, end='\n\n')
|
23 |
+
|
24 |
+
|
25 |
+
def add_special_tokens(tokenizer_path):
|
26 |
+
""" Returns GPT2 tokenizer after adding separator and padding tokens """
|
27 |
+
tokenizer = GPT2TokenizerFast.from_pretrained(tokenizer_path, pad_token='<|endoftext|>')
|
28 |
+
special_tokens = {'sep_token':'<|sep|>'}
|
29 |
+
num_add_toks = tokenizer.add_special_tokens(special_tokens)
|
30 |
+
return tokenizer
|
31 |
+
|
32 |
+
|
33 |
+
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
|
34 |
+
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
35 |
+
Args:
|
36 |
+
logits: logits distribution shape (vocabulary size)
|
37 |
+
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
38 |
+
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
39 |
+
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
40 |
+
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
41 |
+
"""
|
42 |
+
assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
|
43 |
+
top_k = min(top_k, logits.size(-1)) # Safety check
|
44 |
+
if top_k > 0:
|
45 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
46 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
47 |
+
logits[indices_to_remove] = filter_value
|
48 |
+
|
49 |
+
if top_p > 0.0:
|
50 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
51 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
52 |
+
|
53 |
+
# Remove tokens with cumulative probability above the threshold
|
54 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
55 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
56 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
57 |
+
sorted_indices_to_remove[..., 0] = 0
|
58 |
+
|
59 |
+
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
60 |
+
logits[indices_to_remove] = filter_value
|
61 |
+
return logits
|
62 |
+
|
63 |
+
|
64 |
+
def sample_seq_fast(model, context, length, num_sentences, device, temperature=1, top_k=0, top_p=0.0, eos_stopping=False):
|
65 |
+
""" Generates a sequence of tokens
|
66 |
+
Args:
|
67 |
+
model: gpt/gpt2 model
|
68 |
+
context: tokenized text using gpt/gpt2 tokenizer
|
69 |
+
length: length of generated sequence.
|
70 |
+
device: torch.device object.
|
71 |
+
temperature >0: used to control the randomness of predictions by scaling the logits before applying softmax.
|
72 |
+
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
73 |
+
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
74 |
+
"""
|
75 |
+
|
76 |
+
# generates one senence more than wanted
|
77 |
+
# looks at the generated token and estimates the num of sentences on the go
|
78 |
+
# after n+1 times ".!?" it takes first n sentences by sent_tokenize
|
79 |
+
sent_to_gen = num_sentences + 1
|
80 |
+
|
81 |
+
context = torch.tensor(context, dtype=torch.long, device=device)
|
82 |
+
context = context.unsqueeze(0)
|
83 |
+
generated = context
|
84 |
+
with torch.no_grad():
|
85 |
+
for _ in range(length):
|
86 |
+
inputs = {'input_ids': generated}
|
87 |
+
assert len(inputs["input_ids"]) <= 1024 #########################
|
88 |
+
outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
|
89 |
+
next_token_logits = outputs[0][0, -1, :] / temperature
|
90 |
+
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
91 |
+
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
|
92 |
+
if not next_token and eos_stopping:
|
93 |
+
break
|
94 |
+
generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
|
95 |
+
if not eos_stopping and next_token in [1, 14, 31]:
|
96 |
+
sent_to_gen -= 1
|
97 |
+
if not sent_to_gen:
|
98 |
+
break
|
99 |
+
return generated
|
100 |
+
|
101 |
+
|
102 |
+
def generate_summary_fast(context_enc, sep_idx, tokenizer, model, num_sentences, temperature=1, top_k=50, top_p=0.5,
|
103 |
+
device=torch.device('cuda'), eos_stopping=False):
|
104 |
+
|
105 |
+
|
106 |
+
# generates one senence more than wanted
|
107 |
+
# looks at the generated token and estimates the num of sentences on the go
|
108 |
+
# after n+1 times ".!?" it takes first n sentences by sent_tokenize
|
109 |
+
|
110 |
+
generated_text = sample_seq_fast(model, context_enc, 1024-sep_idx, num_sentences, device, temperature, top_k, top_p, eos_stopping=eos_stopping)
|
111 |
+
generated_text = generated_text[0, len(context_enc):].tolist()
|
112 |
+
gen_summary = tokenizer.convert_ids_to_tokens(generated_text,skip_special_tokens=True)
|
113 |
+
gen_summary = tokenizer.convert_tokens_to_string(gen_summary)
|
114 |
+
|
115 |
+
# extract <num_sentences> sentences
|
116 |
+
if not eos_stopping:
|
117 |
+
gen_summary.replace("...", ".")
|
118 |
+
try:
|
119 |
+
gen_summary = " ".join(nltk.sent_tokenize(gen_summary)[:num_sentences])
|
120 |
+
except:
|
121 |
+
pass
|
122 |
+
return gen_summary
|
123 |
+
|
124 |
+
|
125 |
+
|
126 |
+
def generate_eval_file(data, data_type, tokenizer, model, save_dir, field, num_sentences=5,
|
127 |
+
max_summaries=0, temperature=1, top_k=50, top_p=0.5,
|
128 |
+
device=torch.device('cuda'), eval_step=True, eos_stopping=False, skip=0):
|
129 |
+
print(data_type)
|
130 |
+
max_summaries = math.inf if max_summaries == "full" else max_summaries
|
131 |
+
len_data = min(max_summaries, len(data))
|
132 |
+
disp_len = "full" if max_summaries == math.inf else len_data
|
133 |
+
if eos_stopping:
|
134 |
+
save_file = save_dir + f"/{data_type}_{disp_len}_sent{num_sentences}_eos_topk{top_k}_topp{top_p}.jsonl"
|
135 |
+
else:
|
136 |
+
save_file = save_dir + f"/{data_type}_{disp_len}_sent{num_sentences}_topk{top_k}_topp{top_p}.jsonl"
|
137 |
+
print(f"saving to: {save_file}")
|
138 |
+
|
139 |
+
how_open = ""
|
140 |
+
if skip:
|
141 |
+
how_open = "a"
|
142 |
+
else:
|
143 |
+
how_open = "w+"
|
144 |
+
with open(save_file, how_open) as output:
|
145 |
+
for s in range(skip, len_data):
|
146 |
+
if s%100 == 0:
|
147 |
+
print(s)
|
148 |
+
sample = data[s]
|
149 |
+
sep_idx = sample['sum_idx']
|
150 |
+
context = sample['input_ids'][:sep_idx].tolist()
|
151 |
+
gold_summary = sample['input_ids'][sep_idx+1:][:100].tolist()
|
152 |
+
# generating with the new faster and better method
|
153 |
+
gen_summary = generate_summary_fast(context, sep_idx, tokenizer, model, num_sentences,
|
154 |
+
temperature=temperature, top_k=top_k, top_p=top_p,
|
155 |
+
device=device, eos_stopping=eos_stopping)
|
156 |
+
|
157 |
+
if not eval_step:
|
158 |
+
print_summary(tokenizer.decode(context), gen_summary, tokenizer.decode(gold_summary))
|
159 |
+
else:
|
160 |
+
new_doc = {field: gen_summary}
|
161 |
+
line = json
|
162 |
+
json.dump(new_doc, output, ensure_ascii=False)
|
163 |
+
output.write("\n")
|
164 |
+
|
165 |
+
|
166 |
+
def generate_one_summary_fast(input_text, tokenizer, model, num_sentences=3,
|
167 |
+
temperature=1, top_k=50, top_p=0.5,
|
168 |
+
device=torch.device('cuda'), eos_stopping=False, sep_tok=True):
|
169 |
+
|
170 |
+
context = tokenizer.encode(input_text)
|
171 |
+
context += [tokenizer.sep_token_id]
|
172 |
+
|
173 |
+
gen_summary = generate_summary_fast(context, len(context), tokenizer, model, num_sentences,
|
174 |
+
temperature=temperature, top_k=top_k, top_p=top_p, device=device,
|
175 |
+
eos_stopping=eos_stopping)
|
176 |
+
|
177 |
+
print_summary(tokenizer.decode(context), gen_summary, "Not Given")
|
178 |
+
|
179 |
+
return gen_summary
|
180 |
+
|
181 |
+
|