rphrp1985 commited on
Commit
dd65e88
·
verified ·
1 Parent(s): d2ef205

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -24
app.py CHANGED
@@ -75,11 +75,55 @@ model = AutoModelForCausalLM.from_pretrained(model_id, token= token,
75
  model = accelerator.prepare(model)
76
 
77
 
78
- # device_map = infer_auto_device_map(model, max_memory={0: "79GB", "cpu":"65GB" })
79
 
80
- # Load the model with the inferred device map
81
- # model = load_checkpoint_and_dispatch(model, model_id, device_map=device_map, no_split_module_classes=["GPTJBlock"])
82
- # model.half()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  import json
85
 
@@ -106,27 +150,48 @@ def respond(
106
 
107
  messages= json_obj
108
 
109
- input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(accelerator.device)
110
- input_ids2 = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, return_tensors="pt") #.to('cuda')
111
- print(f"Converted input_ids dtype: {input_ids.dtype}")
112
- input_str= str(input_ids2)
113
- print('input str = ', input_str)
114
-
115
- with torch.no_grad():
116
- gen_tokens = model.generate(
117
- input_ids,
118
- max_new_tokens=max_tokens,
119
- # do_sample=True,
120
- temperature=temperature,
121
- )
122
-
123
- gen_text = tokenizer.decode(gen_tokens[0])
124
- print(gen_text)
125
- gen_text= gen_text.replace(input_str,'')
126
- gen_text= gen_text.replace('<|eot_id|>','')
127
 
128
- yield gen_text
129
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  # messages = [
132
  # # {"role": "user", "content": "What is your favourite condiment?"},
 
75
  model = accelerator.prepare(model)
76
 
77
 
 
78
 
79
+
80
+ ################################################### BG REMOVER ###################################################
81
+
82
+
83
+ import gradio as gr
84
+ from gradio_imageslider import ImageSlider
85
+ from loadimg import load_img
86
+ import spaces
87
+ from transformers import AutoModelForImageSegmentation
88
+ import torch
89
+ from torchvision import transforms
90
+
91
+ torch.set_float32_matmul_precision(["high", "highest"][0])
92
+
93
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
94
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
95
+ )
96
+ birefnet.to("cuda")
97
+
98
+ transform_image = transforms.Compose(
99
+ [
100
+ transforms.Resize((1024, 1024)),
101
+ transforms.ToTensor(),
102
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
103
+ ]
104
+ )
105
+
106
+
107
+
108
+
109
+ import base64
110
+ from io import BytesIO
111
+ from PIL import Image
112
+
113
+ def convert_image_to_base64(image):
114
+ """
115
+ Convert a PIL Image with alpha channel to a base64-encoded string.
116
+ """
117
+ # Save the image into a BytesIO buffer
118
+ img_byte_array = BytesIO()
119
+ image.save(img_byte_array, format="PNG") # Use PNG for transparency
120
+ img_byte_array.seek(0) # Reset the pointer to the beginning
121
+
122
+ # Encode the image bytes to base64
123
+ base64_str = base64.b64encode(img_byte_array.getvalue()).decode("utf-8")
124
+ return base64_str
125
+
126
+
127
 
128
  import json
129
 
 
150
 
151
  messages= json_obj
152
 
153
+ try:
154
+ image= json_obj['image']
155
+ image = load_img(image, output_type="pil")
156
+ image = im.convert("RGB")
157
+
158
+ image_size = image.size
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
+ input_images = transform_image(image).unsqueeze(0).to("cuda")
161
+ # Prediction
162
+ with torch.no_grad():
163
+ preds = birefnet(input_images)[-1].sigmoid().cpu()
164
+ pred = preds[0].squeeze()
165
+ pred_pil = transforms.ToPILImage()(pred)
166
+ mask = pred_pil.resize(image_size)
167
+ image.putalpha(mask)
168
+ return convert_image_to_base64(image)
169
+
170
+
171
+ except Exception as e:
172
+ print("using llama 8b intrcuxt ",e)
173
+
174
+ input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(accelerator.device)
175
+ input_ids2 = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, return_tensors="pt") #.to('cuda')
176
+ print(f"Converted input_ids dtype: {input_ids.dtype}")
177
+ input_str= str(input_ids2)
178
+ print('input str = ', input_str)
179
+
180
+ with torch.no_grad():
181
+ gen_tokens = model.generate(
182
+ input_ids,
183
+ max_new_tokens=max_tokens,
184
+ # do_sample=True,
185
+ temperature=temperature,
186
+ )
187
+
188
+ gen_text = tokenizer.decode(gen_tokens[0])
189
+ print(gen_text)
190
+ gen_text= gen_text.replace(input_str,'')
191
+ gen_text= gen_text.replace('<|eot_id|>','')
192
+
193
+ yield gen_text
194
+
195
 
196
  # messages = [
197
  # # {"role": "user", "content": "What is your favourite condiment?"},