File size: 6,918 Bytes
d847cdf cf313cd d847cdf 1476bcf cf313cd 1476bcf aa7e669 86ff964 4410970 aa7e669 1476bcf 4410970 1476bcf aa7e669 1476bcf aa7e669 1476bcf 86ff964 1476bcf aa7e669 1476bcf aa7e669 1476bcf aa7e669 1476bcf c473149 1476bcf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
---
license: artistic-2.0
language:
- en
tags:
- '16384'
- 16k
---
# mega-encoder-small-16k-v1
This is a "huggingface-native" pretrained encoder-only model with 16384 context length. The model architecture is [MEGA](https://arxiv.org/abs/2209.10655).
## Numbers
Despite being a long-context model evaluated on a short-context benchmark, MEGA holds up decently:
| Model | Size | CTX | Avg |
| :------------------------ | :---- | ----: | -----: |
| mega-encoder-small-16k-v1 | 122M | 16384 | 0.777 |
| bert-base-uncased | 110M | 512 | 0.7905 |
| roberta-base | 125M | 514 | 0.86 |
| [bert-plus-L8-4096-v1.0](https://huggingface.co./BEE-spoke-data/bert-plus-L8-4096-v1.0) | 88.1M | 4096 | 0.8278 |
| [mega-wikitext103](https://huggingface.co./mnaylor/mega-base-wikitext) | 7.0M | 10000 | 0.48 |
<details>
<summary><strong>GLUE Details</strong></summary>
| Model | Size | CTX | Avg | CoLA | SST2 | MRPC | STSB | QQP | MNLI | QNLI | RTE |
| :------------------------ | :---- | ----: | -----: | -----: | ----: | -----: | -----: | ----: | ----: | ----: | -----: |
| mega-encoder-small-16k-v1 | 122M | 16384 | 0.777 | 0.454 | 0.914 | 0.8404 | 0.906 | 0.894 | 0.806 | 0.842 | 0.556 |
| bert-base-uncased | 110M | 512 | 0.7905 | 0.521 | 0.935 | 0.889 | 0.858 | 0.712 | 0.84 | 0.905 | 0.664 |
| roberta-base | 125M | 514 | 0.86 | 0.64 | 0.95 | 0.9 | 0.91 | 0.92 | 0.88 | 0.93 | 0.79 |
| bert-plus-L8-4096-v1.0 | 88.1M | 4096 | 0.8278 | 0.6272 | 0.906 | 0.8659 | 0.9207 | 0.906 | 0.832 | 0.9 | 0.6643 |
| mega-wikitext103 | 7M | 10000| 0.480 | 0.00 | 0.732 | 0.748 | -0.087 | 0.701 | 0.54 | 0.598 | 0.513 |
The evals for MEGA/bert-plus can be found in [this open wandb project](https://wandb.ai/pszemraj/glue-benchmarking) and are taken as the max observed values on the validation sets. The values for other models are taken as reported in their papers.
</details>
## Design
### Architecture
This encoder model has 8 layers, hidden size 768, and a feedforward ratio of 3x. The resulting total size is 122M params.
<details>
<summary><strong>Architecture Details</strong></summary>
Details:
1. We use a hidden size of 768, and a 3x hidden:feedforward ratio.
- This contrasts with the 2x ratio used in the paper
2. To handle the long context, we use MEGA's chunking mechanism, with a chunk length of 1024. As such, there is a linear increase in VRAM usage for multiples of this context length past 1024.
3. EMA dimension: we use an EMA dimension of 32 in the interest of modeling long and (potentially) complex sequences
4. We use 8 layers, and a context length of 16384 tokens.
5. We use `"simple"` relative positional embeddings instead of the rotary embeddings touted in the paper.
- This choice came from examining [the detailed logs of models](https://github.com/facebookresearch/mega/blob/aeaa4b44592cd1d60a9a34554e359eda2a62b03b/examples/mega/README.lra.md) trained/evaluated on [the LRA benchmark](https://paperswithcode.com/sota/long-range-modeling-on-lra). Models geared towards encoder-type tasks all use the simple relative positional embeddings
- We observed poor performance/unexplicable 'walls' in previous experiments using rotary positional embeddings with MEGA as an encoder
6. BART tokenizer: we use the tokenizer from `facebook/bart-large`
- This choice was motivated mostly from the desire to use the MEGA encoder in combination with a decoder model in the [HF EncoderDecoderModel class](https://huggingface.co./docs/transformers/model_doc/encoder-decoder) in a "huggingface-native" way. BART is supported as a decoder for the this class, **and** BART's tokenizer has the necessary preprocessing for encoder training.
- - Example usage of MEGA+BART to create an encoder-decoder [here](https://colab.research.google.com/gist/pszemraj/4bac8635361543b66207d73e4b25a13a/mega-encoder-small-16k-v1-for-text2text.ipynb)
- The tokenizer's vocab is **exactly** the same as Roberta's
</details>
### Training
This model was trained with the transformers package. You can find (mostly unorganized) [training runs on wandb here](https://wandb.ai/pszemraj/mega-tuning-longctx).
<details>
<summary><strong>Training Details</strong></summary>
1. **Multi-task training:** the majority of training is "standard" MLM, with no next-sentence prediction, etc. However, in the interest of pretraining a _useful_ encoder for fine-tuning on various tasks, we mix-in such tasks in between several of the MLM phases, carrying-over the model's backbone to the next training phase.
- an example would be multiple-choice tuning on the [swag](https://huggingface.co./datasets/swag)dataset
2. **MLM Mask Ratio 40% default:** we use 40% for the MLM ratio, following [Wettig et al. 2022](https://arxiv.org/abs/2202.08005). This is decreased slightly for training at longer sequences (8192+) to encourage the model to learn/leverage the available context in predictions.
3. AMP with bf16
4. **Gradient checkpointing implementation**: training this (or similar) models at ctx 8192 or longer becomes quite vram intensive despite the linear increase in memory usage
</details>
## Usage
This is a pretrained model intended to be [fine-tuned on various encoder-compatible tasks](https://github.com/huggingface/transformers/tree/831bc25d8fdb85768402f772cf65cc3d7872b211/examples/pytorch). However, if you are interested in testing inference with this model or have a deep passion for predicting mask tokens, you can use the following code:
```python
import json
from transformers import pipeline
pipe = pipeline("fill-mask", model="BEE-spoke-data/mega-encoder-small-16k-v1")
text = "I love to <mask> memes."
result = pipe(text)
print(json.dumps(result, indent=2))
```
### Gradient checkpointing implementation
If fine-tuning this model on `<task>`, using gradient checkpointing makes training at 16384 context quite feasible. By installing the transformers fork below and passing `gradient_checkpointing=True` in the training args, you should be able to finetune at batch size 1 with VRAM to spare on a single 3090/4090.
```sh
pip uninstall -y transformers
pip install -U git+https://github.com/pszemraj/transformers.git@mega-gradient-checkpointing
pip install -U huggingface-hub
```
if there is sufficient interest, we can look at making a PR into the official repo.
## Citation
if you find this useful, please consider citing this DOI, it would make us happy.
```
@misc{beespoke_data_2024,
author = {Peter Szemraj and Vincent Haines and {BEEspoke Data}},
title = {mega-encoder-small-16k-v1 (Revision 1476bcf)},
year = 2024,
url = {https://huggingface.co./BEE-spoke-data/mega-encoder-small-16k-v1},
doi = {10.57967/hf/1837},
publisher = {Hugging Face}
}
``` |