Fill-Mask
Transformers
Safetensors
English
mega
16384
16k
Inference Endpoints
File size: 6,918 Bytes
d847cdf
 
cf313cd
 
 
 
 
d847cdf
1476bcf
 
 
cf313cd
1476bcf
 
 
 
 
aa7e669
 
 
 
 
86ff964
4410970
aa7e669
 
 
 
1476bcf
 
 
 
 
 
4410970
1476bcf
 
aa7e669
1476bcf
 
 
 
 
 
 
aa7e669
 
 
1476bcf
 
 
 
 
 
 
 
 
 
 
 
86ff964
1476bcf
aa7e669
 
1476bcf
 
 
 
 
aa7e669
 
 
1476bcf
 
 
 
 
aa7e669
1476bcf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c473149
 
 
 
 
 
 
 
1476bcf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
---
license: artistic-2.0
language:
- en
tags:
- '16384'
- 16k
---

# mega-encoder-small-16k-v1

This is a "huggingface-native" pretrained encoder-only model with 16384 context length. The model architecture is [MEGA](https://arxiv.org/abs/2209.10655).

## Numbers

Despite being a long-context model evaluated on a short-context benchmark, MEGA holds up decently:

| Model                     | Size  |   CTX |    Avg | 
| :------------------------ | :---- | ----: | -----: | 
| mega-encoder-small-16k-v1 | 122M  | 16384 |  0.777 | 
| bert-base-uncased         | 110M  |   512 | 0.7905 | 
| roberta-base              | 125M  |   514 |   0.86 | 
| [bert-plus-L8-4096-v1.0](https://huggingface.co./BEE-spoke-data/bert-plus-L8-4096-v1.0)    | 88.1M |  4096 | 0.8278 | 
| [mega-wikitext103](https://huggingface.co./mnaylor/mega-base-wikitext)    | 7.0M |  10000 | 0.48 | 

<details>
<summary><strong>GLUE Details</strong></summary>
  
| Model                     | Size  |   CTX |    Avg |   CoLA |  SST2 |   MRPC |   STSB |   QQP |  MNLI |  QNLI |    RTE |
| :------------------------ | :---- | ----: | -----: | -----: | ----: | -----: | -----: | ----: | ----: | ----: | -----: |
| mega-encoder-small-16k-v1 | 122M  | 16384 |  0.777 |  0.454 | 0.914 | 0.8404 |  0.906 | 0.894 | 0.806 | 0.842 |  0.556 |
| bert-base-uncased         | 110M  |   512 | 0.7905 |  0.521 | 0.935 |  0.889 |  0.858 | 0.712 |  0.84 | 0.905 |  0.664 |
| roberta-base              | 125M  |   514 |   0.86 |   0.64 |  0.95 |    0.9 |   0.91 |  0.92 |  0.88 |  0.93 |   0.79 |
| bert-plus-L8-4096-v1.0    | 88.1M |  4096 | 0.8278 | 0.6272 | 0.906 | 0.8659 | 0.9207 | 0.906 | 0.832 |   0.9 | 0.6643 |
| mega-wikitext103          | 7M    | 10000|  0.480 |  0.00  | 0.732 |  0.748 | -0.087 | 0.701 | 0.54  | 0.598 |  0.513 |

The evals for MEGA/bert-plus can be found in [this open wandb project](https://wandb.ai/pszemraj/glue-benchmarking) and are taken as the max observed values on the validation sets. The values for other models are taken as reported in their papers.
</details>

## Design

### Architecture

This encoder model has 8 layers, hidden size 768, and a feedforward ratio of 3x. The resulting total size is 122M params.

<details>
<summary><strong>Architecture Details</strong></summary>
  
Details:

1. We use a hidden size of 768, and a 3x hidden:feedforward ratio.
	- This contrasts with the 2x ratio used in the paper
2. To handle the long context, we use MEGA's chunking mechanism, with a chunk length of 1024. As such, there is a linear increase in VRAM usage for multiples of this context length past 1024.
3. EMA dimension: we use an EMA dimension of 32 in the interest of modeling long and (potentially) complex sequences
4. We use 8 layers, and a context length of 16384 tokens.
5. We use `"simple"` relative positional embeddings instead of the rotary embeddings touted in the paper.
	- This choice came from examining [the detailed logs of models](https://github.com/facebookresearch/mega/blob/aeaa4b44592cd1d60a9a34554e359eda2a62b03b/examples/mega/README.lra.md) trained/evaluated on [the LRA benchmark](https://paperswithcode.com/sota/long-range-modeling-on-lra). Models geared towards encoder-type tasks all use the simple relative positional embeddings
	- We observed poor performance/unexplicable 'walls' in previous experiments using rotary positional embeddings with MEGA as an encoder
6. BART tokenizer: we use the tokenizer from `facebook/bart-large`
	- This choice was motivated mostly from the desire to use the MEGA encoder in combination with a decoder model in the [HF EncoderDecoderModel class](https://huggingface.co./docs/transformers/model_doc/encoder-decoder) in a "huggingface-native" way. BART is supported as a decoder for the this class, **and** BART's tokenizer has the necessary preprocessing for encoder training.
 -  - Example usage of MEGA+BART to create an encoder-decoder [here](https://colab.research.google.com/gist/pszemraj/4bac8635361543b66207d73e4b25a13a/mega-encoder-small-16k-v1-for-text2text.ipynb)
	- The tokenizer's vocab is **exactly** the same as Roberta's
</details>


### Training

This model was trained with the transformers package. You can find (mostly unorganized) [training runs on wandb here](https://wandb.ai/pszemraj/mega-tuning-longctx).

<details>
<summary><strong>Training Details</strong></summary>
  
1. **Multi-task training:** the majority of training is "standard" MLM, with no next-sentence prediction, etc. However, in the interest of pretraining a _useful_ encoder for fine-tuning on various tasks, we mix-in such tasks in between several of the MLM phases, carrying-over the model's backbone to the next training phase.
	- an example would be multiple-choice tuning on the [swag](https://huggingface.co./datasets/swag)dataset
2. **MLM Mask Ratio 40% default:** we use 40% for the MLM ratio, following [Wettig et al. 2022](https://arxiv.org/abs/2202.08005). This is decreased slightly for training at longer sequences (8192+) to encourage the model to learn/leverage the available context in predictions.
3. AMP with bf16
4. **Gradient checkpointing implementation**: training this (or similar) models at ctx 8192 or longer becomes quite vram intensive despite the linear increase in memory usage
</details>

## Usage

This is a pretrained model intended to be [fine-tuned on various encoder-compatible tasks](https://github.com/huggingface/transformers/tree/831bc25d8fdb85768402f772cf65cc3d7872b211/examples/pytorch). However, if you are interested in testing inference with this model or have a deep passion for predicting mask tokens, you can use the following code:

```python
import json
from transformers import pipeline

pipe = pipeline("fill-mask", model="BEE-spoke-data/mega-encoder-small-16k-v1")
text = "I love to <mask> memes."
result = pipe(text)
print(json.dumps(result, indent=2))
```

### Gradient checkpointing implementation

If fine-tuning this model on `<task>`, using gradient checkpointing makes training at 16384 context quite feasible. By installing the transformers fork below and passing `gradient_checkpointing=True` in the training args, you should be able to finetune at batch size 1 with VRAM to spare on a single 3090/4090.

```sh
pip uninstall -y transformers
pip install -U git+https://github.com/pszemraj/transformers.git@mega-gradient-checkpointing
pip install -U huggingface-hub
```

if there is sufficient interest, we can look at making a PR into the official repo.

## Citation

if you find this useful, please consider citing this DOI, it would make us happy.

```
@misc{beespoke_data_2024,
    author       = {Peter Szemraj and Vincent Haines and {BEEspoke Data}},
    title        = {mega-encoder-small-16k-v1 (Revision 1476bcf)},
    year         = 2024,
    url          = {https://huggingface.co./BEE-spoke-data/mega-encoder-small-16k-v1},
    doi          = {10.57967/hf/1837},
    publisher    = {Hugging Face}
}
```