Update app.py
Browse files
app.py
CHANGED
@@ -70,6 +70,7 @@ tokenizer = AutoTokenizer.from_pretrained(
|
|
70 |
, token= token,)
|
71 |
|
72 |
|
|
|
73 |
|
74 |
model = AutoModelForCausalLM.from_pretrained(model_id, token= token,
|
75 |
# torch_dtype= torch.uint8,
|
@@ -78,12 +79,14 @@ model = AutoModelForCausalLM.from_pretrained(model_id, token= token,
|
|
78 |
attn_implementation="flash_attention_2",
|
79 |
low_cpu_mem_usage=True,
|
80 |
|
81 |
-
device_map='cuda',
|
|
|
82 |
|
83 |
)
|
84 |
|
85 |
|
86 |
#
|
|
|
87 |
|
88 |
|
89 |
# device_map = infer_auto_device_map(model, max_memory={0: "79GB", "cpu":"65GB" })
|
@@ -104,7 +107,7 @@ def respond(
|
|
104 |
top_p,
|
105 |
):
|
106 |
messages = [{"role": "user", "content": "Hello, how are you?"}]
|
107 |
-
input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to('cuda')
|
108 |
## <BOS_TOKEN><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hello, how are you?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>
|
109 |
# with autocast():
|
110 |
gen_tokens = model.generate(
|
|
|
70 |
, token= token,)
|
71 |
|
72 |
|
73 |
+
accelerator = Accelerator()
|
74 |
|
75 |
model = AutoModelForCausalLM.from_pretrained(model_id, token= token,
|
76 |
# torch_dtype= torch.uint8,
|
|
|
79 |
attn_implementation="flash_attention_2",
|
80 |
low_cpu_mem_usage=True,
|
81 |
|
82 |
+
# device_map='cuda',
|
83 |
+
device_map=accelerator.device_map,
|
84 |
|
85 |
)
|
86 |
|
87 |
|
88 |
#
|
89 |
+
model = accelerator.prepare(model)
|
90 |
|
91 |
|
92 |
# device_map = infer_auto_device_map(model, max_memory={0: "79GB", "cpu":"65GB" })
|
|
|
107 |
top_p,
|
108 |
):
|
109 |
messages = [{"role": "user", "content": "Hello, how are you?"}]
|
110 |
+
input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(accelerator.device) #.to('cuda')
|
111 |
## <BOS_TOKEN><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hello, how are you?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>
|
112 |
# with autocast():
|
113 |
gen_tokens = model.generate(
|