add gradient checkpointing for the final_layernorm module.
#77
by
zhaoqf123
- opened
- modeling_chatglm.py +4 -1
modeling_chatglm.py
CHANGED
@@ -1012,7 +1012,10 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
1012 |
all_self_attentions = all_self_attentions + (layer_ret[2 if use_cache else 1],)
|
1013 |
|
1014 |
# Final layer norm.
|
1015 |
-
|
|
|
|
|
|
|
1016 |
|
1017 |
if output_hidden_states:
|
1018 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
1012 |
all_self_attentions = all_self_attentions + (layer_ret[2 if use_cache else 1],)
|
1013 |
|
1014 |
# Final layer norm.
|
1015 |
+
if self.gradient_checkpointing and self.training:
|
1016 |
+
hidden_states = torch.utils.checkpoint.checkpoint(self.final_layernorm, hidden_states)
|
1017 |
+
else:
|
1018 |
+
hidden_states = self.final_layernorm(hidden_states)
|
1019 |
|
1020 |
if output_hidden_states:
|
1021 |
all_hidden_states = all_hidden_states + (hidden_states,)
|