|
--- |
|
license: artistic-2.0 |
|
language: |
|
- en |
|
tags: |
|
- '16384' |
|
- 16k |
|
--- |
|
|
|
# mega-encoder-small-16k-v1 |
|
|
|
This is a "huggingface-native" pretrained encoder-only model with 16384 context length. The model architecture is [MEGA](https://arxiv.org/abs/2209.10655). |
|
|
|
## Numbers |
|
|
|
Despite being a long-context model evaluated on a short-context benchmark, MEGA holds up decently: |
|
|
|
| Model | Size | CTX | Avg | |
|
| :------------------------ | :---- | ----: | -----: | |
|
| mega-encoder-small-16k-v1 | 122M | 16384 | 0.777 | |
|
| bert-base-uncased | 110M | 512 | 0.7905 | |
|
| roberta-base | 125M | 514 | 0.86 | |
|
| [bert-plus-L8-4096-v1.0](https://huggingface.co./BEE-spoke-data/bert-plus-L8-4096-v1.0) | 88.1M | 4096 | 0.8278 | |
|
| [mega-wikitext103](https://huggingface.co./mnaylor/mega-base-wikitext) | 7.0M | 10000 | 0.48 | |
|
|
|
<details> |
|
<summary><strong>GLUE Details</strong></summary> |
|
|
|
| Model | Size | CTX | Avg | CoLA | SST2 | MRPC | STSB | QQP | MNLI | QNLI | RTE | |
|
| :------------------------ | :---- | ----: | -----: | -----: | ----: | -----: | -----: | ----: | ----: | ----: | -----: | |
|
| mega-encoder-small-16k-v1 | 122M | 16384 | 0.777 | 0.454 | 0.914 | 0.8404 | 0.906 | 0.894 | 0.806 | 0.842 | 0.556 | |
|
| bert-base-uncased | 110M | 512 | 0.7905 | 0.521 | 0.935 | 0.889 | 0.858 | 0.712 | 0.84 | 0.905 | 0.664 | |
|
| roberta-base | 125M | 514 | 0.86 | 0.64 | 0.95 | 0.9 | 0.91 | 0.92 | 0.88 | 0.93 | 0.79 | |
|
| bert-plus-L8-4096-v1.0 | 88.1M | 4096 | 0.8278 | 0.6272 | 0.906 | 0.8659 | 0.9207 | 0.906 | 0.832 | 0.9 | 0.6643 | |
|
| mega-wikitext103 | 7M | 10000| 0.480 | 0.00 | 0.732 | 0.748 | -0.087 | 0.701 | 0.54 | 0.598 | 0.513 | |
|
|
|
The evals for MEGA/bert-plus can be found in [this open wandb project](https://wandb.ai/pszemraj/glue-benchmarking) and are taken as the max observed values on the validation sets. The values for other models are taken as reported in their papers. |
|
</details> |
|
|
|
## Design |
|
|
|
### Architecture |
|
|
|
This encoder model has 8 layers, hidden size 768, and a feedforward ratio of 3x. The resulting total size is 122M params. |
|
|
|
<details> |
|
<summary><strong>Architecture Details</strong></summary> |
|
|
|
Details: |
|
|
|
1. We use a hidden size of 768, and a 3x hidden:feedforward ratio. |
|
- This contrasts with the 2x ratio used in the paper |
|
2. To handle the long context, we use MEGA's chunking mechanism, with a chunk length of 1024. As such, there is a linear increase in VRAM usage for multiples of this context length past 1024. |
|
3. EMA dimension: we use an EMA dimension of 32 in the interest of modeling long and (potentially) complex sequences |
|
4. We use 8 layers, and a context length of 16384 tokens. |
|
5. We use `"simple"` relative positional embeddings instead of the rotary embeddings touted in the paper. |
|
- This choice came from examining [the detailed logs of models](https://github.com/facebookresearch/mega/blob/aeaa4b44592cd1d60a9a34554e359eda2a62b03b/examples/mega/README.lra.md) trained/evaluated on [the LRA benchmark](https://paperswithcode.com/sota/long-range-modeling-on-lra). Models geared towards encoder-type tasks all use the simple relative positional embeddings |
|
- We observed poor performance/unexplicable 'walls' in previous experiments using rotary positional embeddings with MEGA as an encoder |
|
6. BART tokenizer: we use the tokenizer from `facebook/bart-large` |
|
- This choice was motivated mostly from the desire to use the MEGA encoder in combination with a decoder model in the [HF EncoderDecoderModel class](https://huggingface.co./docs/transformers/model_doc/encoder-decoder) in a "huggingface-native" way. BART is supported as a decoder for the this class, **and** BART's tokenizer has the necessary preprocessing for encoder training. |
|
- - Example usage of MEGA+BART to create an encoder-decoder [here](https://colab.research.google.com/gist/pszemraj/4bac8635361543b66207d73e4b25a13a/mega-encoder-small-16k-v1-for-text2text.ipynb) |
|
- The tokenizer's vocab is **exactly** the same as Roberta's |
|
</details> |
|
|
|
|
|
### Training |
|
|
|
This model was trained with the transformers package. You can find (mostly unorganized) [training runs on wandb here](https://wandb.ai/pszemraj/mega-tuning-longctx). |
|
|
|
<details> |
|
<summary><strong>Training Details</strong></summary> |
|
|
|
1. **Multi-task training:** the majority of training is "standard" MLM, with no next-sentence prediction, etc. However, in the interest of pretraining a _useful_ encoder for fine-tuning on various tasks, we mix-in such tasks in between several of the MLM phases, carrying-over the model's backbone to the next training phase. |
|
- an example would be multiple-choice tuning on the [swag](https://huggingface.co./datasets/swag)dataset |
|
2. **MLM Mask Ratio 40% default:** we use 40% for the MLM ratio, following [Wettig et al. 2022](https://arxiv.org/abs/2202.08005). This is decreased slightly for training at longer sequences (8192+) to encourage the model to learn/leverage the available context in predictions. |
|
3. AMP with bf16 |
|
4. **Gradient checkpointing implementation**: training this (or similar) models at ctx 8192 or longer becomes quite vram intensive despite the linear increase in memory usage |
|
</details> |
|
|
|
## Usage |
|
|
|
This is a pretrained model intended to be [fine-tuned on various encoder-compatible tasks](https://github.com/huggingface/transformers/tree/831bc25d8fdb85768402f772cf65cc3d7872b211/examples/pytorch). However, if you are interested in testing inference with this model or have a deep passion for predicting mask tokens, you can use the following code: |
|
|
|
```python |
|
import json |
|
from transformers import pipeline |
|
|
|
pipe = pipeline("fill-mask", model="BEE-spoke-data/mega-encoder-small-16k-v1") |
|
text = "I love to <mask> memes." |
|
result = pipe(text) |
|
print(json.dumps(result, indent=2)) |
|
``` |
|
|
|
### Gradient checkpointing implementation |
|
|
|
If fine-tuning this model on `<task>`, using gradient checkpointing makes training at 16384 context quite feasible. By installing the transformers fork below and passing `gradient_checkpointing=True` in the training args, you should be able to finetune at batch size 1 with VRAM to spare on a single 3090/4090. |
|
|
|
```sh |
|
pip uninstall -y transformers |
|
pip install -U git+https://github.com/pszemraj/transformers.git@mega-gradient-checkpointing |
|
pip install -U huggingface-hub |
|
``` |
|
|
|
if there is sufficient interest, we can look at making a PR into the official repo. |
|
|
|
## Citation |
|
|
|
if you find this useful, please consider citing this DOI, it would make us happy. |
|
|
|
``` |
|
@misc{beespoke_data_2024, |
|
author = {Peter Szemraj and Vincent Haines and {BEEspoke Data}}, |
|
title = {mega-encoder-small-16k-v1 (Revision 1476bcf)}, |
|
year = 2024, |
|
url = {https://huggingface.co./BEE-spoke-data/mega-encoder-small-16k-v1}, |
|
doi = {10.57967/hf/1837}, |
|
publisher = {Hugging Face} |
|
} |
|
``` |