hvlgo commited on
Commit
788bc4b
·
verified ·
1 Parent(s): 79303e2

Update ts_generation_mixin.py

Browse files
Files changed (1) hide show
  1. ts_generation_mixin.py +250 -251
ts_generation_mixin.py CHANGED
@@ -1,251 +1,250 @@
1
- import warnings
2
- from typing import Any, Dict, List, Optional, Union
3
- import torch
4
- from transformers import GenerationMixin, LogitsProcessorList, StoppingCriteriaList
5
- from transformers.generation import validate_stopping_criteria, EosTokenCriteria
6
- from transformers.generation.utils import GenerateNonBeamOutput, GenerateEncoderDecoderOutput, GenerateDecoderOnlyOutput
7
- from transformers.utils import ModelOutput
8
-
9
-
10
- class TSGenerationMixin(GenerationMixin):
11
-
12
- def _greedy_search(
13
- self,
14
- input_ids: torch.Tensor,
15
- logits_processor: Optional[LogitsProcessorList] = None,
16
- stopping_criteria: Optional[StoppingCriteriaList] = None,
17
- max_length: Optional[int] = None,
18
- pad_token_id: Optional[int] = None,
19
- eos_token_id: Optional[Union[int, List[int]]] = None,
20
- output_attentions: Optional[bool] = None,
21
- output_hidden_states: Optional[bool] = None,
22
- output_scores: Optional[bool] = None,
23
- output_logits: Optional[bool] = None,
24
- return_dict_in_generate: Optional[bool] = None,
25
- synced_gpus: bool = False,
26
- streamer: Optional["BaseStreamer"] = None,
27
- **model_kwargs,
28
- ) -> Union[GenerateNonBeamOutput, torch.Tensor]:
29
- input_ids_origin_device = input_ids.device
30
- input_ids = input_ids.to(self.device)
31
- if len(input_ids.shape) == 2:
32
- batch_size, cur_len = input_ids.shape
33
- else:
34
- raise ValueError('Input shape must be: [batch_size, seq_len]')
35
- # init values
36
- logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
37
- stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
38
- if max_length is not None:
39
- warnings.warn(
40
- "`max_length` is deprecated in this function, use"
41
- " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
42
- UserWarning,
43
- )
44
- stopping_criteria = validate_stopping_criteria(
45
- stopping_criteria, max_length)
46
- pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
47
- if eos_token_id is not None:
48
- stopping_criteria.append(
49
- EosTokenCriteria(eos_token_id=eos_token_id))
50
- else:
51
- # remove when the method is totally private
52
- # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
53
- eos_token_id = [
54
- criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
55
- ]
56
- eos_token_id = eos_token_id[0] if eos_token_id else None
57
- if eos_token_id is None and self.generation_config.eos_token_id is not None:
58
- eos_token_id = self.generation_config.eos_token_id
59
- stopping_criteria.append(
60
- EosTokenCriteria(eos_token_id=eos_token_id))
61
-
62
- if isinstance(eos_token_id, int):
63
- eos_token_id = [eos_token_id]
64
- output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
65
- output_attentions = (
66
- output_attentions if output_attentions is not None else self.generation_config.output_attentions
67
- )
68
- output_hidden_states = (
69
- output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
70
- )
71
- return_dict_in_generate = (
72
- return_dict_in_generate
73
- if return_dict_in_generate is not None
74
- else self.generation_config.return_dict_in_generate
75
- )
76
-
77
- # init attention / hidden states / scores tuples
78
- raw_logits = () if (return_dict_in_generate and output_logits) else None
79
- scores = () if (return_dict_in_generate and output_scores) else None
80
- decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
81
- cross_attentions = () if (return_dict_in_generate and output_attentions) else None
82
- decoder_hidden_states = () if (
83
- return_dict_in_generate and output_hidden_states) else None
84
-
85
- # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
86
- if return_dict_in_generate and self.config.is_encoder_decoder:
87
- encoder_attentions = model_kwargs["encoder_outputs"].get(
88
- "attentions") if output_attentions else None
89
- encoder_hidden_states = (
90
- model_kwargs["encoder_outputs"].get(
91
- "hidden_states") if output_hidden_states else None
92
- )
93
-
94
- # keep track of which sequences are already finished
95
- if "inputs_embeds" in model_kwargs:
96
- cur_len = model_kwargs["inputs_embeds"].shape[1]
97
- this_peer_finished = False
98
- unfinished_sequences = torch.ones(
99
- batch_size, dtype=torch.long, device=input_ids.device)
100
- model_kwargs["cache_position"] = torch.arange(
101
- cur_len, device=input_ids.device)
102
- true_seq_len = input_ids.shape[1] // self.config.input_token_len
103
- model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, -true_seq_len:]
104
-
105
- max_length = stopping_criteria.max_length
106
- while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
107
- # prepare model inputs
108
- model_inputs = self.prepare_inputs_for_generation(
109
- input_ids, **model_kwargs)
110
-
111
- input_length = input_ids.shape[1]
112
-
113
- # forward pass to get next token
114
- outputs = self(
115
- **model_inputs,
116
- return_dict=True,
117
- output_attentions=output_attentions,
118
- output_hidden_states=output_hidden_states,
119
- max_output_length=max_length - input_length,
120
- )
121
-
122
- if synced_gpus and this_peer_finished:
123
- continue # don't waste resources running the code we don't need
124
-
125
- next_token_logits = outputs.logits[:, -1, :]
126
-
127
- # pre-process distribution
128
- next_tokens_scores = logits_processor(input_ids, next_token_logits)
129
-
130
- # Store scores, attentions and hidden_states when required
131
- if return_dict_in_generate:
132
- if output_scores:
133
- scores += (next_tokens_scores,)
134
- if output_logits:
135
- raw_logits += (next_token_logits,)
136
- if output_attentions:
137
- decoder_attentions += (
138
- (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (
139
- outputs.attentions,)
140
- )
141
- if self.config.is_encoder_decoder:
142
- cross_attentions += (outputs.cross_attentions,)
143
-
144
- if output_hidden_states:
145
- decoder_hidden_states += (
146
- (outputs.decoder_hidden_states,)
147
- if self.config.is_encoder_decoder
148
- else (outputs.hidden_states,)
149
- )
150
-
151
- # argmax
152
- # next_tokens = torch.argmax(next_tokens_scores, dim=-1)
153
- next_tokens = next_tokens_scores
154
-
155
- # finished sentences should have their next token be a padding token
156
- if eos_token_id is not None:
157
- if pad_token_id is None:
158
- raise ValueError(
159
- "If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
160
- next_tokens = next_tokens * unfinished_sequences + \
161
- pad_token_id * (1 - unfinished_sequences)
162
-
163
- # update generated ids, model inputs, and length for next step
164
- horizon_length = next_tokens.shape[1] // self.config.input_token_len
165
-
166
- input_ids = torch.cat([input_ids, next_tokens], dim=-1)
167
- if streamer is not None:
168
- streamer.put(next_tokens.cpu())
169
- model_kwargs = self._update_model_kwargs_for_generation(
170
- outputs,
171
- model_kwargs,
172
- horizon_length=horizon_length,
173
- is_encoder_decoder=self.config.is_encoder_decoder,
174
- )
175
- unfinished_sequences = unfinished_sequences & ~stopping_criteria(
176
- input_ids, scores)
177
- this_peer_finished = unfinished_sequences.max() == 0
178
-
179
- if input_ids.shape[1] > max_length:
180
- input_ids = input_ids[:, :max_length]
181
-
182
- if streamer is not None:
183
- streamer.end()
184
-
185
- if return_dict_in_generate:
186
- if self.config.is_encoder_decoder:
187
- return GenerateEncoderDecoderOutput(
188
- sequences=input_ids,
189
- scores=scores,
190
- logits=raw_logits,
191
- encoder_attentions=encoder_attentions,
192
- encoder_hidden_states=encoder_hidden_states,
193
- decoder_attentions=decoder_attentions,
194
- cross_attentions=cross_attentions,
195
- decoder_hidden_states=decoder_hidden_states,
196
- past_key_values=model_kwargs.get("past_key_values"),
197
- )
198
- else:
199
- return GenerateDecoderOnlyOutput(
200
- sequences=input_ids,
201
- scores=scores,
202
- logits=raw_logits,
203
- attentions=decoder_attentions,
204
- hidden_states=decoder_hidden_states,
205
- past_key_values=model_kwargs.get("past_key_values"),
206
- )
207
- else:
208
- return input_ids
209
-
210
- def _update_model_kwargs_for_generation(
211
- self,
212
- outputs: ModelOutput,
213
- model_kwargs: Dict[str, Any],
214
- horizon_length: int = 1,
215
- is_encoder_decoder: bool = False,
216
- standardize_cache_format: bool = False,
217
- ) -> Dict[str, Any]:
218
- # update past_key_values
219
- model_kwargs["past_key_values"] = self._extract_past_from_model_output(
220
- outputs, standardize_cache_format=standardize_cache_format
221
- )
222
- if getattr(outputs, "state", None) is not None:
223
- model_kwargs["state"] = outputs.state
224
-
225
- # update token_type_ids with last value
226
- if "token_type_ids" in model_kwargs:
227
- token_type_ids = model_kwargs["token_type_ids"]
228
- model_kwargs["token_type_ids"] = torch.cat(
229
- [token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
230
-
231
- if not is_encoder_decoder:
232
- # update attention mask
233
- if "attention_mask" in model_kwargs:
234
- attention_mask = model_kwargs["attention_mask"]
235
- model_kwargs["attention_mask"] = torch.cat(
236
- [attention_mask, attention_mask.new_ones((attention_mask.shape[0], horizon_length))], dim=-1
237
- )
238
- else:
239
- # update decoder attention mask
240
- if "decoder_attention_mask" in model_kwargs:
241
- decoder_attention_mask = model_kwargs["decoder_attention_mask"]
242
- model_kwargs["decoder_attention_mask"] = torch.cat(
243
- [decoder_attention_mask, decoder_attention_mask.new_ones(
244
- (decoder_attention_mask.shape[0], horizon_length))],
245
- dim=-1,
246
- )
247
-
248
- if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None:
249
- model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + horizon_length
250
-
251
- return model_kwargs
 
1
+ class TSGenerationMixin(GenerationMixin):
2
+
3
+ def _greedy_search(
4
+ self,
5
+ input_ids: torch.Tensor,
6
+ logits_processor: Optional[LogitsProcessorList] = None,
7
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
8
+ max_length: Optional[int] = None,
9
+ pad_token_id: Optional[int] = None,
10
+ eos_token_id: Optional[Union[int, List[int]]] = None,
11
+ output_attentions: Optional[bool] = None,
12
+ output_hidden_states: Optional[bool] = None,
13
+ output_scores: Optional[bool] = None,
14
+ output_logits: Optional[bool] = None,
15
+ return_dict_in_generate: Optional[bool] = None,
16
+ synced_gpus: bool = False,
17
+ streamer: Optional["BaseStreamer"] = None,
18
+ **model_kwargs,
19
+ ) -> Union[GenerateNonBeamOutput, torch.Tensor]:
20
+ input_ids = input_ids.to(self.device)
21
+ if len(input_ids.shape) == 2:
22
+ batch_size, cur_len = input_ids.shape
23
+ if cur_len < self.config.input_token_len:
24
+ raise ValueError(
25
+ f"Input length must be at least {self.config.input_token_len}")
26
+ elif cur_len % self.config.input_token_len != 0:
27
+ new_len = (cur_len // self.config.input_token_len) * \
28
+ self.config.input_token_len
29
+ input_ids = input_ids[:, -new_len:]
30
+ else:
31
+ raise ValueError('Input shape must be: [batch_size, seq_len]')
32
+ initial_input_length = input_ids.shape[1]
33
+
34
+ # init values
35
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
36
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
37
+ if max_length is not None:
38
+ warnings.warn(
39
+ "`max_length` is deprecated in this function, use"
40
+ " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
41
+ UserWarning,
42
+ )
43
+ stopping_criteria = validate_stopping_criteria(
44
+ stopping_criteria, max_length)
45
+ pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
46
+ if eos_token_id is not None:
47
+ stopping_criteria.append(
48
+ EosTokenCriteria(eos_token_id=eos_token_id))
49
+ else:
50
+ # remove when the method is totally private
51
+ # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
52
+ eos_token_id = [
53
+ criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
54
+ ]
55
+ eos_token_id = eos_token_id[0] if eos_token_id else None
56
+ if eos_token_id is None and self.generation_config.eos_token_id is not None:
57
+ eos_token_id = self.generation_config.eos_token_id
58
+ stopping_criteria.append(
59
+ EosTokenCriteria(eos_token_id=eos_token_id))
60
+
61
+ if isinstance(eos_token_id, int):
62
+ eos_token_id = [eos_token_id]
63
+ output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
64
+ output_attentions = (
65
+ output_attentions if output_attentions is not None else self.generation_config.output_attentions
66
+ )
67
+ output_hidden_states = (
68
+ output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
69
+ )
70
+ return_dict_in_generate = (
71
+ return_dict_in_generate
72
+ if return_dict_in_generate is not None
73
+ else self.generation_config.return_dict_in_generate
74
+ )
75
+
76
+ # init attention / hidden states / scores tuples
77
+ raw_logits = () if (return_dict_in_generate and output_logits) else None
78
+ scores = () if (return_dict_in_generate and output_scores) else None
79
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
80
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
81
+ decoder_hidden_states = () if (
82
+ return_dict_in_generate and output_hidden_states) else None
83
+
84
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
85
+ if return_dict_in_generate and self.config.is_encoder_decoder:
86
+ encoder_attentions = model_kwargs["encoder_outputs"].get(
87
+ "attentions") if output_attentions else None
88
+ encoder_hidden_states = (
89
+ model_kwargs["encoder_outputs"].get(
90
+ "hidden_states") if output_hidden_states else None
91
+ )
92
+
93
+ # keep track of which sequences are already finished
94
+ if "inputs_embeds" in model_kwargs:
95
+ cur_len = model_kwargs["inputs_embeds"].shape[1]
96
+ this_peer_finished = False
97
+ unfinished_sequences = torch.ones(
98
+ batch_size, dtype=torch.long, device=input_ids.device)
99
+ model_kwargs["cache_position"] = torch.arange(
100
+ cur_len, device=input_ids.device)
101
+ true_seq_len = input_ids.shape[1] // self.config.input_token_len
102
+ model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, -true_seq_len:]
103
+
104
+ max_length = stopping_criteria.max_length
105
+ while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
106
+ # prepare model inputs
107
+ model_inputs = self.prepare_inputs_for_generation(
108
+ input_ids, **model_kwargs)
109
+
110
+ input_length = input_ids.shape[1]
111
+
112
+ # forward pass to get next token
113
+ outputs = self(
114
+ **model_inputs,
115
+ return_dict=True,
116
+ output_attentions=output_attentions,
117
+ output_hidden_states=output_hidden_states,
118
+ max_output_length=max_length - input_length,
119
+ )
120
+
121
+ if synced_gpus and this_peer_finished:
122
+ continue # don't waste resources running the code we don't need
123
+
124
+ next_token_logits = outputs.logits[:, -1, :]
125
+
126
+ # pre-process distribution
127
+ next_tokens_scores = logits_processor(input_ids, next_token_logits)
128
+
129
+ # Store scores, attentions and hidden_states when required
130
+ if return_dict_in_generate:
131
+ if output_scores:
132
+ scores += (next_tokens_scores,)
133
+ if output_logits:
134
+ raw_logits += (next_token_logits,)
135
+ if output_attentions:
136
+ decoder_attentions += (
137
+ (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (
138
+ outputs.attentions,)
139
+ )
140
+ if self.config.is_encoder_decoder:
141
+ cross_attentions += (outputs.cross_attentions,)
142
+
143
+ if output_hidden_states:
144
+ decoder_hidden_states += (
145
+ (outputs.decoder_hidden_states,)
146
+ if self.config.is_encoder_decoder
147
+ else (outputs.hidden_states,)
148
+ )
149
+
150
+ # argmax
151
+ # next_tokens = torch.argmax(next_tokens_scores, dim=-1)
152
+ next_tokens = next_tokens_scores
153
+
154
+ # finished sentences should have their next token be a padding token
155
+ if eos_token_id is not None:
156
+ if pad_token_id is None:
157
+ raise ValueError(
158
+ "If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
159
+ next_tokens = next_tokens * unfinished_sequences + \
160
+ pad_token_id * (1 - unfinished_sequences)
161
+
162
+ # update generated ids, model inputs, and length for next step
163
+ horizon_length = next_tokens.shape[1] // self.config.input_token_len
164
+
165
+ input_ids = torch.cat([input_ids, next_tokens], dim=-1)
166
+ if streamer is not None:
167
+ streamer.put(next_tokens.cpu())
168
+ model_kwargs = self._update_model_kwargs_for_generation(
169
+ outputs,
170
+ model_kwargs,
171
+ horizon_length=horizon_length,
172
+ is_encoder_decoder=self.config.is_encoder_decoder,
173
+ )
174
+ unfinished_sequences = unfinished_sequences & ~stopping_criteria(
175
+ input_ids, scores)
176
+ this_peer_finished = unfinished_sequences.max() == 0
177
+
178
+ if input_ids.shape[1] > max_length:
179
+ input_ids = input_ids[:, :max_length]
180
+
181
+ if streamer is not None:
182
+ streamer.end()
183
+
184
+ if return_dict_in_generate:
185
+ if self.config.is_encoder_decoder:
186
+ return GenerateEncoderDecoderOutput(
187
+ sequences=input_ids,
188
+ scores=scores,
189
+ logits=raw_logits,
190
+ encoder_attentions=encoder_attentions,
191
+ encoder_hidden_states=encoder_hidden_states,
192
+ decoder_attentions=decoder_attentions,
193
+ cross_attentions=cross_attentions,
194
+ decoder_hidden_states=decoder_hidden_states,
195
+ past_key_values=model_kwargs.get("past_key_values"),
196
+ )
197
+ else:
198
+ return GenerateDecoderOnlyOutput(
199
+ sequences=input_ids,
200
+ scores=scores,
201
+ logits=raw_logits,
202
+ attentions=decoder_attentions,
203
+ hidden_states=decoder_hidden_states,
204
+ past_key_values=model_kwargs.get("past_key_values"),
205
+ )
206
+ else:
207
+ return input_ids[:, -(max_length - initial_input_length):]
208
+
209
+ def _update_model_kwargs_for_generation(
210
+ self,
211
+ outputs: ModelOutput,
212
+ model_kwargs: Dict[str, Any],
213
+ horizon_length: int = 1,
214
+ is_encoder_decoder: bool = False,
215
+ standardize_cache_format: bool = False,
216
+ ) -> Dict[str, Any]:
217
+ # update past_key_values
218
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
219
+ outputs, standardize_cache_format=standardize_cache_format
220
+ )
221
+ if getattr(outputs, "state", None) is not None:
222
+ model_kwargs["state"] = outputs.state
223
+
224
+ # update token_type_ids with last value
225
+ if "token_type_ids" in model_kwargs:
226
+ token_type_ids = model_kwargs["token_type_ids"]
227
+ model_kwargs["token_type_ids"] = torch.cat(
228
+ [token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
229
+
230
+ if not is_encoder_decoder:
231
+ # update attention mask
232
+ if "attention_mask" in model_kwargs:
233
+ attention_mask = model_kwargs["attention_mask"]
234
+ model_kwargs["attention_mask"] = torch.cat(
235
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], horizon_length))], dim=-1
236
+ )
237
+ else:
238
+ # update decoder attention mask
239
+ if "decoder_attention_mask" in model_kwargs:
240
+ decoder_attention_mask = model_kwargs["decoder_attention_mask"]
241
+ model_kwargs["decoder_attention_mask"] = torch.cat(
242
+ [decoder_attention_mask, decoder_attention_mask.new_ones(
243
+ (decoder_attention_mask.shape[0], horizon_length))],
244
+ dim=-1,
245
+ )
246
+
247
+ if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None:
248
+ model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + horizon_length
249
+
250
+ return model_kwargs