kaiokendev
commited on
Commit
·
0cad621
1
Parent(s):
67bf26a
Fix bug
Browse filesFix bug, t needs to be scaled if input > 8192
llama_rope_scaled_monkey_patch.py
CHANGED
@@ -42,6 +42,7 @@ class ScaledRotaryEmbedding(torch.nn.Module):
|
|
42 |
t = torch.arange(
|
43 |
self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype
|
44 |
)
|
|
|
45 |
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
46 |
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
47 |
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
|
|
42 |
t = torch.arange(
|
43 |
self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype
|
44 |
)
|
45 |
+
t *= self.scale
|
46 |
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
47 |
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
48 |
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|