llmixer commited on
Commit
1f3dd72
1 Parent(s): dbea74d

Normalize hidden state after adding control vectors to preserve L2 norm

Browse files
Files changed (1) hide show
  1. exl2_wrapper.py +2 -0
exl2_wrapper.py CHANGED
@@ -78,7 +78,9 @@ class ExLlamaV2ModuleWrapper:
78
  def wrapped_forward(self, *args, **kwargs):
79
  x = self.module.forward(*args, **kwargs)
80
  try:
 
81
  x += self.control_vector[self.module.layer_idx].clone().to(x.device)
 
82
  except IndexError:
83
  pass
84
  return x
 
78
  def wrapped_forward(self, *args, **kwargs):
79
  x = self.module.forward(*args, **kwargs)
80
  try:
81
+ prev_norm = torch.norm(x, p=2)
82
  x += self.control_vector[self.module.layer_idx].clone().to(x.device)
83
+ x *= prev_norm / torch.norm(x, p=2)
84
  except IndexError:
85
  pass
86
  return x