madhavatreplit
commited on
Commit
•
1e1a20a
1
Parent(s):
b6d9ff2
Update README.md for flash attn
Browse files
README.md
CHANGED
@@ -105,10 +105,16 @@ triton==2.0.0.dev20221202
|
|
105 |
|
106 |
Then, move the model to `bfloat16` and use it as follows:
|
107 |
```python
|
108 |
-
from transformers import AutoModelForCausalLM
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
|
110 |
# load model
|
111 |
-
model = AutoModelForCausalLM.from_pretrained('replit/replit-code-v1-3b',
|
112 |
model.to(device='cuda:0', dtype=torch.bfloat16)
|
113 |
|
114 |
# forward pass
|
|
|
105 |
|
106 |
Then, move the model to `bfloat16` and use it as follows:
|
107 |
```python
|
108 |
+
from transformers import AutoModelForCausalLM, AutoConfig
|
109 |
+
|
110 |
+
config = AutoConfig.from_pretrained(
|
111 |
+
"replit/replit-code-v1-3b",
|
112 |
+
trust_remote_code=True
|
113 |
+
)
|
114 |
+
config.attn_config['attn_impl'] = 'triton'
|
115 |
|
116 |
# load model
|
117 |
+
model = AutoModelForCausalLM.from_pretrained('replit/replit-code-v1-3b', config=config, trust_remote_code=True)
|
118 |
model.to(device='cuda:0', dtype=torch.bfloat16)
|
119 |
|
120 |
# forward pass
|