English
jdeschena commited on
Commit
3e95136
1 Parent(s): 89da857

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +121 -0
README.md ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ license: apache-2.0
4
+ language:
5
+ - en
6
+ datasets:
7
+ - Skylion007/openwebtext
8
+ metrics:
9
+ - perplexity
10
+ - mauve
11
+ base_model:
12
+ - kuleshov-group/mdlm-owt
13
+ ---
14
+
15
+ ## Using SDTT
16
+ - We released 3 groups of models:
17
+ 1. The **baseline students** distilled with the `kld`, `mse` and `tvd` objectives, distilled from a model trained for 1M steps.
18
+ 2. The **students from the scaling experiments**, with sizes `sm`, `md`, `large`, distilled from models trained for 400k steps.
19
+ 3. The **teachers from the scaling experiments**, with sizes `sm`, `md`, `large`, before any distillation.
20
+ - To load those models, first install our code:
21
+ ```bash
22
+ git clone https://github.com/jdeschena/sdtt.git
23
+ cd sdtt
24
+ pip install -r requirements.txt
25
+ pip install flash-attn
26
+ pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/cpu
27
+ pip install -e .
28
+ ```
29
+ - You can then import our models, sample and evaluate them:
30
+
31
+ #### Load the baseline students
32
+ ```python
33
+ from sdtt import load_small_student
34
+ student = load_small_student(loss="kld", round=7) # load the kld student after the last distillation round
35
+ student = load_small_student(loss="mse", round=2) # load the mse student after the second distillation round
36
+ student = load_small_student(loss="tvd", round=1) # load the tvd student after the first distillation round
37
+ ```
38
+
39
+ #### Load the students from the scaling experiment
40
+ ```python
41
+ from sdtt import load_scaling_student
42
+ student = load_scaling_student(size="sm", round=7) # load small student after the last distillation round
43
+ student = load_scaling_student(size="md", round=1) # load medium student after the first distillation round
44
+ student = load_scaling_student(size="large", round=3) # load large student after the third distillation round
45
+ ```
46
+
47
+ #### Load the teachers from the scaling experiment
48
+ ```python
49
+ from sdtt import load_scaling_teacher
50
+ student = load_scaling_student(size="sm",) # load small teacher
51
+ student = load_scaling_student(size="md",) # load medium teacher
52
+ student = load_scaling_student(size="large",) # load large teacher
53
+ ```
54
+
55
+ #### Sample from the pretrained models
56
+ ```python
57
+ from sdtt import load_small_student, load_scaling_student, load_scaling_teacher
58
+ import torch
59
+
60
+ model = load_small_student(loss="kld", round=7) # load model, see above
61
+ model.cuda() # put model on gpu
62
+
63
+ # Unconditional generation
64
+ tokens = model.sample(
65
+ n_samples=8,
66
+ num_steps=256,
67
+ seq_len=1024,
68
+ verbose=True,
69
+ )
70
+ # Detokenize
71
+ uncond_text = model.tokenizer.batch_decode(tokens)
72
+
73
+ # Conditional generation, based on a prompt
74
+ # Prepare a prompt
75
+ prompt = "Today is a great day. The sun is shining,"
76
+ prompt_tokens = model.tokenizer(prompt)["input_ids"]
77
+ prompt_tokens.insert(0, model.tokenizer.bos_token_id)
78
+ prompt_tokens = torch.tensor(prompt_tokens, device="cuda")
79
+ prompt_len = len(prompt_tokens)
80
+
81
+ def project_fn(x):
82
+ # Project the first 10 tokens of all examples to the prompt
83
+ x[:, :prompt_len] = prompt_tokens
84
+ return x # Don't forget to return
85
+
86
+ tokens = model.sample(
87
+ n_samples=8,
88
+ num_steps=256,
89
+ seq_len=1024,
90
+ verbose=True,
91
+ project_fn=project_fn
92
+ )
93
+
94
+ cond_text = model.tokenizer.batch_decode(tokens)
95
+ ```
96
+
97
+
98
+
99
+ For more details, please see our github repository: [SDTT](https://github.com/jdeschena/sdtt)
100
+
101
+ ## Model Details
102
+ Our checkpoints are distilled from [MDLM](https://github.com/kuleshov-group/mdlm) checkpoints. We release small, (169M), medium (424M) and large (863M) checkpoints.
103
+
104
+ ## Citation
105
+
106
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
107
+ Please cite our work using the bibtex below:
108
+
109
+ **BibTeX:**
110
+
111
+ ```
112
+ @article{deschenaux2024beyond,
113
+ title={Beyond Autoregression: Fast LLMs via Self-Distillation Through Time},
114
+ author={Deschenaux, Justin and Gulcehre, Caglar}
115
+ journal={arXiv preprint arXiv:TODO},
116
+ year={2024}
117
+ }
118
+ ```
119
+
120
+ ## Contact
121
+ Justin Deschenaux ([email protected])