Fill-Mask
Transformers
Safetensors
English
mega
16384
16k
Inference Endpoints
pszemraj's picture
Update README.md
4410970 verified
metadata
license: artistic-2.0
language:
  - en
tags:
  - '16384'
  - 16k

mega-encoder-small-16k-v1

This is a "huggingface-native" pretrained encoder-only model with 16384 context length. The model architecture is MEGA.

Numbers

Despite being a long-context model evaluated on a short-context benchmark, MEGA holds up decently:

Model Size CTX Avg
mega-encoder-small-16k-v1 122M 16384 0.777
bert-base-uncased 110M 512 0.7905
roberta-base 125M 514 0.86
bert-plus-L8-4096-v1.0 88.1M 4096 0.8278
mega-wikitext103 7.0M 10000 0.48
GLUE Details
Model Size CTX Avg CoLA SST2 MRPC STSB QQP MNLI QNLI RTE
mega-encoder-small-16k-v1 122M 16384 0.777 0.454 0.914 0.8404 0.906 0.894 0.806 0.842 0.556
bert-base-uncased 110M 512 0.7905 0.521 0.935 0.889 0.858 0.712 0.84 0.905 0.664
roberta-base 125M 514 0.86 0.64 0.95 0.9 0.91 0.92 0.88 0.93 0.79
bert-plus-L8-4096-v1.0 88.1M 4096 0.8278 0.6272 0.906 0.8659 0.9207 0.906 0.832 0.9 0.6643
mega-wikitext103 7M 10000 0.480 0.00 0.732 0.748 -0.087 0.701 0.54 0.598 0.513

The evals for MEGA/bert-plus can be found in this open wandb project and are taken as the max observed values on the validation sets. The values for other models are taken as reported in their papers.

Design

Architecture

This encoder model has 8 layers, hidden size 768, and a feedforward ratio of 3x. The resulting total size is 122M params.

Architecture Details

Details:

  1. We use a hidden size of 768, and a 3x hidden:feedforward ratio.
    • This contrasts with the 2x ratio used in the paper
  2. To handle the long context, we use MEGA's chunking mechanism, with a chunk length of 1024. As such, there is a linear increase in VRAM usage for multiples of this context length past 1024.
  3. EMA dimension: we use an EMA dimension of 32 in the interest of modeling long and (potentially) complex sequences
  4. We use 8 layers, and a context length of 16384 tokens.
  5. We use "simple" relative positional embeddings instead of the rotary embeddings touted in the paper.
    • This choice came from examining the detailed logs of models trained/evaluated on the LRA benchmark. Models geared towards encoder-type tasks all use the simple relative positional embeddings
    • We observed poor performance/unexplicable 'walls' in previous experiments using rotary positional embeddings with MEGA as an encoder
  6. BART tokenizer: we use the tokenizer from facebook/bart-large
    • This choice was motivated mostly from the desire to use the MEGA encoder in combination with a decoder model in the HF EncoderDecoderModel class in a "huggingface-native" way. BART is supported as a decoder for the this class, and BART's tokenizer has the necessary preprocessing for encoder training.
    • Example usage of MEGA+BART to create an encoder-decoder here
    • The tokenizer's vocab is exactly the same as Roberta's

Training

This model was trained with the transformers package. You can find (mostly unorganized) training runs on wandb here.

Training Details
  1. Multi-task training: the majority of training is "standard" MLM, with no next-sentence prediction, etc. However, in the interest of pretraining a useful encoder for fine-tuning on various tasks, we mix-in such tasks in between several of the MLM phases, carrying-over the model's backbone to the next training phase.
    • an example would be multiple-choice tuning on the swagdataset
  2. MLM Mask Ratio 40% default: we use 40% for the MLM ratio, following Wettig et al. 2022. This is decreased slightly for training at longer sequences (8192+) to encourage the model to learn/leverage the available context in predictions.
  3. AMP with bf16
  4. Gradient checkpointing implementation: training this (or similar) models at ctx 8192 or longer becomes quite vram intensive despite the linear increase in memory usage

Usage

This is a pretrained model intended to be fine-tuned on various encoder-compatible tasks. However, if you are interested in testing inference with this model or have a deep passion for predicting mask tokens, you can use the following code:

import json
from transformers import pipeline

pipe = pipeline("fill-mask", model="BEE-spoke-data/mega-encoder-small-16k-v1")
text = "I love to <mask> memes."
result = pipe(text)
print(json.dumps(result, indent=2))

Gradient checkpointing implementation

If fine-tuning this model on <task>, using gradient checkpointing makes training at 16384 context quite feasible. By installing the transformers fork below and passing gradient_checkpointing=True in the training args, you should be able to finetune at batch size 1 with VRAM to spare on a single 3090/4090.

pip uninstall -y transformers
pip install -U git+https://github.com/pszemraj/transformers.git@mega-gradient-checkpointing
pip install -U huggingface-hub

if there is sufficient interest, we can look at making a PR into the official repo.

Citation

if you find this useful, please consider citing this DOI, it would make us happy.

@misc{beespoke_data_2024,
    author       = {Peter Szemraj and Vincent Haines and {BEEspoke Data}},
    title        = {mega-encoder-small-16k-v1 (Revision 1476bcf)},
    year         = 2024,
    url          = {https://huggingface.co./BEE-spoke-data/mega-encoder-small-16k-v1},
    doi          = {10.57967/hf/1837},
    publisher    = {Hugging Face}
}