nguyenvulebinh commited on
Commit
b7153a2
1 Parent(s): 2cf20d8

Upload 2 files

Browse files
Files changed (2) hide show
  1. configuration_deltalm.py +99 -0
  2. modeling_deltalm.py +1551 -0
configuration_deltalm.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """ deltalm model configuration"""
5
+
6
+ import warnings
7
+ from transformers.configuration_utils import PretrainedConfig
8
+ from transformers.utils import logging
9
+ logger = logging.get_logger(__name__)
10
+
11
+ class DeltalmConfig(PretrainedConfig):
12
+
13
+ model_type = "Deltalm"
14
+ keys_to_ignore_at_inference = ["past_key_values"]
15
+ attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
16
+
17
+ def __init__(
18
+ self,
19
+ vocab_size=250001,
20
+ max_position_embeddings=1024,
21
+ encoder_layers=12,
22
+ encoder_ffn_dim=3072,
23
+ encoder_attention_heads=12,
24
+ decoder_layers=6,
25
+ decoder_ffn_dim=3072,
26
+ decoder_attention_heads=12,
27
+ encoder_layerdrop=0.0,
28
+ decoder_layerdrop=0.0,
29
+ activation_function="gelu",
30
+ d_model=1024,
31
+ dropout=0.1,
32
+ attention_dropout=0.0,
33
+ activation_dropout=0.0,
34
+ init_std=0.02,
35
+ classifier_dropout=0.0,
36
+ scale_embedding=False,
37
+ use_cache=True,
38
+ num_labels=3,
39
+ pad_token_id=1,
40
+ bos_token_id=0,
41
+ eos_token_id=2,
42
+ is_encoder_decoder=True,
43
+ decoder_start_token_id=0,
44
+ forced_eos_token_id=2,
45
+ label_smoothing=0.1,
46
+ length_penalty=1.0,
47
+ encoder_normalize_before=False,
48
+ **kwargs
49
+ ):
50
+ self.vocab_size = vocab_size
51
+ self.max_position_embeddings = max_position_embeddings
52
+ self.d_model = d_model
53
+ self.encoder_ffn_dim = encoder_ffn_dim
54
+ self.encoder_layers = encoder_layers
55
+ self.encoder_attention_heads = encoder_attention_heads
56
+ self.decoder_ffn_dim = decoder_ffn_dim
57
+ self.decoder_layers = decoder_layers
58
+ self.decoder_attention_heads = decoder_attention_heads
59
+ self.dropout = dropout
60
+ self.attention_dropout = attention_dropout
61
+ self.activation_dropout = activation_dropout
62
+ self.activation_function = activation_function
63
+ self.init_std = init_std
64
+ self.encoder_layerdrop = encoder_layerdrop
65
+ self.decoder_layerdrop = decoder_layerdrop
66
+ self.classifier_dropout = classifier_dropout
67
+ self.use_cache = use_cache
68
+ self.num_hidden_layers = encoder_layers
69
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
70
+ self.label_smoothing = label_smoothing
71
+ self.encoder_normalize_before = encoder_normalize_before
72
+
73
+ super().__init__(
74
+ num_labels=num_labels,
75
+ pad_token_id=pad_token_id,
76
+ bos_token_id=bos_token_id,
77
+ eos_token_id=eos_token_id,
78
+ is_encoder_decoder=is_encoder_decoder,
79
+ decoder_start_token_id=decoder_start_token_id,
80
+ forced_eos_token_id=forced_eos_token_id,
81
+ length_penalty=length_penalty,
82
+ **kwargs,
83
+ )
84
+
85
+ # ensure backward compatibility for BART CNN models
86
+ if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
87
+ self.forced_bos_token_id = self.bos_token_id
88
+ warnings.warn(
89
+ f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. "
90
+ "The config can simply be saved and uploaded again to be fixed."
91
+ )
92
+
93
+ @property
94
+ def num_attention_heads(self) -> int:
95
+ return self.encoder_attention_heads
96
+
97
+ @property
98
+ def hidden_size(self) -> int:
99
+ return self.d_model
modeling_deltalm.py ADDED
@@ -0,0 +1,1551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ import copy
6
+ import math
7
+ import random
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.utils.checkpoint
11
+ from torch.nn import CrossEntropyLoss
12
+ from typing import List, Optional, Tuple, Union
13
+
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.activations import ACT2FN
16
+ from transformers.modeling_outputs import (
17
+ BaseModelOutput,
18
+ BaseModelOutputWithPastAndCrossAttentions,
19
+ CausalLMOutputWithCrossAttentions,
20
+ Seq2SeqModelOutput,
21
+ Seq2SeqLMOutput,
22
+ )
23
+ from transformers.file_utils import (
24
+ add_end_docstrings,
25
+ add_start_docstrings,
26
+ add_start_docstrings_to_model_forward,
27
+ replace_return_docstrings,
28
+ )
29
+
30
+ import logging
31
+ from .configuration_deltalm import DeltalmConfig
32
+ logger = logging.getLogger(__name__)
33
+
34
+ _CHECKPOINT_FOR_DOC = "IDEA-CCNL/Randeng-Deltalm-362M-En-Zn"
35
+ _CONFIG_FOR_DOC = "DeltalmConfig"
36
+ _TOKENIZER_FOR_DOC = "DeltalmTokenizer"
37
+
38
+ # Base model docstring
39
+ _EXPECTED_OUTPUT_SHAPE = [1, 8, 768]
40
+
41
+
42
+ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
43
+ """
44
+ Shift input ids one token to the right.
45
+ """
46
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
47
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
48
+ shifted_input_ids[:, 0] = decoder_start_token_id
49
+
50
+ if pad_token_id is None:
51
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
52
+ # replace possible -100 values in labels by `pad_token_id`
53
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
54
+
55
+ return shifted_input_ids
56
+
57
+
58
+ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):
59
+ """
60
+ Make causal mask used for bi-directional self-attention.
61
+ """
62
+ bsz, tgt_len = input_ids_shape
63
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
64
+ mask_cond = torch.arange(mask.size(-1))
65
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
66
+ mask = mask.to(dtype)
67
+
68
+ if past_key_values_length > 0:
69
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
70
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
71
+
72
+
73
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
74
+ """
75
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
76
+ """
77
+ bsz, src_len = mask.size()
78
+ tgt_len = tgt_len if tgt_len is not None else src_len
79
+
80
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
81
+
82
+ inverted_mask = 1.0 - expanded_mask
83
+
84
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
85
+
86
+
87
+ class DeltalmLearnedPositionalEmbedding(nn.Embedding):
88
+ """
89
+ This module learns positional embeddings up to a fixed maximum size.
90
+ """
91
+
92
+ def __init__(self, num_embeddings: int, embedding_dim: int):
93
+ # Deltalm is set up so that if padding_idx is specified then offset the embedding ids by 2
94
+ # and adjust num_embeddings appropriately. Other models don't have this hack
95
+ self.offset = 2
96
+ super().__init__(num_embeddings + self.offset, embedding_dim)
97
+
98
+ def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
99
+ """`input_ids_shape` is expected to be [bsz x seqlen]."""
100
+ bsz, seq_len = input_ids_shape[:2]
101
+ positions = torch.arange(
102
+ past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
103
+ )
104
+ return super().forward(positions + self.offset)
105
+
106
+
107
+ class DeltalmAttention(nn.Module):
108
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
109
+
110
+ def __init__(
111
+ self,
112
+ embed_dim: int,
113
+ num_heads: int,
114
+ dropout: float = 0.0,
115
+ is_decoder: bool = False,
116
+ bias: bool = True,
117
+ ):
118
+ super().__init__()
119
+ self.embed_dim = embed_dim
120
+ self.num_heads = num_heads
121
+ self.dropout = dropout
122
+ self.head_dim = embed_dim // num_heads
123
+
124
+ if (self.head_dim * num_heads) != self.embed_dim:
125
+ raise ValueError(
126
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
127
+ f" and `num_heads`: {num_heads})."
128
+ )
129
+ self.scaling = self.head_dim**-0.5
130
+ self.is_decoder = is_decoder
131
+
132
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
133
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
134
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
135
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
136
+
137
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
138
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
139
+
140
+ def forward(
141
+ self,
142
+ hidden_states: torch.Tensor,
143
+ key_value_states: Optional[torch.Tensor] = None,
144
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
145
+ attention_mask: Optional[torch.Tensor] = None,
146
+ layer_head_mask: Optional[torch.Tensor] = None,
147
+ output_attentions: bool = False,
148
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
149
+ """Input shape: Batch x Time x Channel"""
150
+
151
+ # if key_value_states are provided this layer is used as a cross-attention layer
152
+ # for the decoder
153
+ is_cross_attention = key_value_states is not None
154
+
155
+ bsz, tgt_len, _ = hidden_states.size()
156
+
157
+ # get query proj
158
+ query_states = self.q_proj(hidden_states) * self.scaling
159
+ # get key, value proj
160
+ if is_cross_attention and past_key_value is not None:
161
+ # reuse k,v, cross_attentions
162
+ key_states = past_key_value[0]
163
+ value_states = past_key_value[1]
164
+ elif is_cross_attention:
165
+ # cross_attentions
166
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
167
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
168
+ elif past_key_value is not None:
169
+ # reuse k, v, self_attention
170
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
171
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
172
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
173
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
174
+ else:
175
+ # self_attention
176
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
177
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
178
+
179
+ if self.is_decoder:
180
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
181
+ # Further calls to cross_attention layer can then reuse all cross-attention
182
+ # key/value_states (first "if" case)
183
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
184
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
185
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
186
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
187
+ past_key_value = (key_states, value_states)
188
+
189
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
190
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
191
+ key_states = key_states.view(*proj_shape)
192
+ value_states = value_states.view(*proj_shape)
193
+
194
+ src_len = key_states.size(1)
195
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
196
+
197
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
198
+ raise ValueError(
199
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
200
+ f" {attn_weights.size()}"
201
+ )
202
+
203
+ if attention_mask is not None:
204
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
205
+ raise ValueError(
206
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
207
+ )
208
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
209
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
210
+
211
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
212
+
213
+ if layer_head_mask is not None:
214
+ if layer_head_mask.size() != (self.num_heads,):
215
+ raise ValueError(
216
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
217
+ f" {layer_head_mask.size()}"
218
+ )
219
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
220
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
221
+
222
+ if output_attentions:
223
+ # this operation is a bit awkward, but it's required to
224
+ # make sure that attn_weights keeps its gradient.
225
+ # In order to do so, attn_weights have to be reshaped
226
+ # twice and have to be reused in the following
227
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
228
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
229
+ else:
230
+ attn_weights_reshaped = None
231
+
232
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
233
+
234
+ attn_output = torch.bmm(attn_probs, value_states)
235
+
236
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
237
+ raise ValueError(
238
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
239
+ f" {attn_output.size()}"
240
+ )
241
+
242
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
243
+ attn_output = attn_output.transpose(1, 2)
244
+
245
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
246
+ # partitioned aross GPUs when using tensor-parallelism.
247
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
248
+
249
+ attn_output = self.out_proj(attn_output)
250
+
251
+ return attn_output, attn_weights_reshaped, past_key_value
252
+
253
+
254
+ class DeltalmEncoderLayer(nn.Module):
255
+ def __init__(self, config: DeltalmConfig):
256
+ super().__init__()
257
+ self.embed_dim = config.d_model
258
+ self.self_attn = DeltalmAttention(
259
+ embed_dim=self.embed_dim,
260
+ num_heads=config.encoder_attention_heads,
261
+ dropout=config.attention_dropout,
262
+ )
263
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
264
+ self.dropout = config.dropout
265
+ self.activation_fn = ACT2FN[config.activation_function]
266
+ self.activation_dropout = config.activation_dropout
267
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
268
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
269
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
270
+
271
+ def forward(
272
+ self,
273
+ hidden_states: torch.FloatTensor,
274
+ attention_mask: torch.FloatTensor,
275
+ layer_head_mask: torch.FloatTensor,
276
+ output_attentions: Optional[bool] = False,
277
+ ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
278
+ """
279
+ Args:
280
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
281
+ attention_mask (`torch.FloatTensor`): attention mask of size
282
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
283
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
284
+ `(encoder_attention_heads,)`.
285
+ output_attentions (`bool`, *optional*):
286
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
287
+ returned tensors for more detail.
288
+ """
289
+ residual = hidden_states
290
+ hidden_states, attn_weights, _ = self.self_attn(
291
+ hidden_states=hidden_states,
292
+ attention_mask=attention_mask,
293
+ layer_head_mask=layer_head_mask,
294
+ output_attentions=output_attentions,
295
+ )
296
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
297
+ hidden_states = residual + hidden_states
298
+ hidden_states = self.self_attn_layer_norm(hidden_states)
299
+
300
+ residual = hidden_states
301
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
302
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
303
+ hidden_states = self.fc2(hidden_states)
304
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
305
+ hidden_states = residual + hidden_states
306
+ hidden_states = self.final_layer_norm(hidden_states)
307
+
308
+ if hidden_states.dtype == torch.float16 and (
309
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
310
+ ):
311
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
312
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
313
+
314
+ outputs = (hidden_states,)
315
+
316
+ if output_attentions:
317
+ outputs += (attn_weights,)
318
+
319
+ return outputs
320
+
321
+
322
+ class DeltalmDecoderLayer(nn.Module):
323
+ def __init__(self, config: DeltalmConfig):
324
+ super().__init__()
325
+ self.embed_dim = config.d_model
326
+
327
+ self.self_attn = DeltalmAttention(
328
+ embed_dim=self.embed_dim,
329
+ num_heads=config.decoder_attention_heads,
330
+ dropout=config.attention_dropout,
331
+ is_decoder=True,
332
+ )
333
+ self.dropout = config.dropout
334
+ self.activation_fn = ACT2FN[config.activation_function]
335
+ self.activation_dropout = config.activation_dropout
336
+
337
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
338
+ self.encoder_attn = DeltalmAttention(
339
+ self.embed_dim,
340
+ config.decoder_attention_heads,
341
+ dropout=config.attention_dropout,
342
+ is_decoder=True,
343
+ )
344
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
345
+ self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
346
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
347
+ self.fc3 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
348
+ self.fc4 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
349
+
350
+ self.ffn_layer_norm = nn.LayerNorm(self.embed_dim)
351
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
352
+
353
+ def forward(
354
+ self,
355
+ hidden_states: torch.Tensor,
356
+ attention_mask: Optional[torch.Tensor] = None,
357
+ encoder_hidden_states: Optional[torch.Tensor] = None,
358
+ encoder_attention_mask: Optional[torch.Tensor] = None,
359
+ layer_head_mask: Optional[torch.Tensor] = None,
360
+ cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
361
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
362
+ output_attentions: Optional[bool] = False,
363
+ use_cache: Optional[bool] = True,
364
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
365
+ """
366
+ Args:
367
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
368
+ attention_mask (`torch.FloatTensor`): attention mask of size
369
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
370
+ encoder_hidden_states (`torch.FloatTensor`):
371
+ cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
372
+ encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
373
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
374
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
375
+ `(encoder_attention_heads,)`.
376
+ cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
377
+ size `(decoder_attention_heads,)`.
378
+ past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
379
+ output_attentions (`bool`, *optional*):
380
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
381
+ returned tensors for more detail.
382
+ """
383
+ residual = hidden_states
384
+
385
+ # Self Attention
386
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
387
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
388
+ # add present self-attn cache to positions 1,2 of present_key_value tuple
389
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
390
+ hidden_states=hidden_states,
391
+ past_key_value=self_attn_past_key_value,
392
+ attention_mask=attention_mask,
393
+ layer_head_mask=layer_head_mask,
394
+ output_attentions=output_attentions,
395
+ )
396
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
397
+ hidden_states = residual + hidden_states
398
+ hidden_states = self.self_attn_layer_norm(hidden_states)
399
+
400
+ # Add another ffn after self-attention to keep the structure same to encoder-layer
401
+ residual = hidden_states
402
+ hidden_states = self.activation_fn(self.fc3(hidden_states))
403
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
404
+ hidden_states = self.fc4(hidden_states)
405
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
406
+ hidden_states = residual + hidden_states
407
+ hidden_states = self.ffn_layer_norm(hidden_states)
408
+
409
+ # Cross-Attention Block
410
+ cross_attn_present_key_value = None
411
+ cross_attn_weights = None
412
+ if encoder_hidden_states is not None:
413
+ residual = hidden_states
414
+
415
+ # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
416
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
417
+ hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
418
+ hidden_states=hidden_states,
419
+ key_value_states=encoder_hidden_states,
420
+ attention_mask=encoder_attention_mask,
421
+ layer_head_mask=cross_attn_layer_head_mask,
422
+ past_key_value=cross_attn_past_key_value,
423
+ output_attentions=output_attentions,
424
+ )
425
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
426
+ hidden_states = residual + hidden_states
427
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
428
+
429
+ # add cross-attn to positions 3,4 of present_key_value tuple
430
+ present_key_value = present_key_value + cross_attn_present_key_value
431
+
432
+ # Fully Connected
433
+ residual = hidden_states
434
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
435
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
436
+ hidden_states = self.fc2(hidden_states)
437
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
438
+ hidden_states = residual + hidden_states
439
+ hidden_states = self.final_layer_norm(hidden_states)
440
+
441
+ outputs = (hidden_states,)
442
+
443
+ if output_attentions:
444
+ outputs += (self_attn_weights, cross_attn_weights)
445
+
446
+ if use_cache:
447
+ outputs += (present_key_value,)
448
+
449
+ return outputs
450
+
451
+
452
+ class DeltalmPretrainedModel(PreTrainedModel):
453
+ config_class = DeltalmConfig
454
+ base_model_prefix = "model"
455
+ supports_gradient_checkpointing = True
456
+
457
+ def _init_weights(self, module):
458
+ std = self.config.init_std
459
+ if isinstance(module, nn.Linear):
460
+ module.weight.data.normal_(mean=0.0, std=std)
461
+ if module.bias is not None:
462
+ module.bias.data.zero_()
463
+ elif isinstance(module, nn.Embedding):
464
+ module.weight.data.normal_(mean=0.0, std=std)
465
+ if module.padding_idx is not None:
466
+ module.weight.data[module.padding_idx].zero_()
467
+
468
+ def _set_gradient_checkpointing(self, module, value=False):
469
+ if isinstance(module, (DeltalmDecoder, DeltalmEncoder)):
470
+ module.gradient_checkpointing = value
471
+
472
+
473
+ class DeltalmDecoder(DeltalmPretrainedModel):
474
+ """
475
+ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`DeltalmDecoderLayer`]
476
+ Args:
477
+ config: DeltalmConfig
478
+ embed_tokens (nn.Embedding): output embedding
479
+ """
480
+
481
+ def __init__(self, config: DeltalmConfig, embed_tokens: Optional[nn.Embedding] = None):
482
+ super().__init__(config)
483
+ self.dropout = config.dropout
484
+ self.layerdrop = config.decoder_layerdrop
485
+ self.padding_idx = config.pad_token_id
486
+ self.max_target_positions = config.max_position_embeddings
487
+ self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
488
+
489
+ if embed_tokens is not None:
490
+ self.embed_tokens = embed_tokens
491
+ else:
492
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
493
+
494
+ self.embed_positions = DeltalmLearnedPositionalEmbedding(
495
+ config.max_position_embeddings,
496
+ config.d_model,
497
+ )
498
+ self.layers = nn.ModuleList([DeltalmDecoderLayer(config) for _ in range(config.decoder_layers)])
499
+ self.layernorm_embedding = nn.LayerNorm(config.d_model)
500
+
501
+ self.gradient_checkpointing = False
502
+ # Initialize weights and apply final processing
503
+ self.post_init()
504
+
505
+ # fairseq实现了一个 nn.init.normal_(self.output_projection.weight, mean=0, std=self.output_embed_dim ** -0.5) 对最后的output权重做正态分布转换?
506
+
507
+ def get_input_embeddings(self):
508
+ return self.embed_tokens
509
+
510
+ def set_input_embeddings(self, value):
511
+ self.embed_tokens = value
512
+
513
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
514
+ # create causal mask
515
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
516
+ combined_attention_mask = None
517
+ if input_shape[-1] > 1:
518
+ combined_attention_mask = _make_causal_mask(
519
+ input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
520
+ ).to(inputs_embeds.device)
521
+
522
+ if attention_mask is not None:
523
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
524
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
525
+ combined_attention_mask = (
526
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
527
+ )
528
+
529
+ return combined_attention_mask
530
+
531
+ def forward(
532
+ self,
533
+ input_ids: torch.LongTensor = None,
534
+ attention_mask: Optional[torch.Tensor] = None,
535
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
536
+ encoder_attention_mask: Optional[torch.LongTensor] = None,
537
+ head_mask: Optional[torch.Tensor] = None,
538
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
539
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
540
+ inputs_embeds: Optional[torch.FloatTensor] = None,
541
+ use_cache: Optional[bool] = None,
542
+ output_attentions: Optional[bool] = None,
543
+ output_hidden_states: Optional[bool] = None,
544
+ return_dict: Optional[bool] = None,
545
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
546
+ r"""
547
+ Args:
548
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
549
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
550
+ provide it.
551
+ Indices can be obtained using [`DeltalmTokenizer`]. See [`PreTrainedTokenizer.encode`] and
552
+ [`PreTrainedTokenizer.__call__`] for details.
553
+ [What are input IDs?](../glossary#input-ids)
554
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
555
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
556
+ - 1 for tokens that are **not masked**,
557
+ - 0 for tokens that are **masked**.
558
+ [What are attention masks?](../glossary#attention-mask)
559
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
560
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
561
+ of the decoder.
562
+ encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
563
+ Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
564
+ selected in `[0, 1]`:
565
+ - 1 for tokens that are **not masked**,
566
+ - 0 for tokens that are **masked**.
567
+ [What are attention masks?](../glossary#attention-mask)
568
+ head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
569
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
570
+ - 1 indicates the head is **not masked**,
571
+ - 0 indicates the head is **masked**.
572
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
573
+ Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
574
+ cross-attention on hidden heads. Mask values selected in `[0, 1]`:
575
+ - 1 indicates the head is **not masked**,
576
+ - 0 indicates the head is **masked**.
577
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
578
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
579
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
580
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
581
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
582
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
583
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
584
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
585
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of
586
+ shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
587
+ `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
588
+ control over how to convert `input_ids` indices into associated vectors than the model's internal
589
+ embedding lookup matrix.
590
+ output_attentions (`bool`, *optional*):
591
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
592
+ returned tensors for more detail.
593
+ output_hidden_states (`bool`, *optional*):
594
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
595
+ for more detail.
596
+ return_dict (`bool`, *optional*):
597
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
598
+ """
599
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
600
+ output_hidden_states = (
601
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
602
+ )
603
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
604
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
605
+
606
+ # retrieve input_ids and inputs_embeds
607
+ if input_ids is not None and inputs_embeds is not None:
608
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
609
+ elif input_ids is not None:
610
+ input_shape = input_ids.size()
611
+ input_ids = input_ids.view(-1, input_shape[-1])
612
+ elif inputs_embeds is not None:
613
+ input_shape = inputs_embeds.size()[:-1]
614
+ else:
615
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
616
+
617
+ # past_key_values_length
618
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
619
+
620
+ if inputs_embeds is None:
621
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
622
+
623
+ attention_mask = self._prepare_decoder_attention_mask(
624
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
625
+ )
626
+
627
+ # expand encoder attention mask
628
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
629
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
630
+ encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
631
+
632
+ # embed positions
633
+ positions = self.embed_positions(input_shape, past_key_values_length)
634
+
635
+ hidden_states = inputs_embeds + positions
636
+ hidden_states = self.layernorm_embedding(hidden_states)
637
+
638
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
639
+
640
+ # decoder layers
641
+ all_hidden_states = () if output_hidden_states else None
642
+ all_self_attns = () if output_attentions else None
643
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
644
+ next_decoder_cache = () if use_cache else None
645
+
646
+ # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
647
+ for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
648
+ if attn_mask is not None:
649
+ if attn_mask.size()[0] != (len(self.layers)):
650
+ raise ValueError(
651
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
652
+ f" {head_mask.size()[0]}."
653
+ )
654
+
655
+ for idx, decoder_layer in enumerate(self.layers):
656
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
657
+ if output_hidden_states:
658
+ all_hidden_states += (hidden_states,)
659
+ dropout_probability = random.uniform(0, 1)
660
+ if self.training and (dropout_probability < self.layerdrop):
661
+ continue
662
+
663
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
664
+
665
+ if self.gradient_checkpointing and self.training:
666
+
667
+ if use_cache:
668
+ logger.warning(
669
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
670
+ )
671
+ use_cache = False
672
+
673
+ def create_custom_forward(module):
674
+ def custom_forward(*inputs):
675
+ # None for past_key_value
676
+ return module(*inputs, output_attentions, use_cache)
677
+
678
+ return custom_forward
679
+
680
+ layer_outputs = torch.utils.checkpoint.checkpoint(
681
+ create_custom_forward(decoder_layer),
682
+ hidden_states,
683
+ attention_mask,
684
+ encoder_hidden_states,
685
+ encoder_attention_mask,
686
+ head_mask[idx] if head_mask is not None else None,
687
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
688
+ None,
689
+ )
690
+ else:
691
+
692
+ layer_outputs = decoder_layer(
693
+ hidden_states,
694
+ attention_mask=attention_mask,
695
+ encoder_hidden_states=encoder_hidden_states,
696
+ encoder_attention_mask=encoder_attention_mask,
697
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
698
+ cross_attn_layer_head_mask=(
699
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
700
+ ),
701
+ past_key_value=past_key_value,
702
+ output_attentions=output_attentions,
703
+ use_cache=use_cache,
704
+ )
705
+ hidden_states = layer_outputs[0]
706
+
707
+ if use_cache:
708
+ next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
709
+
710
+ if output_attentions:
711
+ all_self_attns += (layer_outputs[1],)
712
+
713
+ if encoder_hidden_states is not None:
714
+ all_cross_attentions += (layer_outputs[2],)
715
+
716
+ # add hidden states from the last decoder layer
717
+ if output_hidden_states:
718
+ all_hidden_states += (hidden_states,)
719
+
720
+ next_cache = next_decoder_cache if use_cache else None
721
+ if not return_dict:
722
+ return tuple(
723
+ v
724
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
725
+ if v is not None
726
+ )
727
+ return BaseModelOutputWithPastAndCrossAttentions(
728
+ last_hidden_state=hidden_states,
729
+ past_key_values=next_cache,
730
+ hidden_states=all_hidden_states,
731
+ attentions=all_self_attns,
732
+ cross_attentions=all_cross_attentions,
733
+ )
734
+
735
+
736
+ DELTALM_START_DOCSTRING = r"""
737
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
738
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
739
+ etc.)
740
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
741
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
742
+ and behavior.
743
+ Parameters:
744
+ config ([`DeltalmConfig`]):
745
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
746
+ load the weights associated with the model, only the configuration. Check out the
747
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
748
+ """
749
+
750
+ DELTALM_GENERATION_EXAMPLE = r"""
751
+ Summarization example:
752
+ ```python
753
+ >>> from transformers import DeltalmTokenizer, DeltalmForConditionalGeneration
754
+ >>> model = DeltalmForConditionalGeneration.from_pretrained("facebook/deltalm-large-cnn")
755
+ >>> tokenizer = DeltalmTokenizer.from_pretrained("facebook/deltalm-large-cnn")
756
+ >>> ARTICLE_TO_SUMMARIZE = (
757
+ ... "PG&E stated it scheduled the blackouts in response to forecasts for high winds "
758
+ ... "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
759
+ ... "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
760
+ ... )
761
+ >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="pt")
762
+ >>> # Generate Summary
763
+ >>> summary_ids = model.generate(inputs["input_ids"], num_beams=2, min_length=0, max_length=20)
764
+ >>> tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
765
+ 'PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions'
766
+ ```
767
+ Mask filling example:
768
+ ```python
769
+ >>> from transformers import DeltalmTokenizer, DeltalmForConditionalGeneration
770
+ >>> tokenizer = DeltalmTokenizer.from_pretrained("facebook/deltalm-base")
771
+ >>> model = DeltalmForConditionalGeneration.from_pretrained("facebook/deltalm-base")
772
+ >>> TXT = "My friends are <mask> but they eat too many carbs."
773
+ >>> input_ids = tokenizer([TXT], return_tensors="pt")["input_ids"]
774
+ >>> logits = model(input_ids).logits
775
+ >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
776
+ >>> probs = logits[0, masked_index].softmax(dim=0)
777
+ >>> values, predictions = probs.topk(5)
778
+ >>> tokenizer.decode(predictions).split()
779
+ ['not', 'good', 'healthy', 'great', 'very']
780
+ ```
781
+ """
782
+
783
+ DELTALM_INPUTS_DOCSTRING = r"""
784
+ Args:
785
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
786
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
787
+ it.
788
+ Indices can be obtained using [`DeltalmTokenizer`]. See [`PreTrainedTokenizer.encode`] and
789
+ [`PreTrainedTokenizer.__call__`] for details.
790
+ [What are input IDs?](../glossary#input-ids)
791
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
792
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
793
+ - 1 for tokens that are **not masked**,
794
+ - 0 for tokens that are **masked**.
795
+ [What are attention masks?](../glossary#attention-mask)
796
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
797
+ Indices of decoder input sequence tokens in the vocabulary.
798
+ Indices can be obtained using [`DeltalmTokenizer`]. See [`PreTrainedTokenizer.encode`] and
799
+ [`PreTrainedTokenizer.__call__`] for details.
800
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
801
+ Deltalm uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
802
+ is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
803
+ For translation and summarization training, `decoder_input_ids` should be provided. If no
804
+ `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
805
+ for denoising pre-training following the paper.
806
+ decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
807
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
808
+ be used by default.
809
+ If you want to change padding behavior, you should read [`modeling_deltalm._prepare_decoder_attention_mask`]
810
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
811
+ information on the default strategy.
812
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
813
+ Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
814
+ - 1 indicates the head is **not masked**,
815
+ - 0 indicates the head is **masked**.
816
+ decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
817
+ Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
818
+ - 1 indicates the head is **not masked**,
819
+ - 0 indicates the head is **masked**.
820
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
821
+ Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
822
+ 1]`:
823
+ - 1 indicates the head is **not masked**,
824
+ - 0 indicates the head is **masked**.
825
+ encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
826
+ Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
827
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
828
+ hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
829
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
830
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
831
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
832
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
833
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
834
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
835
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
836
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
837
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape
838
+ `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you
839
+ can choose to directly pass an embedded representation. This is useful if you want more control over how to
840
+ convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
841
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
842
+ Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
843
+ representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
844
+ input (see `past_key_values`). This is useful if you want more control over how to convert
845
+ `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
846
+ If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
847
+ of `inputs_embeds`.
848
+ use_cache (`bool`, *optional*):
849
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
850
+ `past_key_values`).
851
+ output_attentions (`bool`, *optional*):
852
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
853
+ tensors for more detail.
854
+ output_hidden_states (`bool`, *optional*):
855
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
856
+ more detail.
857
+ return_dict (`bool`, *optional*):
858
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
859
+ """
860
+
861
+
862
+ class DeltalmEncoder(DeltalmPretrainedModel):
863
+ """
864
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
865
+ [`DeltalmEncoderLayer`].
866
+ Args:
867
+ config: DeltalmConfig
868
+ embed_tokens (nn.Embedding): output embedding
869
+ """
870
+
871
+ def __init__(self, config: DeltalmConfig, embed_tokens: Optional[nn.Embedding] = None):
872
+ super().__init__(config)
873
+
874
+ self.dropout = config.dropout
875
+ self.layerdrop = config.encoder_layerdrop
876
+
877
+ embed_dim = config.d_model
878
+ self.padding_idx = config.pad_token_id
879
+ self.max_source_positions = config.max_position_embeddings
880
+ self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
881
+
882
+ if embed_tokens is not None:
883
+ self.embed_tokens = embed_tokens
884
+ else:
885
+ self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
886
+
887
+ self.embed_positions = DeltalmLearnedPositionalEmbedding(
888
+ config.max_position_embeddings,
889
+ embed_dim,
890
+ )
891
+ self.layers = nn.ModuleList([DeltalmEncoderLayer(config) for _ in range(config.encoder_layers)])
892
+ self.layernorm_embedding = nn.LayerNorm(embed_dim)
893
+
894
+ self.gradient_checkpointing = False
895
+ if config.encoder_normalize_before:
896
+ self.layer_norm = nn.LayerNorm(embed_dim)
897
+ else:
898
+ self.layer_norm = None
899
+ # Initialize weights and apply final processing
900
+ self.post_init()
901
+
902
+ def get_input_embeddings(self):
903
+ return self.embed_tokens
904
+
905
+ def set_input_embeddings(self, value):
906
+ self.embed_tokens = value
907
+
908
+ def forward(
909
+ self,
910
+ input_ids: torch.LongTensor = None,
911
+ attention_mask: Optional[torch.Tensor] = None,
912
+ head_mask: Optional[torch.Tensor] = None,
913
+ inputs_embeds: Optional[torch.FloatTensor] = None,
914
+ output_attentions: Optional[bool] = None,
915
+ output_hidden_states: Optional[bool] = None,
916
+ return_dict: Optional[bool] = None,
917
+ ) -> Union[Tuple, BaseModelOutput]:
918
+ r"""
919
+ Args:
920
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
921
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
922
+ provide it.
923
+ Indices can be obtained using [`DeltalmTokenizer`]. See [`PreTrainedTokenizer.encode`] and
924
+ [`PreTrainedTokenizer.__call__`] for details.
925
+ [What are input IDs?](../glossary#input-ids)
926
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
927
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
928
+ - 1 for tokens that are **not masked**,
929
+ - 0 for tokens that are **masked**.
930
+ [What are attention masks?](../glossary#attention-mask)
931
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
932
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
933
+ - 1 indicates the head is **not masked**,
934
+ - 0 indicates the head is **masked**.
935
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
936
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
937
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
938
+ than the model's internal embedding lookup matrix.
939
+ output_attentions (`bool`, *optional*):
940
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
941
+ returned tensors for more detail.
942
+ output_hidden_states (`bool`, *optional*):
943
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
944
+ for more detail.
945
+ return_dict (`bool`, *optional*):
946
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
947
+ """
948
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
949
+ output_hidden_states = (
950
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
951
+ )
952
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
953
+
954
+ # retrieve input_ids and inputs_embeds
955
+ if input_ids is not None and inputs_embeds is not None:
956
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
957
+ elif input_ids is not None:
958
+ input_shape = input_ids.size()
959
+ input_ids = input_ids.view(-1, input_shape[-1])
960
+ elif inputs_embeds is not None:
961
+ input_shape = inputs_embeds.size()[:-1]
962
+ else:
963
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
964
+
965
+ if inputs_embeds is None:
966
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
967
+
968
+ embed_pos = self.embed_positions(input_shape)
969
+
970
+ hidden_states = inputs_embeds + embed_pos
971
+ hidden_states = self.layernorm_embedding(hidden_states)
972
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
973
+
974
+ # expand attention_mask
975
+ if attention_mask is not None:
976
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
977
+ attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)
978
+
979
+ encoder_states = () if output_hidden_states else None
980
+ all_attentions = () if output_attentions else None
981
+
982
+ # check if head_mask has a correct number of layers specified if desired
983
+ if head_mask is not None:
984
+ if head_mask.size()[0] != (len(self.layers)):
985
+ raise ValueError(
986
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
987
+ f" {head_mask.size()[0]}."
988
+ )
989
+
990
+ for idx, encoder_layer in enumerate(self.layers):
991
+ if output_hidden_states:
992
+ encoder_states = encoder_states + (hidden_states,)
993
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
994
+ dropout_probability = random.uniform(0, 1)
995
+ if self.training and (dropout_probability < self.layerdrop): # skip the layer
996
+ layer_outputs = (None, None)
997
+ else:
998
+ if self.gradient_checkpointing and self.training:
999
+
1000
+ def create_custom_forward(module):
1001
+ def custom_forward(*inputs):
1002
+ return module(*inputs, output_attentions)
1003
+
1004
+ return custom_forward
1005
+
1006
+ layer_outputs = torch.utils.checkpoint.checkpoint(
1007
+ create_custom_forward(encoder_layer),
1008
+ hidden_states,
1009
+ attention_mask,
1010
+ (head_mask[idx] if head_mask is not None else None),
1011
+ )
1012
+ else:
1013
+ layer_outputs = encoder_layer(
1014
+ hidden_states,
1015
+ attention_mask,
1016
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
1017
+ output_attentions=output_attentions,
1018
+ )
1019
+
1020
+ hidden_states = layer_outputs[0]
1021
+
1022
+ if output_attentions:
1023
+ all_attentions = all_attentions + (layer_outputs[1],)
1024
+
1025
+ if self.layer_norm is not None:
1026
+ hidden_states = self.layer_norm(hidden_states)
1027
+ # hidden_states = self.layernorm_embedding(hidden_states)
1028
+
1029
+ if output_hidden_states:
1030
+ encoder_states = encoder_states + (hidden_states,)
1031
+
1032
+ if not return_dict:
1033
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
1034
+ return BaseModelOutput(
1035
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
1036
+ )
1037
+
1038
+
1039
+ class DeltalmModel(DeltalmPretrainedModel):
1040
+ def __init__(self, config: DeltalmConfig):
1041
+ super().__init__(config)
1042
+
1043
+ padding_idx, vocab_size = config.pad_token_id, config.vocab_size
1044
+ self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
1045
+
1046
+ self.encoder = DeltalmEncoder(config, self.shared)
1047
+ self.decoder = DeltalmDecoder(config, self.shared)
1048
+
1049
+ # Initialize weights and apply final processing
1050
+ self.post_init()
1051
+
1052
+ def get_input_embeddings(self):
1053
+ return self.shared
1054
+
1055
+ def set_input_embeddings(self, value):
1056
+ self.shared = value
1057
+ self.encoder.embed_tokens = self.shared
1058
+ self.decoder.embed_tokens = self.shared
1059
+
1060
+ def get_encoder(self):
1061
+ return self.encoder
1062
+
1063
+ def get_decoder(self):
1064
+ return self.decoder
1065
+
1066
+ @add_start_docstrings_to_model_forward(DELTALM_INPUTS_DOCSTRING)
1067
+ # @add_code_sample_docstrings(
1068
+ # processor_class=_TOKENIZER_FOR_DOC,
1069
+ # checkpoint=_CHECKPOINT_FOR_DOC,
1070
+ # output_type=Seq2SeqModelOutput,
1071
+ # config_class=_CONFIG_FOR_DOC,
1072
+ # expected_output=_EXPECTED_OUTPUT_SHAPE,
1073
+ # )
1074
+ def forward(
1075
+ self,
1076
+ input_ids: torch.LongTensor = None,
1077
+ attention_mask: Optional[torch.Tensor] = None,
1078
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1079
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
1080
+ head_mask: Optional[torch.Tensor] = None,
1081
+ decoder_head_mask: Optional[torch.Tensor] = None,
1082
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1083
+ encoder_outputs: Optional[List[torch.FloatTensor]] = None,
1084
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1085
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1086
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1087
+ use_cache: Optional[bool] = None,
1088
+ output_attentions: Optional[bool] = None,
1089
+ output_hidden_states: Optional[bool] = None,
1090
+ return_dict: Optional[bool] = None,
1091
+ ) -> Union[Tuple, Seq2SeqModelOutput]:
1092
+
1093
+ # different to other models, Deltalm automatically creates decoder_input_ids from
1094
+ # input_ids if no decoder_input_ids are provided
1095
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
1096
+ if input_ids is None:
1097
+ raise ValueError(
1098
+ "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
1099
+ "passed, `input_ids` cannot be `None`. Please pass either "
1100
+ "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
1101
+ )
1102
+
1103
+ decoder_input_ids = shift_tokens_right(
1104
+ input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
1105
+ )
1106
+
1107
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1108
+ output_hidden_states = (
1109
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1110
+ )
1111
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1112
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1113
+
1114
+ if encoder_outputs is None:
1115
+ encoder_outputs = self.encoder(
1116
+ input_ids=input_ids,
1117
+ attention_mask=attention_mask,
1118
+ head_mask=head_mask,
1119
+ inputs_embeds=inputs_embeds,
1120
+ output_attentions=output_attentions,
1121
+ output_hidden_states=output_hidden_states,
1122
+ return_dict=return_dict,
1123
+ )
1124
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
1125
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1126
+ encoder_outputs = BaseModelOutput(
1127
+ last_hidden_state=encoder_outputs[0],
1128
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1129
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1130
+ )
1131
+
1132
+ # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
1133
+ decoder_outputs = self.decoder(
1134
+ input_ids=decoder_input_ids,
1135
+ attention_mask=decoder_attention_mask,
1136
+ encoder_hidden_states=encoder_outputs[0],
1137
+ encoder_attention_mask=attention_mask,
1138
+ head_mask=decoder_head_mask,
1139
+ cross_attn_head_mask=cross_attn_head_mask,
1140
+ past_key_values=past_key_values,
1141
+ inputs_embeds=decoder_inputs_embeds,
1142
+ use_cache=use_cache,
1143
+ output_attentions=output_attentions,
1144
+ output_hidden_states=output_hidden_states,
1145
+ return_dict=return_dict,
1146
+ )
1147
+
1148
+ if not return_dict:
1149
+ return decoder_outputs + encoder_outputs
1150
+
1151
+ logger.debug("last_hidden_state.size: %s", decoder_outputs.last_hidden_state)
1152
+ return Seq2SeqModelOutput(
1153
+ last_hidden_state=decoder_outputs.last_hidden_state,
1154
+ past_key_values=decoder_outputs.past_key_values,
1155
+ decoder_hidden_states=decoder_outputs.hidden_states,
1156
+ decoder_attentions=decoder_outputs.attentions,
1157
+ cross_attentions=decoder_outputs.cross_attentions,
1158
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1159
+ encoder_hidden_states=encoder_outputs.hidden_states,
1160
+ encoder_attentions=encoder_outputs.attentions,
1161
+ )
1162
+
1163
+
1164
+ @add_start_docstrings(
1165
+ "The DELTALM Model with a language modeling head. Can be used for translation.", DELTALM_START_DOCSTRING
1166
+ )
1167
+ class DeltalmForConditionalGeneration(DeltalmPretrainedModel):
1168
+ base_model_prefix = "model"
1169
+ _keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head.weight"]
1170
+
1171
+ def __init__(self, config: DeltalmConfig):
1172
+ super().__init__(config)
1173
+ self.model = DeltalmModel(config)
1174
+ self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
1175
+ self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
1176
+
1177
+ # Initialize weights and apply final processing
1178
+ self.post_init()
1179
+
1180
+ def get_encoder(self):
1181
+ return self.model.get_encoder()
1182
+
1183
+ def get_decoder(self):
1184
+ return self.model.get_decoder()
1185
+
1186
+ def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
1187
+ new_embeddings = super().resize_token_embeddings(new_num_tokens)
1188
+ self._resize_final_logits_bias(new_num_tokens)
1189
+ return new_embeddings
1190
+
1191
+ def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
1192
+ logger.debug("Debug: coming to _resize_final_logits_bias")
1193
+ old_num_tokens = self.final_logits_bias.shape[-1]
1194
+ if new_num_tokens <= old_num_tokens:
1195
+ new_bias = self.final_logits_bias[:, :new_num_tokens]
1196
+ else:
1197
+ extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
1198
+ new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
1199
+ self.register_buffer("final_logits_bias", new_bias)
1200
+
1201
+ def get_output_embeddings(self):
1202
+ return self.lm_head
1203
+
1204
+ def set_output_embeddings(self, new_embeddings):
1205
+ self.lm_head = new_embeddings
1206
+
1207
+ @add_start_docstrings_to_model_forward(DELTALM_INPUTS_DOCSTRING)
1208
+ @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
1209
+ @add_end_docstrings(DELTALM_GENERATION_EXAMPLE)
1210
+ def forward(
1211
+ self,
1212
+ input_ids: torch.LongTensor = None,
1213
+ attention_mask: Optional[torch.Tensor] = None,
1214
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1215
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
1216
+ head_mask: Optional[torch.Tensor] = None,
1217
+ decoder_head_mask: Optional[torch.Tensor] = None,
1218
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1219
+ encoder_outputs: Optional[List[torch.FloatTensor]] = None,
1220
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1221
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1222
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1223
+ labels: Optional[torch.LongTensor] = None,
1224
+ use_cache: Optional[bool] = None,
1225
+ output_attentions: Optional[bool] = None,
1226
+ output_hidden_states: Optional[bool] = None,
1227
+ return_dict: Optional[bool] = None,
1228
+ ) -> Union[Tuple, Seq2SeqLMOutput]:
1229
+ r"""
1230
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1231
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1232
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1233
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1234
+ Returns:
1235
+ """
1236
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1237
+
1238
+ logger.debug("Comming to Generation!")
1239
+
1240
+ if labels is not None:
1241
+ logger.debug("Debug: *************** Before label ***************** ")
1242
+ logger.debug("Debug: %s", labels.size())
1243
+ if use_cache:
1244
+ logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
1245
+ use_cache = False
1246
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
1247
+ decoder_input_ids = shift_tokens_right(
1248
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
1249
+ )
1250
+
1251
+ logger.debug("Debug: ************ After labels ************")
1252
+ logger.debug("Debug: %s", labels.size())
1253
+
1254
+ outputs = self.model(
1255
+ input_ids,
1256
+ attention_mask=attention_mask,
1257
+ decoder_input_ids=decoder_input_ids,
1258
+ encoder_outputs=encoder_outputs,
1259
+ decoder_attention_mask=decoder_attention_mask,
1260
+ head_mask=head_mask,
1261
+ decoder_head_mask=decoder_head_mask,
1262
+ cross_attn_head_mask=cross_attn_head_mask,
1263
+ past_key_values=past_key_values,
1264
+ inputs_embeds=inputs_embeds,
1265
+ decoder_inputs_embeds=decoder_inputs_embeds,
1266
+ use_cache=use_cache,
1267
+ output_attentions=output_attentions,
1268
+ output_hidden_states=output_hidden_states,
1269
+ return_dict=return_dict,
1270
+ )
1271
+ lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
1272
+ # print(self.lm_head)
1273
+ logger.debug("Debug: logit_size: %s", lm_logits.size())
1274
+
1275
+ # logger.debug("Debug: change logit size: ", lm_logits.view(-1, self.config.vocab_size).size())
1276
+ # logger.debug("Debug: change label size: ", labels.view(-1).size())
1277
+ masked_lm_loss = None
1278
+
1279
+ if labels is not None:
1280
+ # logger.debug("Debug: model label_size: %s", labels.size())
1281
+ # loss_fct = CrossEntropyLoss()
1282
+ # masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
1283
+ loss_fct = CrossEntropyLoss()
1284
+ masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
1285
+ # label_smoothing = self.config.label_smoothing
1286
+ # # logger.debug("Debug: label.size: ", )
1287
+ # if label_smoothing == 0:
1288
+ # # compute label smoothed loss
1289
+ # loss_fct = CrossEntropyLoss()
1290
+ # masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
1291
+ # else:
1292
+ # m = torch.nn.LogSoftmax(dim=-1)
1293
+ # lprobs = m(lm_logits.float())
1294
+ # # lprobs = m(lm_logits)
1295
+ # # # torch.set_printoptions(linewidth=200)
1296
+ # loss_fn = label_smoothed_nll_loss
1297
+ # masked_lm_loss, _ = loss_fn(lprobs.view(-1, lprobs.size(-1)), labels.view(-1), label_smoothing, self.config.pad_token_id)
1298
+
1299
+ if not return_dict:
1300
+ logger.debug("Debug: not return dict")
1301
+ output = (lm_logits,) + outputs[1:]
1302
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1303
+
1304
+ return Seq2SeqLMOutput(
1305
+ loss=masked_lm_loss,
1306
+ logits=lm_logits,
1307
+ past_key_values=outputs.past_key_values,
1308
+ decoder_hidden_states=outputs.decoder_hidden_states,
1309
+ decoder_attentions=outputs.decoder_attentions,
1310
+ cross_attentions=outputs.cross_attentions,
1311
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1312
+ encoder_hidden_states=outputs.encoder_hidden_states,
1313
+ encoder_attentions=outputs.encoder_attentions,
1314
+ )
1315
+
1316
+ def prepare_inputs_for_generation(
1317
+ self,
1318
+ decoder_input_ids,
1319
+ past=None,
1320
+ attention_mask=None,
1321
+ head_mask=None,
1322
+ decoder_head_mask=None,
1323
+ cross_attn_head_mask=None,
1324
+ use_cache=None,
1325
+ encoder_outputs=None,
1326
+ **kwargs
1327
+ ):
1328
+ # cut decoder_input_ids if past is used
1329
+ if past is not None:
1330
+ decoder_input_ids = decoder_input_ids[:, -1:]
1331
+
1332
+ return {
1333
+ "input_ids": None, # encoder_outputs is defined. input_ids not needed
1334
+ "encoder_outputs": encoder_outputs,
1335
+ "past_key_values": past,
1336
+ "decoder_input_ids": decoder_input_ids,
1337
+ "attention_mask": attention_mask,
1338
+ "head_mask": head_mask,
1339
+ "decoder_head_mask": decoder_head_mask,
1340
+ "cross_attn_head_mask": cross_attn_head_mask,
1341
+ "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
1342
+ }
1343
+
1344
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
1345
+ return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
1346
+
1347
+ @staticmethod
1348
+ def _reorder_cache(past, beam_idx):
1349
+ reordered_past = ()
1350
+ for layer_past in past:
1351
+ # cached cross_attention states don't have to be reordered -> they are always the same
1352
+ reordered_past += (
1353
+ tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
1354
+ )
1355
+ return reordered_past
1356
+
1357
+
1358
+ class DeltalmDecoderWrapper(DeltalmPretrainedModel):
1359
+ """
1360
+ This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
1361
+ used in combination with the [`EncoderDecoderModel`] framework.
1362
+ """
1363
+
1364
+ def __init__(self, config):
1365
+ super().__init__(config)
1366
+ self.decoder = DeltalmDecoder(config)
1367
+
1368
+ def forward(self, *args, **kwargs):
1369
+ return self.decoder(*args, **kwargs)
1370
+
1371
+
1372
+ class DeltalmForCausalLM(DeltalmPretrainedModel):
1373
+ def __init__(self, config):
1374
+ config = copy.deepcopy(config)
1375
+ config.is_decoder = True
1376
+ config.is_encoder_decoder = False
1377
+ super().__init__(config)
1378
+ self.model = DeltalmDecoderWrapper(config)
1379
+
1380
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1381
+
1382
+ # Initialize weights and apply final processing
1383
+ self.post_init()
1384
+
1385
+ def get_input_embeddings(self):
1386
+ return self.model.decoder.embed_tokens
1387
+
1388
+ def set_input_embeddings(self, value):
1389
+ self.model.decoder.embed_tokens = value
1390
+
1391
+ def get_output_embeddings(self):
1392
+ return self.lm_head
1393
+
1394
+ def set_output_embeddings(self, new_embeddings):
1395
+ self.lm_head = new_embeddings
1396
+
1397
+ def set_decoder(self, decoder):
1398
+ self.model.decoder = decoder
1399
+
1400
+ def get_decoder(self):
1401
+ return self.model.decoder
1402
+
1403
+ @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
1404
+ def forward(
1405
+ self,
1406
+ input_ids: torch.LongTensor = None,
1407
+ attention_mask: Optional[torch.Tensor] = None,
1408
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1409
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1410
+ head_mask: Optional[torch.Tensor] = None,
1411
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1412
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1413
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1414
+ labels: Optional[torch.LongTensor] = None,
1415
+ use_cache: Optional[bool] = None,
1416
+ output_attentions: Optional[bool] = None,
1417
+ output_hidden_states: Optional[bool] = None,
1418
+ return_dict: Optional[bool] = None,
1419
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
1420
+ r"""
1421
+ Args:
1422
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1423
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
1424
+ provide it.
1425
+ Indices can be obtained using [`DeltalmTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1426
+ [`PreTrainedTokenizer.__call__`] for details.
1427
+ [What are input IDs?](../glossary#input-ids)
1428
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1429
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1430
+ - 1 for tokens that are **not masked**,
1431
+ - 0 for tokens that are **masked**.
1432
+ [What are attention masks?](../glossary#attention-mask)
1433
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1434
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
1435
+ if the model is configured as a decoder.
1436
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
1437
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
1438
+ in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
1439
+ head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
1440
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
1441
+ - 1 indicates the head is **not masked**,
1442
+ - 0 indicates the head is **masked**.
1443
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
1444
+ Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
1445
+ - 1 indicates the head is **not masked**,
1446
+ - 0 indicates the head is **masked**.
1447
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1448
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1449
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
1450
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
1451
+ tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
1452
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
1453
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
1454
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
1455
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
1456
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1457
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1458
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1459
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1460
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1461
+ use_cache (`bool`, *optional*):
1462
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1463
+ (see `past_key_values`).
1464
+ - 1 for tokens that are **not masked**,
1465
+ - 0 for tokens that are **masked**.
1466
+ output_attentions (`bool`, *optional*):
1467
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1468
+ returned tensors for more detail.
1469
+ output_hidden_states (`bool`, *optional*):
1470
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1471
+ for more detail.
1472
+ return_dict (`bool`, *optional*):
1473
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1474
+ Returns:
1475
+ Example:
1476
+ ```python
1477
+ >>> from transformers import DeltalmTokenizer, DeltalmForCausalLM
1478
+ >>> tokenizer = DeltalmTokenizer.from_pretrained("facebook/deltalm-base")
1479
+ >>> model = DeltalmForCausalLM.from_pretrained("facebook/deltalm-base", add_cross_attention=False)
1480
+ >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
1481
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1482
+ >>> outputs = model(**inputs)
1483
+ >>> logits = outputs.logits
1484
+ >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size]
1485
+ >>> list(logits.shape) == expected_shape
1486
+ True
1487
+ ```"""
1488
+
1489
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1490
+ output_hidden_states = (
1491
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1492
+ )
1493
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1494
+
1495
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1496
+ outputs = self.model.decoder(
1497
+ input_ids=input_ids,
1498
+ attention_mask=attention_mask,
1499
+ encoder_hidden_states=encoder_hidden_states,
1500
+ encoder_attention_mask=encoder_attention_mask,
1501
+ head_mask=head_mask,
1502
+ cross_attn_head_mask=cross_attn_head_mask,
1503
+ past_key_values=past_key_values,
1504
+ inputs_embeds=inputs_embeds,
1505
+ use_cache=use_cache,
1506
+ output_attentions=output_attentions,
1507
+ output_hidden_states=output_hidden_states,
1508
+ return_dict=return_dict,
1509
+ )
1510
+
1511
+ logits = self.lm_head(outputs[0])
1512
+
1513
+ loss = None
1514
+ if labels is not None:
1515
+ loss_fct = CrossEntropyLoss()
1516
+ loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
1517
+
1518
+ if not return_dict:
1519
+ output = (logits,) + outputs[1:]
1520
+ return (loss,) + output if loss is not None else output
1521
+
1522
+ return CausalLMOutputWithCrossAttentions(
1523
+ loss=loss,
1524
+ logits=logits,
1525
+ past_key_values=outputs.past_key_values,
1526
+ hidden_states=outputs.hidden_states,
1527
+ attentions=outputs.attentions,
1528
+ cross_attentions=outputs.cross_attentions,
1529
+ )
1530
+
1531
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
1532
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1533
+ if attention_mask is None:
1534
+ attention_mask = input_ids.new_ones(input_ids.shape)
1535
+
1536
+ if past:
1537
+ input_ids = input_ids[:, -1:]
1538
+ # first step, decoder_cached_states are empty
1539
+ return {
1540
+ "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
1541
+ "attention_mask": attention_mask,
1542
+ "past_key_values": past,
1543
+ "use_cache": use_cache,
1544
+ }
1545
+
1546
+ @staticmethod
1547
+ def _reorder_cache(past, beam_idx):
1548
+ reordered_past = ()
1549
+ for layer_past in past:
1550
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
1551
+ return reordered_past