Normalize hidden state after adding control vectors to preserve L2 norm
Browse files- exl2_wrapper.py +2 -0
exl2_wrapper.py
CHANGED
@@ -78,7 +78,9 @@ class ExLlamaV2ModuleWrapper:
|
|
78 |
def wrapped_forward(self, *args, **kwargs):
|
79 |
x = self.module.forward(*args, **kwargs)
|
80 |
try:
|
|
|
81 |
x += self.control_vector[self.module.layer_idx].clone().to(x.device)
|
|
|
82 |
except IndexError:
|
83 |
pass
|
84 |
return x
|
|
|
78 |
def wrapped_forward(self, *args, **kwargs):
|
79 |
x = self.module.forward(*args, **kwargs)
|
80 |
try:
|
81 |
+
prev_norm = torch.norm(x, p=2)
|
82 |
x += self.control_vector[self.module.layer_idx].clone().to(x.device)
|
83 |
+
x *= prev_norm / torch.norm(x, p=2)
|
84 |
except IndexError:
|
85 |
pass
|
86 |
return x
|