feat: selective activation checkpointing
#16
by
Markus28
- opened
- configuration_bert.py +21 -1
- modeling_bert.py +12 -8
configuration_bert.py
CHANGED
@@ -55,6 +55,10 @@ class JinaBertConfig(PretrainedConfig):
|
|
55 |
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
56 |
The epsilon used by the layer normalization layers.
|
57 |
window_size (`tuple`, *optional*, defaults to `(-1, -1)`): If not the default, use local attention
|
|
|
|
|
|
|
|
|
58 |
"""
|
59 |
|
60 |
model_type = "bert"
|
@@ -86,6 +90,7 @@ class JinaBertConfig(PretrainedConfig):
|
|
86 |
emb_pooler=None,
|
87 |
classifier_dropout=None,
|
88 |
num_loras=5,
|
|
|
89 |
**kwargs,
|
90 |
):
|
91 |
assert 'position_embedding_type' not in kwargs
|
@@ -95,6 +100,20 @@ class JinaBertConfig(PretrainedConfig):
|
|
95 |
if mlp_type == 'fused_mlp' and hidden_act not in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]:
|
96 |
raise ValueError('Fused MLP only supports approximate gelu')
|
97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
self.vocab_size = vocab_size
|
99 |
self.hidden_size = hidden_size
|
100 |
self.num_hidden_layers = num_hidden_layers
|
@@ -118,4 +137,5 @@ class JinaBertConfig(PretrainedConfig):
|
|
118 |
self.use_qk_norm = use_qk_norm
|
119 |
self.emb_pooler = emb_pooler
|
120 |
self.classifier_dropout = classifier_dropout
|
121 |
-
self.num_loras = num_loras
|
|
|
|
55 |
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
56 |
The epsilon used by the layer normalization layers.
|
57 |
window_size (`tuple`, *optional*, defaults to `(-1, -1)`): If not the default, use local attention
|
58 |
+
activation_checkpoint_lvl (`int`, *optional*, defaults to `100`): How many layers to activation-checkpoint.
|
59 |
+
If larger than 0, the MLP activation checkpointing level is expected to be 0 for the first
|
60 |
+
`activation_checkpoint_lvl` layers. The activation checkpointing will only come into effect
|
61 |
+
after `model.gradient_checkpointing_enable()` is called.
|
62 |
"""
|
63 |
|
64 |
model_type = "bert"
|
|
|
90 |
emb_pooler=None,
|
91 |
classifier_dropout=None,
|
92 |
num_loras=5,
|
93 |
+
activation_checkpoint_lvl=100,
|
94 |
**kwargs,
|
95 |
):
|
96 |
assert 'position_embedding_type' not in kwargs
|
|
|
100 |
if mlp_type == 'fused_mlp' and hidden_act not in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]:
|
101 |
raise ValueError('Fused MLP only supports approximate gelu')
|
102 |
|
103 |
+
if mlp_checkpoint_lvl != 0 and mlp_type != 'fused_mlp':
|
104 |
+
raise ValueError('MLP checkpointing only available for `fused_mlp`')
|
105 |
+
|
106 |
+
if activation_checkpoint_lvl > 0 and isinstance(mlp_checkpoint_lvl, int) and mlp_checkpoint_lvl > 0:
|
107 |
+
raise ValueError('Trying to use layer-wise activation checkpointing and MLP-checkpointing '
|
108 |
+
'in every layer simultaneously. Either only use one of the techniques, '
|
109 |
+
'or specify layer-wise MLP checkpointing.')
|
110 |
+
elif activation_checkpoint_lvl > 0 and mlp_checkpoint_lvl > 0:
|
111 |
+
for layer_idx, mlp_lvl in enumerate(mlp_checkpoint_lvl):
|
112 |
+
if layer_idx < activation_checkpoint_lvl and mlp_lvl > 0:
|
113 |
+
raise ValueError(f'Layer {layer_idx} is being checkpointed as a whole and its MLP '
|
114 |
+
f'is being checkpointed. Either remove MLP checkpointing for this layer '
|
115 |
+
f'or reduce the `activation_checkpoint_lvl` appropriately')
|
116 |
+
|
117 |
self.vocab_size = vocab_size
|
118 |
self.hidden_size = hidden_size
|
119 |
self.num_hidden_layers = num_hidden_layers
|
|
|
137 |
self.use_qk_norm = use_qk_norm
|
138 |
self.emb_pooler = emb_pooler
|
139 |
self.classifier_dropout = classifier_dropout
|
140 |
+
self.num_loras = num_loras
|
141 |
+
self.activation_checkpoint_lvl = activation_checkpoint_lvl
|
modeling_bert.py
CHANGED
@@ -180,13 +180,17 @@ class BertEncoder(nn.Module):
|
|
180 |
[create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
|
181 |
)
|
182 |
self._grad_checkpointing = False
|
|
|
183 |
|
184 |
@property
|
185 |
def gradient_checkpointing(self):
|
186 |
return self._grad_checkpointing
|
187 |
|
188 |
@gradient_checkpointing.setter
|
189 |
-
def gradient_checkpointing(self, value):
|
|
|
|
|
|
|
190 |
self._grad_checkpointing = value
|
191 |
|
192 |
def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
|
@@ -198,8 +202,8 @@ class BertEncoder(nn.Module):
|
|
198 |
mixer_kwargs = (
|
199 |
{"key_padding_mask": key_padding_mask.bool()} if key_padding_mask is not None else None
|
200 |
)
|
201 |
-
for layer in self.layers:
|
202 |
-
if self._grad_checkpointing:
|
203 |
hidden_states = torch.utils.checkpoint.checkpoint(
|
204 |
layer,
|
205 |
hidden_states,
|
@@ -217,8 +221,8 @@ class BertEncoder(nn.Module):
|
|
217 |
)
|
218 |
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
|
219 |
if subset_mask is None:
|
220 |
-
for layer in self.layers:
|
221 |
-
if self._grad_checkpointing:
|
222 |
hidden_states = torch.utils.checkpoint.checkpoint(
|
223 |
layer,
|
224 |
hidden_states,
|
@@ -229,8 +233,8 @@ class BertEncoder(nn.Module):
|
|
229 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
230 |
hidden_states = pad_input(hidden_states, indices, batch, seqlen)
|
231 |
else:
|
232 |
-
for layer in self.layers[:-1]:
|
233 |
-
if self._grad_checkpointing:
|
234 |
hidden_states = torch.utils.checkpoint.checkpoint(
|
235 |
layer,
|
236 |
hidden_states,
|
@@ -264,7 +268,7 @@ class BertEncoder(nn.Module):
|
|
264 |
"cu_seqlens_k": cu_seqlens,
|
265 |
"max_seqlen_k": max_seqlen_in_batch,
|
266 |
}
|
267 |
-
if self._grad_checkpointing:
|
268 |
torch.utils.checkpoint.checkpoint(
|
269 |
self.layers[-1],
|
270 |
hidden_states_subset,
|
|
|
180 |
[create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
|
181 |
)
|
182 |
self._grad_checkpointing = False
|
183 |
+
self._num_checkpointed_layers = config.activation_checkpoint_lvl
|
184 |
|
185 |
@property
|
186 |
def gradient_checkpointing(self):
|
187 |
return self._grad_checkpointing
|
188 |
|
189 |
@gradient_checkpointing.setter
|
190 |
+
def gradient_checkpointing(self, value: bool):
|
191 |
+
if value and self._num_checkpointed_layers <= 0:
|
192 |
+
raise ValueError('Trying to use activation checkpointing, but `activation_checkpoint_lvl`'
|
193 |
+
'is set to zero.')
|
194 |
self._grad_checkpointing = value
|
195 |
|
196 |
def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
|
|
|
202 |
mixer_kwargs = (
|
203 |
{"key_padding_mask": key_padding_mask.bool()} if key_padding_mask is not None else None
|
204 |
)
|
205 |
+
for idx, layer in enumerate(self.layers):
|
206 |
+
if self._grad_checkpointing and idx < self._num_checkpointed_layers:
|
207 |
hidden_states = torch.utils.checkpoint.checkpoint(
|
208 |
layer,
|
209 |
hidden_states,
|
|
|
221 |
)
|
222 |
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
|
223 |
if subset_mask is None:
|
224 |
+
for idx, layer in enumerate(self.layers):
|
225 |
+
if self._grad_checkpointing and idx < self._num_checkpointed_layers:
|
226 |
hidden_states = torch.utils.checkpoint.checkpoint(
|
227 |
layer,
|
228 |
hidden_states,
|
|
|
233 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
234 |
hidden_states = pad_input(hidden_states, indices, batch, seqlen)
|
235 |
else:
|
236 |
+
for idx, layer in enumerate(self.layers[:-1]):
|
237 |
+
if self._grad_checkpointing and idx < self._num_checkpointed_layers:
|
238 |
hidden_states = torch.utils.checkpoint.checkpoint(
|
239 |
layer,
|
240 |
hidden_states,
|
|
|
268 |
"cu_seqlens_k": cu_seqlens,
|
269 |
"max_seqlen_k": max_seqlen_in_batch,
|
270 |
}
|
271 |
+
if self._grad_checkpointing and len(self.layers) <= self._num_checkpointed_layers:
|
272 |
torch.utils.checkpoint.checkpoint(
|
273 |
self.layers[-1],
|
274 |
hidden_states_subset,
|