Aakash Vardhan commited on
Commit
bb8f386
·
1 Parent(s): 8a47fff
Files changed (2) hide show
  1. app.py +3 -2
  2. config.yaml +2 -1
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
 
5
  from config import load_config
6
 
@@ -19,7 +19,8 @@ if "torch_dtype" in model_config:
19
  elif model_config["torch_dtype"] == "bfloat16":
20
  model_config["torch_dtype"] = torch.bfloat16
21
 
22
- model = AutoModelForCausalLM.from_pretrained(model_name, **model_config)
 
23
 
24
  checkpoint_model = "checkpoint_dir/checkpoint-650"
25
 
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig
4
 
5
  from config import load_config
6
 
 
19
  elif model_config["torch_dtype"] == "bfloat16":
20
  model_config["torch_dtype"] = torch.bfloat16
21
 
22
+ quantization_config = BitsAndBytesConfig(load_in_8bit=True)
23
+ model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=quantization_config, **model_config)
24
 
25
  checkpoint_model = "checkpoint_dir/checkpoint-650"
26
 
config.yaml CHANGED
@@ -4,4 +4,5 @@ model_config:
4
  trust_remote_code: True
5
  use_cache: True
6
  attn_implementation: "eager"
7
- device_map: "cpu"
 
 
4
  trust_remote_code: True
5
  use_cache: True
6
  attn_implementation: "eager"
7
+ device_map: "cpu"
8
+ load_in_8bit: True