Update app.py
Browse files
app.py
CHANGED
@@ -75,11 +75,55 @@ model = AutoModelForCausalLM.from_pretrained(model_id, token= token,
|
|
75 |
model = accelerator.prepare(model)
|
76 |
|
77 |
|
78 |
-
# device_map = infer_auto_device_map(model, max_memory={0: "79GB", "cpu":"65GB" })
|
79 |
|
80 |
-
|
81 |
-
|
82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
import json
|
85 |
|
@@ -106,27 +150,48 @@ def respond(
|
|
106 |
|
107 |
messages= json_obj
|
108 |
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
with torch.no_grad():
|
116 |
-
gen_tokens = model.generate(
|
117 |
-
input_ids,
|
118 |
-
max_new_tokens=max_tokens,
|
119 |
-
# do_sample=True,
|
120 |
-
temperature=temperature,
|
121 |
-
)
|
122 |
-
|
123 |
-
gen_text = tokenizer.decode(gen_tokens[0])
|
124 |
-
print(gen_text)
|
125 |
-
gen_text= gen_text.replace(input_str,'')
|
126 |
-
gen_text= gen_text.replace('<|eot_id|>','')
|
127 |
|
128 |
-
|
129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
|
131 |
# messages = [
|
132 |
# # {"role": "user", "content": "What is your favourite condiment?"},
|
|
|
75 |
model = accelerator.prepare(model)
|
76 |
|
77 |
|
|
|
78 |
|
79 |
+
|
80 |
+
################################################### BG REMOVER ###################################################
|
81 |
+
|
82 |
+
|
83 |
+
import gradio as gr
|
84 |
+
from gradio_imageslider import ImageSlider
|
85 |
+
from loadimg import load_img
|
86 |
+
import spaces
|
87 |
+
from transformers import AutoModelForImageSegmentation
|
88 |
+
import torch
|
89 |
+
from torchvision import transforms
|
90 |
+
|
91 |
+
torch.set_float32_matmul_precision(["high", "highest"][0])
|
92 |
+
|
93 |
+
birefnet = AutoModelForImageSegmentation.from_pretrained(
|
94 |
+
"ZhengPeng7/BiRefNet", trust_remote_code=True
|
95 |
+
)
|
96 |
+
birefnet.to("cuda")
|
97 |
+
|
98 |
+
transform_image = transforms.Compose(
|
99 |
+
[
|
100 |
+
transforms.Resize((1024, 1024)),
|
101 |
+
transforms.ToTensor(),
|
102 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
103 |
+
]
|
104 |
+
)
|
105 |
+
|
106 |
+
|
107 |
+
|
108 |
+
|
109 |
+
import base64
|
110 |
+
from io import BytesIO
|
111 |
+
from PIL import Image
|
112 |
+
|
113 |
+
def convert_image_to_base64(image):
|
114 |
+
"""
|
115 |
+
Convert a PIL Image with alpha channel to a base64-encoded string.
|
116 |
+
"""
|
117 |
+
# Save the image into a BytesIO buffer
|
118 |
+
img_byte_array = BytesIO()
|
119 |
+
image.save(img_byte_array, format="PNG") # Use PNG for transparency
|
120 |
+
img_byte_array.seek(0) # Reset the pointer to the beginning
|
121 |
+
|
122 |
+
# Encode the image bytes to base64
|
123 |
+
base64_str = base64.b64encode(img_byte_array.getvalue()).decode("utf-8")
|
124 |
+
return base64_str
|
125 |
+
|
126 |
+
|
127 |
|
128 |
import json
|
129 |
|
|
|
150 |
|
151 |
messages= json_obj
|
152 |
|
153 |
+
try:
|
154 |
+
image= json_obj['image']
|
155 |
+
image = load_img(image, output_type="pil")
|
156 |
+
image = im.convert("RGB")
|
157 |
+
|
158 |
+
image_size = image.size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
|
160 |
+
input_images = transform_image(image).unsqueeze(0).to("cuda")
|
161 |
+
# Prediction
|
162 |
+
with torch.no_grad():
|
163 |
+
preds = birefnet(input_images)[-1].sigmoid().cpu()
|
164 |
+
pred = preds[0].squeeze()
|
165 |
+
pred_pil = transforms.ToPILImage()(pred)
|
166 |
+
mask = pred_pil.resize(image_size)
|
167 |
+
image.putalpha(mask)
|
168 |
+
return convert_image_to_base64(image)
|
169 |
+
|
170 |
+
|
171 |
+
except Exception as e:
|
172 |
+
print("using llama 8b intrcuxt ",e)
|
173 |
+
|
174 |
+
input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(accelerator.device)
|
175 |
+
input_ids2 = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, return_tensors="pt") #.to('cuda')
|
176 |
+
print(f"Converted input_ids dtype: {input_ids.dtype}")
|
177 |
+
input_str= str(input_ids2)
|
178 |
+
print('input str = ', input_str)
|
179 |
+
|
180 |
+
with torch.no_grad():
|
181 |
+
gen_tokens = model.generate(
|
182 |
+
input_ids,
|
183 |
+
max_new_tokens=max_tokens,
|
184 |
+
# do_sample=True,
|
185 |
+
temperature=temperature,
|
186 |
+
)
|
187 |
+
|
188 |
+
gen_text = tokenizer.decode(gen_tokens[0])
|
189 |
+
print(gen_text)
|
190 |
+
gen_text= gen_text.replace(input_str,'')
|
191 |
+
gen_text= gen_text.replace('<|eot_id|>','')
|
192 |
+
|
193 |
+
yield gen_text
|
194 |
+
|
195 |
|
196 |
# messages = [
|
197 |
# # {"role": "user", "content": "What is your favourite condiment?"},
|