kaiokendev
commited on
Commit
·
67bf26a
1
Parent(s):
65084ac
Upload lora
Browse files- README.md +31 -0
- adapter_config.json +19 -0
- adapter_model.bin +3 -0
- llama_rope_scaled_monkey_patch.py +63 -0
README.md
CHANGED
@@ -1,3 +1,34 @@
|
|
1 |
---
|
2 |
license: mit
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
license: mit
|
3 |
---
|
4 |
+
|
5 |
+
### SuperHOT Prototype 2 w/ 4-8K Context
|
6 |
+
|
7 |
+
This is a second prototype of SuperHOT, this time with 4K context and no RLHF. In my testing, it can go all the way to 6K without breaking down and I made the change with intention to reach 8K, so I'll assume it will go to 8K although I only trained on 4K sequences.
|
8 |
+
|
9 |
+
In order to use the 8K context, you will need to apply the monkeypatch I have added in this repo -- without it, it will not work. The patch is very simple, and you can make the changes yourself:
|
10 |
+
- Increase the `max_position_embeddings` to 8192 to stretch the sinusoidal
|
11 |
+
- Stretch the frequency steps by a scale of `0.25`
|
12 |
+
|
13 |
+
The intuition is to calibrate the model to within the learned positions of the pre-trained model as the model may be overfit on the token-position relationship (not my idea, [Ofir Press'](https://ofir.io/)). By interpolating the encodings, we remain within the bounds of the pre-trained model (work with the overfitting rather than against it). The monkeypatch will work for the pre-trained model without fine-tuning, but you will need to fine-tune as the results will not be that good without it.
|
14 |
+
|
15 |
+
It can probably be even better than this with a few other modifications which I am testing (swap softmax for ReLU, increase head dimension)
|
16 |
+
|
17 |
+
In my testing, I tried random positional encoding, but I was not able to replicate the results of [Jianlin Su](https://kexue.fm/archives/9444), so maybe I did it incorrectly. I also tried shifted positions, log n scaling, log-sigmoid, and increase the head dimension, though this dilated RoPE (DoPE :) ) is the only one which worked for me consistently -- Note these are all based on finetuning, since the goal is to extend the context of the pre-trained model. Pre-training will paint a different picture.
|
18 |
+
|
19 |
+
I trained the LoRA with the following configuration:
|
20 |
+
- 1200 samples (~400 samples over 2048 sequence length)
|
21 |
+
- learning rate of 3e-4
|
22 |
+
- 3 epochs
|
23 |
+
- The exported modules are:
|
24 |
+
- q_proj
|
25 |
+
- k_proj
|
26 |
+
- v_proj
|
27 |
+
- o_proj
|
28 |
+
- all bias
|
29 |
+
- Rank = 2
|
30 |
+
- Alpha = 8
|
31 |
+
- no dropout
|
32 |
+
- weight decay of 0.1
|
33 |
+
- AdamW beta1 of 0.9 and beta2 0.99, epsilon of 1e-5
|
34 |
+
- Trained on 4-bit base model
|
adapter_config.json
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"base_model_name_or_path": "",
|
3 |
+
"bias": "all",
|
4 |
+
"fan_in_fan_out": false,
|
5 |
+
"inference_mode": true,
|
6 |
+
"init_lora_weights": true,
|
7 |
+
"lora_alpha": 8,
|
8 |
+
"lora_dropout": 0,
|
9 |
+
"modules_to_save": null,
|
10 |
+
"peft_type": "LORA",
|
11 |
+
"r": 2,
|
12 |
+
"target_modules": [
|
13 |
+
"q_proj",
|
14 |
+
"k_proj",
|
15 |
+
"v_proj",
|
16 |
+
"o_proj"
|
17 |
+
],
|
18 |
+
"task_type": "CAUSAL_LM"
|
19 |
+
}
|
adapter_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:76133dc631ac8dc28341c45f8f469cc603174cb1d16c728f65b33778f8f497e4
|
3 |
+
size 17579562
|
llama_rope_scaled_monkey_patch.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import transformers
|
3 |
+
import transformers.models.llama.modeling_llama
|
4 |
+
from einops import rearrange
|
5 |
+
import random
|
6 |
+
|
7 |
+
|
8 |
+
class ScaledRotaryEmbedding(torch.nn.Module):
|
9 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
10 |
+
super().__init__()
|
11 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
12 |
+
self.register_buffer("inv_freq", inv_freq)
|
13 |
+
|
14 |
+
max_position_embeddings = 8192
|
15 |
+
|
16 |
+
# Build here to make `torch.jit.trace` work.
|
17 |
+
self.max_seq_len_cached = max_position_embeddings
|
18 |
+
t = torch.arange(
|
19 |
+
self.max_seq_len_cached,
|
20 |
+
device=self.inv_freq.device,
|
21 |
+
dtype=self.inv_freq.dtype,
|
22 |
+
)
|
23 |
+
|
24 |
+
self.scale = 1 / 4
|
25 |
+
t *= self.scale
|
26 |
+
|
27 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
28 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
29 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
30 |
+
self.register_buffer(
|
31 |
+
"cos_cached", emb.cos()[None, None, :, :], persistent=False
|
32 |
+
)
|
33 |
+
self.register_buffer(
|
34 |
+
"sin_cached", emb.sin()[None, None, :, :], persistent=False
|
35 |
+
)
|
36 |
+
|
37 |
+
def forward(self, x, seq_len=None):
|
38 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
39 |
+
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
|
40 |
+
if seq_len > self.max_seq_len_cached:
|
41 |
+
self.max_seq_len_cached = seq_len
|
42 |
+
t = torch.arange(
|
43 |
+
self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype
|
44 |
+
)
|
45 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
46 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
47 |
+
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
48 |
+
self.register_buffer(
|
49 |
+
"cos_cached", emb.cos()[None, None, :, :], persistent=False
|
50 |
+
)
|
51 |
+
self.register_buffer(
|
52 |
+
"sin_cached", emb.sin()[None, None, :, :], persistent=False
|
53 |
+
)
|
54 |
+
return (
|
55 |
+
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
56 |
+
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
57 |
+
)
|
58 |
+
|
59 |
+
|
60 |
+
def replace_llama_rope_with_scaled_rope():
|
61 |
+
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = (
|
62 |
+
ScaledRotaryEmbedding
|
63 |
+
)
|