--- license: apache-2.0 language: - en datasets: - Skylion007/openwebtext metrics: - perplexity - mauve --- # Self-Distillation Through Time (SDTT) SDTT is a distillation method for diffusion language models. Recent diffusion language models such as [SEDD](https://huggingface.co./louaaron/sedd-small) or [MDLM](https://huggingface.co./kuleshov-group/mdlm-owt) achieve great results. However, because they cannot use KV-caching (non-causal architecture), it is slow to sample from them. Therefore, we devise a novel distillation method to reduce the inference latency of discrete diffusion models. After distillation, we can sample up to 8x faster than GPT-2 (that uses KV-caching). Find more details below and on [our GitHub repo](https://github.com/jdeschena/sdtt). ## Using SDTT - We released 3 groups of models: 1. The **baseline students** distilled with the `kld`, `mse` and `tvd` objectives, distilled from a model trained for 1M steps. 2. The **students from the scaling experiments**, with sizes `sm`, `md`, `large`, distilled from models trained for 400k steps. 3. The **teachers from the scaling experiments**, with sizes `sm`, `md`, `large`, before any distillation. - To load those models, first install our code: ```bash git clone https://github.com/jdeschena/sdtt.git cd sdtt pip install -r requirements.txt pip install flash-attn pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/cpu pip install -e . ``` - You can then import our models, sample and evaluate them: #### Load the baseline students ```python from sdtt import load_small_student student = load_small_student(loss="kld", round=7) # load the kld student after the last distillation round student = load_small_student(loss="mse", round=2) # load the mse student after the second distillation round student = load_small_student(loss="tvd", round=1) # load the tvd student after the first distillation round ``` #### Load the students from the scaling experiment ```python from sdtt import load_scaling_student student = load_scaling_student(size="sm", round=7) # load small student after the last distillation round student = load_scaling_student(size="md", round=1) # load medium student after the first distillation round student = load_scaling_student(size="large", round=3) # load large student after the third distillation round ``` #### Load the teachers from the scaling experiment ```python from sdtt import load_scaling_teacher student = load_scaling_student(size="sm",) # load small teacher student = load_scaling_student(size="md",) # load medium teacher student = load_scaling_student(size="large",) # load large teacher ``` #### Sample from the pretrained models ```python from sdtt import load_small_student, load_scaling_student, load_scaling_teacher import torch model = load_small_student(loss="kld", round=7) # load model, see above model.cuda() # put model on gpu # Unconditional generation tokens = model.sample( n_samples=8, num_steps=256, seq_len=1024, verbose=True, ) # Detokenize uncond_text = model.tokenizer.batch_decode(tokens) # Conditional generation, based on a prompt # Prepare a prompt prompt = "Today is a great day. The sun is shining," prompt_tokens = model.tokenizer(prompt)["input_ids"] prompt_tokens.insert(0, model.tokenizer.bos_token_id) prompt_tokens = torch.tensor(prompt_tokens, device="cuda") prompt_len = len(prompt_tokens) def project_fn(x): # Project the first 10 tokens of all examples to the prompt x[:, :prompt_len] = prompt_tokens return x # Don't forget to return tokens = model.sample( n_samples=8, num_steps=256, seq_len=1024, verbose=True, project_fn=project_fn ) cond_text = model.tokenizer.batch_decode(tokens) ``` For more details, please see our github repository: [SDTT](https://github.com/jdeschena/sdtt) ## Model Details Our checkpoints are distilled from [MDLM](https://github.com/kuleshov-group/mdlm) checkpoints. We release small, (169M), medium (424M) and large (863M) checkpoints. ## Citation Please cite our work using the bibtex below: **BibTeX:** ``` @article{deschenaux2024autoregressionfastllmsselfdistillation, title={Beyond Autoregression: Fast LLMs via Self-Distillation Through Time}, author={Deschenaux, Justin and Gulcehre, Caglar} eprint={2410.21035}, archivePrefix={arXiv}, primaryClass={cs.LG}, url={https://arxiv.org/abs/2410.21035}, } ``` ## Contact Justin Deschenaux (justin.deschenaux@epfl.ch)