Create README.md
Browse files
README.md
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
library_name: transformers
|
3 |
+
license: apache-2.0
|
4 |
+
language:
|
5 |
+
- en
|
6 |
+
datasets:
|
7 |
+
- Skylion007/openwebtext
|
8 |
+
metrics:
|
9 |
+
- perplexity
|
10 |
+
- mauve
|
11 |
+
base_model:
|
12 |
+
- kuleshov-group/mdlm-owt
|
13 |
+
---
|
14 |
+
|
15 |
+
## Using SDTT
|
16 |
+
- We released 3 groups of models:
|
17 |
+
1. The **baseline students** distilled with the `kld`, `mse` and `tvd` objectives, distilled from a model trained for 1M steps.
|
18 |
+
2. The **students from the scaling experiments**, with sizes `sm`, `md`, `large`, distilled from models trained for 400k steps.
|
19 |
+
3. The **teachers from the scaling experiments**, with sizes `sm`, `md`, `large`, before any distillation.
|
20 |
+
- To load those models, first install our code:
|
21 |
+
```bash
|
22 |
+
git clone https://github.com/jdeschena/sdtt.git
|
23 |
+
cd sdtt
|
24 |
+
pip install -r requirements.txt
|
25 |
+
pip install flash-attn
|
26 |
+
pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/cpu
|
27 |
+
pip install -e .
|
28 |
+
```
|
29 |
+
- You can then import our models, sample and evaluate them:
|
30 |
+
|
31 |
+
#### Load the baseline students
|
32 |
+
```python
|
33 |
+
from sdtt import load_small_student
|
34 |
+
student = load_small_student(loss="kld", round=7) # load the kld student after the last distillation round
|
35 |
+
student = load_small_student(loss="mse", round=2) # load the mse student after the second distillation round
|
36 |
+
student = load_small_student(loss="tvd", round=1) # load the tvd student after the first distillation round
|
37 |
+
```
|
38 |
+
|
39 |
+
#### Load the students from the scaling experiment
|
40 |
+
```python
|
41 |
+
from sdtt import load_scaling_student
|
42 |
+
student = load_scaling_student(size="sm", round=7) # load small student after the last distillation round
|
43 |
+
student = load_scaling_student(size="md", round=1) # load medium student after the first distillation round
|
44 |
+
student = load_scaling_student(size="large", round=3) # load large student after the third distillation round
|
45 |
+
```
|
46 |
+
|
47 |
+
#### Load the teachers from the scaling experiment
|
48 |
+
```python
|
49 |
+
from sdtt import load_scaling_teacher
|
50 |
+
student = load_scaling_student(size="sm",) # load small teacher
|
51 |
+
student = load_scaling_student(size="md",) # load medium teacher
|
52 |
+
student = load_scaling_student(size="large",) # load large teacher
|
53 |
+
```
|
54 |
+
|
55 |
+
#### Sample from the pretrained models
|
56 |
+
```python
|
57 |
+
from sdtt import load_small_student, load_scaling_student, load_scaling_teacher
|
58 |
+
import torch
|
59 |
+
|
60 |
+
model = load_small_student(loss="kld", round=7) # load model, see above
|
61 |
+
model.cuda() # put model on gpu
|
62 |
+
|
63 |
+
# Unconditional generation
|
64 |
+
tokens = model.sample(
|
65 |
+
n_samples=8,
|
66 |
+
num_steps=256,
|
67 |
+
seq_len=1024,
|
68 |
+
verbose=True,
|
69 |
+
)
|
70 |
+
# Detokenize
|
71 |
+
uncond_text = model.tokenizer.batch_decode(tokens)
|
72 |
+
|
73 |
+
# Conditional generation, based on a prompt
|
74 |
+
# Prepare a prompt
|
75 |
+
prompt = "Today is a great day. The sun is shining,"
|
76 |
+
prompt_tokens = model.tokenizer(prompt)["input_ids"]
|
77 |
+
prompt_tokens.insert(0, model.tokenizer.bos_token_id)
|
78 |
+
prompt_tokens = torch.tensor(prompt_tokens, device="cuda")
|
79 |
+
prompt_len = len(prompt_tokens)
|
80 |
+
|
81 |
+
def project_fn(x):
|
82 |
+
# Project the first 10 tokens of all examples to the prompt
|
83 |
+
x[:, :prompt_len] = prompt_tokens
|
84 |
+
return x # Don't forget to return
|
85 |
+
|
86 |
+
tokens = model.sample(
|
87 |
+
n_samples=8,
|
88 |
+
num_steps=256,
|
89 |
+
seq_len=1024,
|
90 |
+
verbose=True,
|
91 |
+
project_fn=project_fn
|
92 |
+
)
|
93 |
+
|
94 |
+
cond_text = model.tokenizer.batch_decode(tokens)
|
95 |
+
```
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
For more details, please see our github repository: [SDTT](https://github.com/jdeschena/sdtt)
|
100 |
+
|
101 |
+
## Model Details
|
102 |
+
Our checkpoints are distilled from [MDLM](https://github.com/kuleshov-group/mdlm) checkpoints. We release small, (169M), medium (424M) and large (863M) checkpoints.
|
103 |
+
|
104 |
+
## Citation
|
105 |
+
|
106 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
107 |
+
Please cite our work using the bibtex below:
|
108 |
+
|
109 |
+
**BibTeX:**
|
110 |
+
|
111 |
+
```
|
112 |
+
@article{deschenaux2024beyond,
|
113 |
+
title={Beyond Autoregression: Fast LLMs via Self-Distillation Through Time},
|
114 |
+
author={Deschenaux, Justin and Gulcehre, Caglar}
|
115 |
+
journal={arXiv preprint arXiv:TODO},
|
116 |
+
year={2024}
|
117 |
+
}
|
118 |
+
```
|
119 |
+
|
120 |
+
## Contact
|
121 |
+
Justin Deschenaux ([email protected])
|