zhaoqf123 commited on
Commit
3d854f8
·
1 Parent(s): 1d240ba

add gradient checkpointing for the final_layernorm module.

Browse files

Without this, when tuning with LoRA + gradient checkpointing, the last transformer layer, i.e., layer-27's LoRA weights won't be updated!

For example, if we use this callback to log the weight change of LoRA weights in each layer, we will find that no weight update for the last layer in TensorBoard.

```
class ParamsTensorBoardCallback(TensorBoardCallback):
def __init__(self, tb_writer=None, params=None, process_name=lambda x:x):
super().__init__(tb_writer)
self.params = params
self._process_name = process_name

def on_step_end(self, args, state, control, **kwargs):
if state.global_step % args.logging_steps == 0:
dict_ = {}
model = kwargs["model"]
for name in self.params:
param = model.get_parameter(name)
param = param.flatten()
name_p = self._process_name(name)
dict_tmp = {
f"{name_p}_mean": param.mean().item(),
f"{name_p}_max": param.max().item(),
f"{name_p}_q75": param.quantile(0.75).item(),
f"{name_p}_q25": param.quantile(0.25).item(),
f"{name_p}_min": param.min().item(),
f"{name_p}_median": param.median().item(),
f"{name_p}_std": param.std().item(),
}
dict_.update(dict_tmp)
self.on_log(args, state, control, logs=dict_, **kwargs)

def get_params_for_logging(model):
ls_params = []
for name, param in model.named_parameters():
if param.requires_grad:
ls_params.append(name)
return ls_params

ls_params = get_params_for_logging(model)
tb_cb = ParamsTensorBoardCallback(
None, ls_params, process_name=getattr(utils, param_name_trimmer_name)()
)

trainer = Trainer(
model=model,
train_dataset=train_data,
eval_dataset=val_data,
args=args,
data_collator=data_collator,
callbacks=[tb_cb]
)
```

Files changed (1) hide show
  1. modeling_chatglm.py +4 -1
modeling_chatglm.py CHANGED
@@ -1012,7 +1012,10 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
1012
  all_self_attentions = all_self_attentions + (layer_ret[2 if use_cache else 1],)
1013
 
1014
  # Final layer norm.
1015
- hidden_states = self.final_layernorm(hidden_states)
 
 
 
1016
 
1017
  if output_hidden_states:
1018
  all_hidden_states = all_hidden_states + (hidden_states,)
 
1012
  all_self_attentions = all_self_attentions + (layer_ret[2 if use_cache else 1],)
1013
 
1014
  # Final layer norm.
1015
+ if self.gradient_checkpointing and self.training:
1016
+ hidden_states = torch.utils.checkpoint.checkpoint(self.final_layernorm, hidden_states)
1017
+ else:
1018
+ hidden_states = self.final_layernorm(hidden_states)
1019
 
1020
  if output_hidden_states:
1021
  all_hidden_states = all_hidden_states + (hidden_states,)