Runtime autograd error due to inplace operations
#4
by
xianbin
- opened
Error
While performing fine tuning of the Gemma2 models using TRL, the following errors were encountered:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [CUDABFloat16Type [1, 308, 256000]], which is output 0 of TanhBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
Cause
This was found to be due to the use of inplace operations in the Gemma2 transformer model definition that modifies a variable needed for gradient computation
Possible solution
The following lines of codes should be modified in diff_gemma2.py (and by extension modeling_gemma2.py)
Line 163-165:
attention_mask *= torch.tril(
torch.ones_like(attention_mask),
diagonal=(self.sliding_window - cache_position[-1]),
)
Replacement:
attention_mask = torch.mul(
attention_mask,
torch.tril(
torch.ones_like(attention_mask),
diagonal=(self.sliding_window - cache_position[-1]),
),
)
Line 119-121:
attn_weights.div_(self.config.attn_logit_softcapping)
attn_weights = torch.tanh(attn_weights)
attn_weights.mul_(self.config.attn_logit_softcapping)
Replacement:
attn_weights = torch.div(attn_weights, self.config.attn_logit_softcapping)
attn_weights = self.attn_weights_tanh(attn_weights)
attn_weights = torch.mul(attn_weights, self.config.attn_logit_softcapping)
Place this in the init of Gemma2Attention:
self.attn_weights_tanh = nn.Tanh()
Line 202-204:
logits.div_(self.config.final_logit_softcapping)
logits = torch.tanh(logits)
logits.mul_(self.config.final_logit_softcapping)
Replacement:
logits = torch.div(logits, self.config.final_logit_softcapping)
logits = self.final_logit_tanh(logits)
logits = torch.mul(logits, self.config.final_logit_softcapping)
Place this in the init of Gemma2ForCausalLM:
self.final_logit_tanh = nn.Tanh()
Yes will fix this in a bit!