Spaces:
Runtime error
Runtime error
Aakash Vardhan
commited on
Commit
·
bb8f386
1
Parent(s):
8a47fff
- app.py +3 -2
- config.yaml +2 -1
app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
4 |
|
5 |
from config import load_config
|
6 |
|
@@ -19,7 +19,8 @@ if "torch_dtype" in model_config:
|
|
19 |
elif model_config["torch_dtype"] == "bfloat16":
|
20 |
model_config["torch_dtype"] = torch.bfloat16
|
21 |
|
22 |
-
|
|
|
23 |
|
24 |
checkpoint_model = "checkpoint_dir/checkpoint-650"
|
25 |
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig
|
4 |
|
5 |
from config import load_config
|
6 |
|
|
|
19 |
elif model_config["torch_dtype"] == "bfloat16":
|
20 |
model_config["torch_dtype"] = torch.bfloat16
|
21 |
|
22 |
+
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
23 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=quantization_config, **model_config)
|
24 |
|
25 |
checkpoint_model = "checkpoint_dir/checkpoint-650"
|
26 |
|
config.yaml
CHANGED
@@ -4,4 +4,5 @@ model_config:
|
|
4 |
trust_remote_code: True
|
5 |
use_cache: True
|
6 |
attn_implementation: "eager"
|
7 |
-
device_map: "cpu"
|
|
|
|
4 |
trust_remote_code: True
|
5 |
use_cache: True
|
6 |
attn_implementation: "eager"
|
7 |
+
device_map: "cpu"
|
8 |
+
load_in_8bit: True
|