hajekad commited on
Commit
ce0d2c6
1 Parent(s): b5c41b5

Upload utils.py

Browse files
Files changed (1) hide show
  1. utils.py +181 -0
utils.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # helpful functions for summary generation
2
+ # the code is a customized version of:
3
+ # https://github.com/SKRohit/Generating_Text_Summary_With_GPT2/blob/master/utils.py
4
+
5
+ import json
6
+ import math
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from transformers import GPT2TokenizerFast
10
+ #from tqdm import tnrange
11
+ import nltk
12
+ # nltk.download('punkt')
13
+ from nltk.tokenize import sent_tokenize
14
+
15
+
16
+ def print_summary(context, gen_summary, gold_summary):
17
+ print('input_text', end='\n\n')
18
+ print(context, end='\n\n')
19
+ print("generated_summary", end='\n\n')
20
+ print(gen_summary, end='\n\n')
21
+ print('golden_summary', end='\n\n')
22
+ print(gold_summary, end='\n\n')
23
+
24
+
25
+ def add_special_tokens(tokenizer_path):
26
+ """ Returns GPT2 tokenizer after adding separator and padding tokens """
27
+ tokenizer = GPT2TokenizerFast.from_pretrained(tokenizer_path, pad_token='<|endoftext|>')
28
+ special_tokens = {'sep_token':'<|sep|>'}
29
+ num_add_toks = tokenizer.add_special_tokens(special_tokens)
30
+ return tokenizer
31
+
32
+
33
+ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
34
+ """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
35
+ Args:
36
+ logits: logits distribution shape (vocabulary size)
37
+ top_k > 0: keep only top k tokens with highest probability (top-k filtering).
38
+ top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
39
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
40
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
41
+ """
42
+ assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
43
+ top_k = min(top_k, logits.size(-1)) # Safety check
44
+ if top_k > 0:
45
+ # Remove all tokens with a probability less than the last token of the top-k
46
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
47
+ logits[indices_to_remove] = filter_value
48
+
49
+ if top_p > 0.0:
50
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
51
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
52
+
53
+ # Remove tokens with cumulative probability above the threshold
54
+ sorted_indices_to_remove = cumulative_probs > top_p
55
+ # Shift the indices to the right to keep also the first token above the threshold
56
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
57
+ sorted_indices_to_remove[..., 0] = 0
58
+
59
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
60
+ logits[indices_to_remove] = filter_value
61
+ return logits
62
+
63
+
64
+ def sample_seq_fast(model, context, length, num_sentences, device, temperature=1, top_k=0, top_p=0.0, eos_stopping=False):
65
+ """ Generates a sequence of tokens
66
+ Args:
67
+ model: gpt/gpt2 model
68
+ context: tokenized text using gpt/gpt2 tokenizer
69
+ length: length of generated sequence.
70
+ device: torch.device object.
71
+ temperature >0: used to control the randomness of predictions by scaling the logits before applying softmax.
72
+ top_k > 0: keep only top k tokens with highest probability (top-k filtering).
73
+ top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
74
+ """
75
+
76
+ # generates one senence more than wanted
77
+ # looks at the generated token and estimates the num of sentences on the go
78
+ # after n+1 times ".!?" it takes first n sentences by sent_tokenize
79
+ sent_to_gen = num_sentences + 1
80
+
81
+ context = torch.tensor(context, dtype=torch.long, device=device)
82
+ context = context.unsqueeze(0)
83
+ generated = context
84
+ with torch.no_grad():
85
+ for _ in range(length):
86
+ inputs = {'input_ids': generated}
87
+ assert len(inputs["input_ids"]) <= 1024 #########################
88
+ outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
89
+ next_token_logits = outputs[0][0, -1, :] / temperature
90
+ filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
91
+ next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
92
+ if not next_token and eos_stopping:
93
+ break
94
+ generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
95
+ if not eos_stopping and next_token in [1, 14, 31]:
96
+ sent_to_gen -= 1
97
+ if not sent_to_gen:
98
+ break
99
+ return generated
100
+
101
+
102
+ def generate_summary_fast(context_enc, sep_idx, tokenizer, model, num_sentences, temperature=1, top_k=50, top_p=0.5,
103
+ device=torch.device('cuda'), eos_stopping=False):
104
+
105
+
106
+ # generates one senence more than wanted
107
+ # looks at the generated token and estimates the num of sentences on the go
108
+ # after n+1 times ".!?" it takes first n sentences by sent_tokenize
109
+
110
+ generated_text = sample_seq_fast(model, context_enc, 1024-sep_idx, num_sentences, device, temperature, top_k, top_p, eos_stopping=eos_stopping)
111
+ generated_text = generated_text[0, len(context_enc):].tolist()
112
+ gen_summary = tokenizer.convert_ids_to_tokens(generated_text,skip_special_tokens=True)
113
+ gen_summary = tokenizer.convert_tokens_to_string(gen_summary)
114
+
115
+ # extract <num_sentences> sentences
116
+ if not eos_stopping:
117
+ gen_summary.replace("...", ".")
118
+ try:
119
+ gen_summary = " ".join(nltk.sent_tokenize(gen_summary)[:num_sentences])
120
+ except:
121
+ pass
122
+ return gen_summary
123
+
124
+
125
+
126
+ def generate_eval_file(data, data_type, tokenizer, model, save_dir, field, num_sentences=5,
127
+ max_summaries=0, temperature=1, top_k=50, top_p=0.5,
128
+ device=torch.device('cuda'), eval_step=True, eos_stopping=False, skip=0):
129
+ print(data_type)
130
+ max_summaries = math.inf if max_summaries == "full" else max_summaries
131
+ len_data = min(max_summaries, len(data))
132
+ disp_len = "full" if max_summaries == math.inf else len_data
133
+ if eos_stopping:
134
+ save_file = save_dir + f"/{data_type}_{disp_len}_sent{num_sentences}_eos_topk{top_k}_topp{top_p}.jsonl"
135
+ else:
136
+ save_file = save_dir + f"/{data_type}_{disp_len}_sent{num_sentences}_topk{top_k}_topp{top_p}.jsonl"
137
+ print(f"saving to: {save_file}")
138
+
139
+ how_open = ""
140
+ if skip:
141
+ how_open = "a"
142
+ else:
143
+ how_open = "w+"
144
+ with open(save_file, how_open) as output:
145
+ for s in range(skip, len_data):
146
+ if s%100 == 0:
147
+ print(s)
148
+ sample = data[s]
149
+ sep_idx = sample['sum_idx']
150
+ context = sample['input_ids'][:sep_idx].tolist()
151
+ gold_summary = sample['input_ids'][sep_idx+1:][:100].tolist()
152
+ # generating with the new faster and better method
153
+ gen_summary = generate_summary_fast(context, sep_idx, tokenizer, model, num_sentences,
154
+ temperature=temperature, top_k=top_k, top_p=top_p,
155
+ device=device, eos_stopping=eos_stopping)
156
+
157
+ if not eval_step:
158
+ print_summary(tokenizer.decode(context), gen_summary, tokenizer.decode(gold_summary))
159
+ else:
160
+ new_doc = {field: gen_summary}
161
+ line = json
162
+ json.dump(new_doc, output, ensure_ascii=False)
163
+ output.write("\n")
164
+
165
+
166
+ def generate_one_summary_fast(input_text, tokenizer, model, num_sentences=3,
167
+ temperature=1, top_k=50, top_p=0.5,
168
+ device=torch.device('cuda'), eos_stopping=False, sep_tok=True):
169
+
170
+ context = tokenizer.encode(input_text)
171
+ context += [tokenizer.sep_token_id]
172
+
173
+ gen_summary = generate_summary_fast(context, len(context), tokenizer, model, num_sentences,
174
+ temperature=temperature, top_k=top_k, top_p=top_p, device=device,
175
+ eos_stopping=eos_stopping)
176
+
177
+ print_summary(tokenizer.decode(context), gen_summary, "Not Given")
178
+
179
+ return gen_summary
180
+
181
+