Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -114,6 +114,20 @@ GLEEmodel_swin = GLEE_Model(cfg_swin, None, device, None, True).to(device)
|
|
114 |
GLEEmodel_swin.load_state_dict(checkpoints_swin, strict=False)
|
115 |
GLEEmodel_swin.eval()
|
116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
pixel_mean = torch.Tensor( [123.675, 116.28, 103.53]).to(device).view(3, 1, 1)
|
118 |
pixel_std = torch.Tensor([58.395, 57.12, 57.375]).to(device).view(3, 1, 1)
|
119 |
normalizer = lambda x: (x - pixel_mean) / pixel_std
|
@@ -130,16 +144,26 @@ TEXT_Y_OFFSET_SCALE = 1e-2
|
|
130 |
if inference_type != 'LSJ':
|
131 |
resizer = torchvision.transforms.Resize(inference_size,antialias=True)
|
132 |
videoresizer = torchvision.transforms.Resize(video_inference_size,antialias=True)
|
|
|
|
|
|
|
|
|
133 |
|
134 |
|
135 |
def segment_image(img, prompt_mode, categoryname, custom_category, expressiong, results_select, num_inst_select, threshold_select, mask_image_mix_ration, model_selection):
|
136 |
torch.cuda.empty_cache()
|
137 |
if model_selection == 'GLEE-Plus (SwinL)':
|
138 |
GLEEmodel = GLEEmodel_swin
|
|
|
139 |
print('use GLEE-Plus')
|
140 |
-
|
|
|
141 |
GLEEmodel = GLEEmodel_r50
|
142 |
print('use GLEE-Lite')
|
|
|
|
|
|
|
|
|
143 |
|
144 |
copyed_img = img['background'][:,:,:3].copy()
|
145 |
|
@@ -148,8 +172,12 @@ def segment_image(img, prompt_mode, categoryname, custom_category, expressiong,
|
|
148 |
_,_, ori_height, ori_width = ori_image.shape
|
149 |
|
150 |
if inference_type == 'LSJ':
|
151 |
-
|
152 |
-
|
|
|
|
|
|
|
|
|
153 |
else:
|
154 |
resize_image = resizer(ori_image)
|
155 |
image_size = torch.as_tensor((resize_image.shape[-2],resize_image.shape[-1]))
|
@@ -309,8 +337,9 @@ def segment_image(img, prompt_mode, categoryname, custom_category, expressiong,
|
|
309 |
|
310 |
fakemask = torch.from_numpy(fakemask).unsqueeze(0).to(ori_image)
|
311 |
if inference_type == 'LSJ':
|
312 |
-
|
313 |
-
infer_visual_prompt
|
|
|
314 |
else:
|
315 |
resize_fakemask = resizer(fakemask)
|
316 |
if size_divisibility > 1:
|
@@ -377,8 +406,12 @@ def process_frames(frame_list):
|
|
377 |
_,_, ori_height, ori_width = ori_image.shape
|
378 |
|
379 |
if inference_type == 'LSJ':
|
380 |
-
|
381 |
-
|
|
|
|
|
|
|
|
|
382 |
else:
|
383 |
resize_image = videoresizer(ori_image)
|
384 |
image_size = torch.as_tensor((resize_image.shape[-2],resize_image.shape[-1]))
|
@@ -414,14 +447,23 @@ def match_from_embds(tgt_embds, cur_embds):
|
|
414 |
def segment_video(video, prompt_mode, categoryname, custom_category, expressiong, results_select, num_inst_select, threshold_select, mask_image_mix_ration, model_selection,video_frames_select, prompter):
|
415 |
torch.cuda.empty_cache()
|
416 |
### model selection
|
|
|
|
|
417 |
if model_selection == 'GLEE-Plus (SwinL)':
|
418 |
GLEEmodel = GLEEmodel_swin
|
|
|
419 |
print('use GLEE-Plus')
|
420 |
clip_length = 2 #batchsize
|
421 |
-
|
|
|
422 |
GLEEmodel = GLEEmodel_r50
|
423 |
print('use GLEE-Lite')
|
424 |
clip_length = 4 #batchsize
|
|
|
|
|
|
|
|
|
|
|
425 |
|
426 |
# read video and get sparse frames
|
427 |
cap = cv2.VideoCapture(video)
|
@@ -678,8 +720,9 @@ def segment_video(video, prompt_mode, categoryname, custom_category, expressiong
|
|
678 |
|
679 |
fakemask = torch.from_numpy(fakemask).unsqueeze(0).to(ori_image)
|
680 |
if inference_type == 'LSJ':
|
681 |
-
|
682 |
-
infer_visual_prompt
|
|
|
683 |
else:
|
684 |
resize_fakemask = videoresizer(fakemask)
|
685 |
if size_divisibility > 1:
|
|
|
114 |
GLEEmodel_swin.load_state_dict(checkpoints_swin, strict=False)
|
115 |
GLEEmodel_swin.eval()
|
116 |
|
117 |
+
|
118 |
+
cfg_eva02 = get_cfg()
|
119 |
+
add_deeplab_config(cfg_eva02)
|
120 |
+
add_glee_config(cfg_eva02)
|
121 |
+
conf_files_swin = 'GLEE/configs/EVA02.yaml'
|
122 |
+
checkpoints_eva = torch.load('GLEE/GLEE_{}.pth'.format(args.version))
|
123 |
+
cfg_eva02.merge_from_file(conf_files_swin)
|
124 |
+
GLEEmodel_eva02 = GLEE_Model(cfg_eva02, None, device, None, True).to(device)
|
125 |
+
GLEEmodel_eva02.load_state_dict(checkpoints_eva, strict=False)
|
126 |
+
GLEEmodel_eva02.eval()
|
127 |
+
# inference_type = 'LSJ'
|
128 |
+
|
129 |
+
|
130 |
+
|
131 |
pixel_mean = torch.Tensor( [123.675, 116.28, 103.53]).to(device).view(3, 1, 1)
|
132 |
pixel_std = torch.Tensor([58.395, 57.12, 57.375]).to(device).view(3, 1, 1)
|
133 |
normalizer = lambda x: (x - pixel_mean) / pixel_std
|
|
|
144 |
if inference_type != 'LSJ':
|
145 |
resizer = torchvision.transforms.Resize(inference_size,antialias=True)
|
146 |
videoresizer = torchvision.transforms.Resize(video_inference_size,antialias=True)
|
147 |
+
else:
|
148 |
+
resizer = torchvision.transforms.Resize(size = 1535, max_size=1536, antialias=True)
|
149 |
+
videoresizer = torchvision.transforms.Resize(size = 1535, max_size=1536, antialias=True)
|
150 |
+
|
151 |
|
152 |
|
153 |
def segment_image(img, prompt_mode, categoryname, custom_category, expressiong, results_select, num_inst_select, threshold_select, mask_image_mix_ration, model_selection):
|
154 |
torch.cuda.empty_cache()
|
155 |
if model_selection == 'GLEE-Plus (SwinL)':
|
156 |
GLEEmodel = GLEEmodel_swin
|
157 |
+
inference_type = 'resize_shot'
|
158 |
print('use GLEE-Plus')
|
159 |
+
elif model_selection == 'GLEE-Lite (R50)':
|
160 |
+
inference_type = 'resize_shot'
|
161 |
GLEEmodel = GLEEmodel_r50
|
162 |
print('use GLEE-Lite')
|
163 |
+
else:
|
164 |
+
GLEEmodel = GLEEmodel_eva02
|
165 |
+
print('use GLEE-Pro')
|
166 |
+
inference_type = 'LSJ'
|
167 |
|
168 |
copyed_img = img['background'][:,:,:3].copy()
|
169 |
|
|
|
172 |
_,_, ori_height, ori_width = ori_image.shape
|
173 |
|
174 |
if inference_type == 'LSJ':
|
175 |
+
resize_image = resizer(ori_image)
|
176 |
+
image_size = torch.as_tensor((resize_image.shape[-2],resize_image.shape[-1]))
|
177 |
+
re_size = resize_image.shape[-2:]
|
178 |
+
infer_image = torch.zeros(1,3,1536,1536).to(ori_image)
|
179 |
+
infer_image[:,:,:image_size[0],:image_size[1]] = resize_image
|
180 |
+
padding_size = (1536,1536)
|
181 |
else:
|
182 |
resize_image = resizer(ori_image)
|
183 |
image_size = torch.as_tensor((resize_image.shape[-2],resize_image.shape[-1]))
|
|
|
337 |
|
338 |
fakemask = torch.from_numpy(fakemask).unsqueeze(0).to(ori_image)
|
339 |
if inference_type == 'LSJ':
|
340 |
+
resize_fakemask = resizer(fakemask)
|
341 |
+
infer_visual_prompt = torch.zeros(1,1536,1536).to(resize_fakemask)
|
342 |
+
infer_visual_prompt[:,:image_size[0],:image_size[1]] = resize_fakemask
|
343 |
else:
|
344 |
resize_fakemask = resizer(fakemask)
|
345 |
if size_divisibility > 1:
|
|
|
406 |
_,_, ori_height, ori_width = ori_image.shape
|
407 |
|
408 |
if inference_type == 'LSJ':
|
409 |
+
resize_image = resizer(ori_image)
|
410 |
+
image_size = torch.as_tensor((resize_image.shape[-2],resize_image.shape[-1]))
|
411 |
+
re_size = resize_image.shape[-2:]
|
412 |
+
infer_image = torch.zeros(1,3,1536,1536).to(ori_image)
|
413 |
+
infer_image[:,:,:image_size[0],:image_size[1]] = resize_image
|
414 |
+
padding_size = (1536,1536)
|
415 |
else:
|
416 |
resize_image = videoresizer(ori_image)
|
417 |
image_size = torch.as_tensor((resize_image.shape[-2],resize_image.shape[-1]))
|
|
|
447 |
def segment_video(video, prompt_mode, categoryname, custom_category, expressiong, results_select, num_inst_select, threshold_select, mask_image_mix_ration, model_selection,video_frames_select, prompter):
|
448 |
torch.cuda.empty_cache()
|
449 |
### model selection
|
450 |
+
|
451 |
+
|
452 |
if model_selection == 'GLEE-Plus (SwinL)':
|
453 |
GLEEmodel = GLEEmodel_swin
|
454 |
+
inference_type = 'resize_shot'
|
455 |
print('use GLEE-Plus')
|
456 |
clip_length = 2 #batchsize
|
457 |
+
elif model_selection == 'GLEE-Lite (R50)':
|
458 |
+
inference_type = 'resize_shot'
|
459 |
GLEEmodel = GLEEmodel_r50
|
460 |
print('use GLEE-Lite')
|
461 |
clip_length = 4 #batchsize
|
462 |
+
else:
|
463 |
+
GLEEmodel = GLEEmodel_eva02
|
464 |
+
print('use GLEE-Pro')
|
465 |
+
inference_type = 'LSJ'
|
466 |
+
clip_length = 1 #batchsize
|
467 |
|
468 |
# read video and get sparse frames
|
469 |
cap = cv2.VideoCapture(video)
|
|
|
720 |
|
721 |
fakemask = torch.from_numpy(fakemask).unsqueeze(0).to(ori_image)
|
722 |
if inference_type == 'LSJ':
|
723 |
+
resize_fakemask = resizer(fakemask)
|
724 |
+
infer_visual_prompt = torch.zeros(1,1536,1536).to(resize_fakemask)
|
725 |
+
infer_visual_prompt[:,:image_size[0],:image_size[1]] = resize_fakemask
|
726 |
else:
|
727 |
resize_fakemask = videoresizer(fakemask)
|
728 |
if size_divisibility > 1:
|