abhi-mosaic
commited on
Commit
•
2f88b1b
1
Parent(s):
40e5047
Update README.md
Browse files
README.md
CHANGED
@@ -39,14 +39,16 @@ It includes options for many training efficiency features such as [FlashAttentio
|
|
39 |
|
40 |
```python
|
41 |
import transformers
|
42 |
-
model = transformers.AutoModelForCausalLM.from_pretrained('mosaicml/mpt-7b-storywriter', trust_remote_code=True
|
43 |
```
|
44 |
|
45 |
-
To use the optimized triton implementation of FlashAttention, you can load with `attn_impl='triton'` and move the model to `bfloat16
|
46 |
-
|
47 |
```python
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
50 |
```
|
51 |
|
52 |
Although the model was trained with a sequence length of 2048 and finetuned with a sequence length of 65536,
|
|
|
39 |
|
40 |
```python
|
41 |
import transformers
|
42 |
+
model = transformers.AutoModelForCausalLM.from_pretrained('mosaicml/mpt-7b-storywriter', trust_remote_code=True)
|
43 |
```
|
44 |
|
45 |
+
To use the optimized [triton implementation](https://github.com/openai/triton) of FlashAttention, you can load the model with `attn_impl='triton'` and move the model to `bfloat16`:
|
|
|
46 |
```python
|
47 |
+
config = transformers.AutoConfig.from_pretrained('mosaicml/mpt-7b-storywriter', trust_remote_code=True)
|
48 |
+
config.attn_config['attn_impl'] = 'triton'
|
49 |
+
|
50 |
+
model = transformers.AutoModelForCausalLM.from_pretrained('mosaicml/mpt-7b-storywriter', config=config, torch_dtype=torch.bfloat16, trust_remote_code=True)
|
51 |
+
model.to(device='cuda:0')
|
52 |
```
|
53 |
|
54 |
Although the model was trained with a sequence length of 2048 and finetuned with a sequence length of 65536,
|