English
sdtt / README.md
jdeschena's picture
Update README.md
6a0105d verified
|
raw
history blame
3.87 kB
metadata
license: apache-2.0
language:
  - en
datasets:
  - Skylion007/openwebtext
metrics:
  - perplexity
  - mauve

Using SDTT

  • We released 3 groups of models:
    1. The baseline students distilled with the kld, mse and tvd objectives, distilled from a model trained for 1M steps.
    2. The students from the scaling experiments, with sizes sm, md, large, distilled from models trained for 400k steps.
    3. The teachers from the scaling experiments, with sizes sm, md, large, before any distillation.
  • To load those models, first install our code:
git clone https://github.com/jdeschena/sdtt.git
cd sdtt
pip install -r requirements.txt
pip install flash-attn
pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/cpu
pip install -e .
  • You can then import our models, sample and evaluate them:

Load the baseline students

from sdtt import load_small_student
student = load_small_student(loss="kld", round=7)  # load the kld student after the last distillation round
student = load_small_student(loss="mse", round=2)  # load the mse student after the second distillation round
student = load_small_student(loss="tvd", round=1)  # load the tvd student after the first distillation round

Load the students from the scaling experiment

from sdtt import load_scaling_student
student = load_scaling_student(size="sm", round=7)  # load small student after the last distillation round
student = load_scaling_student(size="md", round=1)   # load medium student after the first distillation round
student = load_scaling_student(size="large", round=3)  # load large student after the third distillation round

Load the teachers from the scaling experiment

from sdtt import load_scaling_teacher
student = load_scaling_student(size="sm",)  # load small teacher
student = load_scaling_student(size="md",)   # load medium teacher
student = load_scaling_student(size="large",)  # load large teacher

Sample from the pretrained models

from sdtt import load_small_student, load_scaling_student, load_scaling_teacher
import torch

model = load_small_student(loss="kld", round=7)  # load model, see above
model.cuda()  # put model on gpu

# Unconditional generation
tokens = model.sample(
    n_samples=8,
    num_steps=256,
    seq_len=1024,
    verbose=True,
)
# Detokenize
uncond_text = model.tokenizer.batch_decode(tokens)

# Conditional generation, based on a prompt
# Prepare a prompt
prompt = "Today is a great day. The sun is shining,"
prompt_tokens = model.tokenizer(prompt)["input_ids"]
prompt_tokens.insert(0, model.tokenizer.bos_token_id)
prompt_tokens = torch.tensor(prompt_tokens, device="cuda")
prompt_len = len(prompt_tokens)

def project_fn(x):
    # Project the first 10 tokens of all examples to the prompt
    x[:, :prompt_len] = prompt_tokens  
    return x  # Don't forget to return

tokens = model.sample(
    n_samples=8,
    num_steps=256,
    seq_len=1024,
    verbose=True,
    project_fn=project_fn
)

cond_text = model.tokenizer.batch_decode(tokens)

For more details, please see our github repository: SDTT

Model Details

Our checkpoints are distilled from MDLM checkpoints. We release small, (169M), medium (424M) and large (863M) checkpoints.

Citation

Please cite our work using the bibtex below:

BibTeX:

@article{deschenaux2024beyond,
  title={Beyond Autoregression: Fast LLMs via Self-Distillation Through Time},
  author={Deschenaux, Justin and Gulcehre, Caglar}
  journal={arXiv preprint arXiv:TODO},
  year={2024}
}

Contact

Justin Deschenaux ([email protected])