Yibin Lei commited on
Commit
b5975ee
·
1 Parent(s): 2322527

Upload bidirectional implementation

Browse files
Files changed (1) hide show
  1. bidirectional_mistral.py +257 -0
bidirectional_mistral.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is adapted from https://github.com/McGill-NLP/llm2vec.
3
+ """
4
+
5
+ import torch
6
+
7
+ from transformers import (
8
+ MistralModel,
9
+ MistralPreTrainedModel,
10
+ MistralForCausalLM,
11
+ MistralConfig,
12
+ )
13
+ from transformers.models.mistral.modeling_mistral import (
14
+ MistralDecoderLayer,
15
+ MistralRMSNorm,
16
+ MistralAttention,
17
+ MistralFlashAttention2,
18
+ MistralSdpaAttention,
19
+ MistralMLP,
20
+ )
21
+ from torch import nn
22
+ from transformers.utils import logging
23
+ from transformers.cache_utils import Cache, StaticCache, SlidingWindowCache
24
+
25
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
26
+
27
+ from peft import PeftModel
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ def is_transformers_attn_greater_or_equal_4_43_1():
33
+ import importlib.metadata
34
+ from packaging import version
35
+ from transformers.utils.import_utils import _is_package_available
36
+ if not _is_package_available("transformers"):
37
+ return False
38
+
39
+ return version.parse(importlib.metadata.version("transformers")) >= version.parse(
40
+ "4.43.1"
41
+ )
42
+
43
+ class ModifiedMistralAttention(MistralAttention):
44
+ def __init__(self, *args, **kwargs):
45
+ super().__init__(*args, **kwargs)
46
+ self.is_causal = False
47
+
48
+
49
+ class ModifiedMistralFlashAttention2(MistralFlashAttention2):
50
+ def __init__(self, *args, **kwargs):
51
+ super().__init__(*args, **kwargs)
52
+ self.is_causal = False
53
+
54
+
55
+ class ModifiedMistralSdpaAttention(MistralSdpaAttention):
56
+ def __init__(self, *args, **kwargs):
57
+ super().__init__(*args, **kwargs)
58
+ self.is_causal = False
59
+
60
+
61
+ MISTRAL_ATTENTION_CLASSES = {
62
+ "eager": ModifiedMistralAttention,
63
+ "flash_attention_2": ModifiedMistralFlashAttention2,
64
+ "sdpa": ModifiedMistralSdpaAttention,
65
+ }
66
+
67
+
68
+ class ModifiedMistralDecoderLayer(MistralDecoderLayer):
69
+ def __init__(self, config: MistralConfig, layer_idx: int):
70
+ nn.Module.__init__(self)
71
+ self.hidden_size = config.hidden_size
72
+
73
+ self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](
74
+ config, layer_idx
75
+ )
76
+
77
+ self.mlp = MistralMLP(config)
78
+ self.input_layernorm = MistralRMSNorm(
79
+ config.hidden_size, eps=config.rms_norm_eps
80
+ )
81
+ self.post_attention_layernorm = MistralRMSNorm(
82
+ config.hidden_size, eps=config.rms_norm_eps
83
+ )
84
+
85
+
86
+ class MistralBiModel(MistralModel):
87
+ _no_split_modules = ["ModifiedMistralDecoderLayer"]
88
+
89
+ def __init__(self, config: MistralConfig):
90
+ if not is_transformers_attn_greater_or_equal_4_43_1():
91
+ raise ValueError(
92
+ "The current implementation of LlamaEncoderModel follows modeling_llama.py of transformers version >= 4.43.1"
93
+ )
94
+ MistralPreTrainedModel.__init__(self, config)
95
+ self.padding_idx = config.pad_token_id
96
+ self.vocab_size = config.vocab_size
97
+
98
+ self.embed_tokens = nn.Embedding(
99
+ config.vocab_size, config.hidden_size, self.padding_idx
100
+ )
101
+ assert config._attn_implementation == "flash_attention_2"
102
+ self.layers = nn.ModuleList(
103
+ [
104
+ ModifiedMistralDecoderLayer(config, layer_idx)
105
+ for layer_idx in range(config.num_hidden_layers)
106
+ ]
107
+ )
108
+ self._attn_implementation = config._attn_implementation
109
+ self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
110
+
111
+ self.gradient_checkpointing = False
112
+ # Initialize weights and apply final processing
113
+ self.post_init()
114
+
115
+ # Copied from forward() in transformers.models.mistral.modeling_mistral.MistralModel
116
+ def _update_causal_mask(
117
+ self,
118
+ attention_mask: torch.Tensor,
119
+ input_tensor: torch.Tensor,
120
+ cache_position: torch.Tensor,
121
+ past_key_values: Cache,
122
+ use_cache: bool,
123
+ output_attentions: bool,
124
+ ):
125
+ if self._attn_implementation == "flash_attention_2":
126
+ if attention_mask is not None and use_cache:
127
+ is_padding_right = (
128
+ attention_mask[:, -1].sum().item() != input_tensor.size()[0]
129
+ )
130
+ if is_padding_right:
131
+ raise ValueError(
132
+ "You are attempting to perform batched generation with padding_side='right'"
133
+ " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
134
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
135
+ )
136
+ if attention_mask is not None and 0.0 in attention_mask:
137
+ return attention_mask
138
+ return None
139
+
140
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
141
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
142
+ # to infer the attention mask.
143
+
144
+ # cache_position must be valid here no matter which cache we use
145
+ past_seen_tokens = cache_position[0] if past_key_values is not None else 0
146
+ using_static_cache = isinstance(past_key_values, StaticCache)
147
+ using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
148
+
149
+ # if (
150
+ # self.config._attn_implementation == "sdpa"
151
+ # and not (using_static_cache or using_sliding_window_cache)
152
+ # and not output_attentions
153
+ # ):
154
+ # if AttentionMaskConverter._ignore_causal_mask_sdpa(
155
+ # attention_mask,
156
+ # inputs_embeds=input_tensor,
157
+ # past_key_values_length=past_seen_tokens,
158
+ # sliding_window=self.config.sliding_window,
159
+ # is_training=self.training,
160
+ # ):
161
+ # return None
162
+
163
+ dtype, device = input_tensor.dtype, input_tensor.device
164
+ min_dtype = torch.finfo(dtype).min
165
+ sequence_length = input_tensor.shape[1]
166
+ # SlidingWindowCache
167
+ if using_sliding_window_cache:
168
+ target_length = max(sequence_length, self.config.sliding_window)
169
+ # StaticCache
170
+ elif using_static_cache:
171
+ target_length = past_key_values.get_max_length()
172
+ # DynamicCache or no cache
173
+ else:
174
+ target_length = (
175
+ attention_mask.shape[-1]
176
+ if isinstance(attention_mask, torch.Tensor)
177
+ else past_seen_tokens + sequence_length + 1
178
+ )
179
+
180
+ if attention_mask is not None and attention_mask.dim() == 4:
181
+ # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
182
+ if attention_mask.max() != 0:
183
+ raise ValueError(
184
+ "Custom 4D attention mask should be passed in inverted form with max==0`"
185
+ )
186
+ causal_mask = attention_mask
187
+ else:
188
+ causal_mask = torch.zeros(
189
+ (sequence_length, target_length), dtype=dtype, device=device
190
+ ) # causal_mask = torch.full(
191
+ # (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
192
+ # )
193
+ exclude_mask = torch.arange(
194
+ target_length, device=device
195
+ ) > cache_position.reshape(-1, 1)
196
+ if self.config.sliding_window is not None:
197
+ if (
198
+ not using_sliding_window_cache
199
+ or sequence_length > self.config.sliding_window
200
+ ):
201
+ exclude_mask.bitwise_or_(
202
+ torch.arange(target_length, device=device)
203
+ <= (cache_position.reshape(-1, 1) - self.config.sliding_window)
204
+ )
205
+ causal_mask *= exclude_mask
206
+ causal_mask = causal_mask[None, None, :, :].expand(
207
+ input_tensor.shape[0], 1, -1, -1
208
+ )
209
+ if attention_mask is not None:
210
+ causal_mask = (
211
+ causal_mask.clone()
212
+ ) # copy to contiguous memory for in-place edit
213
+ if attention_mask.dim() == 2:
214
+ mask_length = attention_mask.shape[-1]
215
+ padding_mask = (
216
+ causal_mask[:, :, :, :mask_length]
217
+ + attention_mask[:, None, None, :]
218
+ )
219
+ padding_mask = padding_mask == 0
220
+ causal_mask[:, :, :, :mask_length] = causal_mask[
221
+ :, :, :, :mask_length
222
+ ].masked_fill(padding_mask, min_dtype)
223
+
224
+ if (
225
+ self.config._attn_implementation == "sdpa"
226
+ and attention_mask is not None
227
+ and attention_mask.device.type == "cuda"
228
+ and not output_attentions
229
+ ):
230
+ causal_mask = AttentionMaskConverter._unmask_unattended(
231
+ causal_mask, min_dtype
232
+ )
233
+
234
+ return causal_mask
235
+
236
+
237
+ class MistralBiForCausalLM(MistralForCausalLM):
238
+ def __init__(self, config):
239
+ MistralPreTrainedModel.__init__(self, config)
240
+ self.model = MistralBiModel(config)
241
+ self.vocab_size = config.vocab_size
242
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
243
+
244
+ # Initialize weights and apply final processing
245
+ self.post_init()
246
+
247
+ # getter for PEFT model
248
+ def get_model_for_peft(self):
249
+ return self.model
250
+
251
+ # setter for PEFT model
252
+ def set_model_for_peft(self, model: PeftModel):
253
+ self.model = model
254
+
255
+ # save the PEFT model
256
+ def save_peft_model(self, path):
257
+ self.model.save_pretrained(path)