Spaces:
Runtime error
Runtime error
update demo
Browse files- .gitignore +2 -1
- .log/log.txt +5 -5
- SegFormer +1 -1
- mask.png +0 -0
- output.png +0 -0
- streamlit_test.py +3 -0
- test.png +0 -0
- test.py +168 -242
.gitignore
CHANGED
@@ -2,4 +2,5 @@ __pycache__
|
|
2 |
*.pyc
|
3 |
checkpoints/
|
4 |
I2SB/
|
5 |
-
*.pth
|
|
|
|
2 |
*.pyc
|
3 |
checkpoints/
|
4 |
I2SB/
|
5 |
+
*.pth
|
6 |
+
SegFormer/
|
.log/log.txt
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
-
[19:
|
2 |
INFO (0:00:00) [Diffusion] Built I2SB diffusion: steps=1000!
|
3 |
-
[19:
|
4 |
-
[19:02
|
5 |
-
[19:
|
6 |
-
[19:
|
|
|
1 |
+
[19:58:55] INFO (0:00:00) Loaded options from opt_pkl_path=PosixPath('I2SB/results/inpaint-freeform2030/options.pkl')!
|
2 |
INFO (0:00:00) [Diffusion] Built I2SB diffusion: steps=1000!
|
3 |
+
[19:58:58] INFO (0:00:03) [Net] Initialized network from ckpt_pkl='I2SB/data/256x256_diffusion_uncond_fixedsigma.pkl'! Size=552807171!
|
4 |
+
[19:59:02] INFO (0:00:07) [Net] Loaded pretrained adm ckpt_pt='I2SB/data/256x256_diffusion_uncond_fixedsigma.pt'!
|
5 |
+
[19:59:06] INFO (0:00:11) [Net] Loaded network ckpt: I2SB/results/inpaint-freeform2030/latest.pt!
|
6 |
+
[19:59:08] INFO (0:00:13) [Ema] Loaded ema ckpt: I2SB/results/inpaint-freeform2030/latest.pt!
|
SegFormer
CHANGED
@@ -1 +1 @@
|
|
1 |
-
Subproject commit
|
|
|
1 |
+
Subproject commit ccc3dd500c4091a583b4b2749e35da501e670aca
|
mask.png
ADDED
output.png
CHANGED
streamlit_test.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
st.write("Hello")
|
test.png
CHANGED
test.py
CHANGED
@@ -40,6 +40,7 @@ from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases
|
|
40 |
import sys
|
41 |
|
42 |
sys.path.insert(0, "/home/ubuntu/Thesis-Demo/I2SB")
|
|
|
43 |
|
44 |
import numpy as np
|
45 |
import torch
|
@@ -62,6 +63,18 @@ from I2SB.i2sb import Runner, ckpt_util, download_ckpt
|
|
62 |
from I2SB.logger import Logger
|
63 |
from I2SB.sample import *
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
|
67 |
import cv2
|
@@ -89,13 +102,6 @@ if os.environ.get('IS_MY_DEBUG') is not None:
|
|
89 |
inpainting_enable = False
|
90 |
kosmos_enable = False
|
91 |
|
92 |
-
if lama_cleaner_enable:
|
93 |
-
try:
|
94 |
-
from lama_cleaner.model_manager import ModelManager
|
95 |
-
from lama_cleaner.schema import Config as lama_Config
|
96 |
-
except Exception as e:
|
97 |
-
lama_cleaner_enable = False
|
98 |
-
|
99 |
# segment anything
|
100 |
from segment_anything import build_sam, SamPredictor, SamAutomaticMaskGenerator
|
101 |
|
@@ -191,13 +197,16 @@ def get_point(img, sel_pix, evt: gr.SelectData):
|
|
191 |
|
192 |
|
193 |
def undo_button(orig_img, sel_pix):
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
sel_pix
|
198 |
-
|
199 |
-
|
200 |
-
|
|
|
|
|
|
|
201 |
|
202 |
def clear_button(orig_img):
|
203 |
|
@@ -256,10 +265,22 @@ def load_i2sb_model():
|
|
256 |
runner.ema = ExponentialMovingAverage(
|
257 |
runner.net.parameters(), decay=0.99) # re-init ema with fp16 weight
|
258 |
|
|
|
259 |
print("Loading time:", (time.time()-s)*1e3, "ms.")
|
260 |
i2sb_model = runner
|
261 |
return runner
|
262 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
263 |
def plot_boxes_to_image(image_pil, tgt):
|
264 |
H, W = tgt["size"]
|
265 |
boxes = tgt["boxes"]
|
@@ -326,42 +347,6 @@ def load_image(image_path):
|
|
326 |
return image_pil, image
|
327 |
|
328 |
|
329 |
-
|
330 |
-
def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
|
331 |
-
caption = caption.lower()
|
332 |
-
caption = caption.strip()
|
333 |
-
if not caption.endswith("."):
|
334 |
-
caption = caption + "."
|
335 |
-
model = model.to(device)
|
336 |
-
image = image.to(device)
|
337 |
-
with torch.no_grad():
|
338 |
-
outputs = model(image[None], captions=[caption])
|
339 |
-
logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
|
340 |
-
boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
|
341 |
-
logits.shape[0]
|
342 |
-
|
343 |
-
# filter output
|
344 |
-
logits_filt = logits.clone()
|
345 |
-
boxes_filt = boxes.clone()
|
346 |
-
filt_mask = logits_filt.max(dim=1)[0] > box_threshold
|
347 |
-
logits_filt = logits_filt[filt_mask] # num_filt, 256
|
348 |
-
boxes_filt = boxes_filt[filt_mask] # num_filt, 4
|
349 |
-
logits_filt.shape[0]
|
350 |
-
|
351 |
-
# get phrase
|
352 |
-
tokenlizer = model.tokenizer
|
353 |
-
tokenized = tokenlizer(caption)
|
354 |
-
# build pred
|
355 |
-
pred_phrases = []
|
356 |
-
for logit, box in zip(logits_filt, boxes_filt):
|
357 |
-
pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
|
358 |
-
if with_logits:
|
359 |
-
pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
|
360 |
-
else:
|
361 |
-
pred_phrases.append(pred_phrase)
|
362 |
-
|
363 |
-
return boxes_filt, pred_phrases
|
364 |
-
|
365 |
def show_mask(mask, ax, random_color=False):
|
366 |
if random_color:
|
367 |
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
@@ -447,99 +432,45 @@ def load_sd_model(device):
|
|
447 |
)
|
448 |
sd_model = sd_model.to(device)
|
449 |
|
450 |
-
def forward_i2sb(img, mask):
|
451 |
-
|
|
|
|
|
452 |
mask = np.where(mask > 0, 1, 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
453 |
img_tensor = i2sb_transforms(img).to(
|
454 |
i2sb_opt.device).unsqueeze(0)
|
455 |
|
456 |
mask_tensor = torch.from_numpy(np.resize(np.array(mask), (256,256))).to(
|
457 |
i2sb_opt.device).unsqueeze(0).unsqueeze(0)
|
458 |
-
print("POST PROCESSING\t", torch.unique(img_tensor))
|
459 |
-
|
|
|
|
|
|
|
460 |
f = time.time()
|
461 |
xs, _ = i2sb_model.ddpm_sampling(
|
462 |
ckpt_opt, img_tensor, mask=mask_tensor, cond=None, clip_denoise=i2sb_opt.clip_denoise, nfe=nfe, verbose=i2sb_opt.n_gpu_per_node == 1)
|
463 |
recon_img = xs[:, 0, ...].to(i2sb_opt.device)
|
464 |
-
tu.save_image((recon_img+1)/2, "output.png")
|
|
|
465 |
print(recon_img.shape)
|
466 |
-
return transforms.ToPILImage()(((recon_img+1)/2)[0])
|
467 |
|
468 |
-
def
|
469 |
-
|
470 |
-
|
471 |
-
ori_image = image
|
472 |
-
if mask.shape[0] == image.shape[1] and mask.shape[1] == image.shape[0] and mask.shape[0] != mask.shape[1]:
|
473 |
-
# rotate image
|
474 |
-
logger.info(f'_______lama_cleaner_process_______2____')
|
475 |
-
ori_image = np.transpose(image[::-1, ...][:, ::-1], axes=(1, 0, 2))[::-1, ...]
|
476 |
-
logger.info(f'_______lama_cleaner_process_______3____')
|
477 |
-
image = ori_image
|
478 |
-
|
479 |
-
logger.info(f'_______lama_cleaner_process_______4____')
|
480 |
-
original_shape = ori_image.shape
|
481 |
-
logger.info(f'_______lama_cleaner_process_______5____')
|
482 |
-
interpolation = cv2.INTER_CUBIC
|
483 |
-
|
484 |
-
size_limit = cleaner_size_limit
|
485 |
-
if size_limit == -1:
|
486 |
-
logger.info(f'_______lama_cleaner_process_______6____')
|
487 |
-
size_limit = max(image.shape)
|
488 |
-
else:
|
489 |
-
logger.info(f'_______lama_cleaner_process_______7____')
|
490 |
-
size_limit = int(size_limit)
|
491 |
-
|
492 |
-
logger.info(f'_______lama_cleaner_process_______8____')
|
493 |
-
config = lama_Config(
|
494 |
-
ldm_steps=25,
|
495 |
-
ldm_sampler='plms',
|
496 |
-
zits_wireframe=True,
|
497 |
-
hd_strategy='Original',
|
498 |
-
hd_strategy_crop_margin=196,
|
499 |
-
hd_strategy_crop_trigger_size=1280,
|
500 |
-
hd_strategy_resize_limit=2048,
|
501 |
-
prompt='',
|
502 |
-
use_croper=False,
|
503 |
-
croper_x=0,
|
504 |
-
croper_y=0,
|
505 |
-
croper_height=512,
|
506 |
-
croper_width=512,
|
507 |
-
sd_mask_blur=5,
|
508 |
-
sd_strength=0.75,
|
509 |
-
sd_steps=50,
|
510 |
-
sd_guidance_scale=7.5,
|
511 |
-
sd_sampler='ddim',
|
512 |
-
sd_seed=42,
|
513 |
-
cv2_flag='INPAINT_NS',
|
514 |
-
cv2_radius=5,
|
515 |
-
)
|
516 |
-
|
517 |
-
logger.info(f'_______lama_cleaner_process_______9____')
|
518 |
-
if config.sd_seed == -1:
|
519 |
-
config.sd_seed = random.randint(1, 999999999)
|
520 |
-
|
521 |
-
# logger.info(f"Origin image shape_0_: {original_shape} / {size_limit}")
|
522 |
-
logger.info(f'_______lama_cleaner_process_______10____')
|
523 |
-
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
|
524 |
-
# logger.info(f"Resized image shape_1_: {image.shape}")
|
525 |
-
|
526 |
-
# logger.info(f"mask image shape_0_: {mask.shape} / {type(mask)}")
|
527 |
-
logger.info(f'_______lama_cleaner_process_______11____')
|
528 |
-
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
|
529 |
-
# logger.info(f"mask image shape_1_: {mask.shape} / {type(mask)}")
|
530 |
-
|
531 |
-
logger.info(f'_______lama_cleaner_process_______12____')
|
532 |
-
res_np_img = lama_cleaner_model(image, mask, config)
|
533 |
-
logger.info(f'_______lama_cleaner_process_______13____')
|
534 |
-
torch.cuda.empty_cache()
|
535 |
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
except Exception as e:
|
540 |
-
logger.info(f'lama_cleaner_process[Error]:' + str(e))
|
541 |
-
image = None
|
542 |
-
return image
|
543 |
|
544 |
# visualization
|
545 |
def draw_selected_mask(mask, draw):
|
@@ -632,27 +563,15 @@ def get_time_cost(run_task_time, time_cost_str):
|
|
632 |
return run_task_time, time_cost_str
|
633 |
|
634 |
def run_anything_task(input_image, input_points, origin_image, task_type,
|
635 |
-
mask_source_radio,
|
636 |
|
637 |
run_task_time = 0
|
638 |
time_cost_str = ''
|
639 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
640 |
print("HERE................", task_type)
|
641 |
-
|
642 |
-
global kosmos_model, kosmos_processor
|
643 |
-
if isinstance(input_image, dict):
|
644 |
-
image_pil, image = load_image(input_image['image'].convert("RGB"))
|
645 |
-
input_img = input_image['image']
|
646 |
-
else:
|
647 |
-
image_pil, image = load_image(input_image.convert("RGB"))
|
648 |
-
input_img = input_image
|
649 |
-
|
650 |
-
kosmos_image, kosmos_text, kosmos_entities = kosmos_generate_predictions(image_pil, kosmos_model, kosmos_processor)
|
651 |
-
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
652 |
-
return None, None, time_cost_str, kosmos_image, gr.Textbox.update(visible=(time_cost_str !='')), kosmos_text, kosmos_entities
|
653 |
-
|
654 |
if input_image is None:
|
655 |
-
return [], gr.Gallery.update(label='Please upload a image!😂😂😂😂'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !=''))
|
656 |
|
657 |
file_temp = int(time.time())
|
658 |
logger.info(f'run_anything_task_002/{device}_[{file_temp}]_{task_type}/[{mask_source_radio}]_1_')
|
@@ -682,92 +601,119 @@ def run_anything_task(input_image, input_points, origin_image, task_type,
|
|
682 |
groundingdino_device = 'cpu'
|
683 |
|
684 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_')
|
685 |
-
if task_type == 'segment' or
|
686 |
-
image = np.array(
|
687 |
-
if
|
688 |
-
sam_predictor
|
689 |
-
|
690 |
-
|
691 |
-
|
692 |
-
|
693 |
-
|
694 |
-
|
695 |
-
|
696 |
-
|
697 |
-
|
698 |
-
|
699 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
700 |
else:
|
701 |
-
|
702 |
-
|
703 |
-
|
704 |
-
|
705 |
-
plt.imshow(origin_image)
|
706 |
-
for mask in masks:
|
707 |
-
show_mask(mask, plt.gca(), random_color=True)
|
708 |
-
# for box, label in zip(boxes_filt, pred_phrases):
|
709 |
-
# show_box(box.cpu().numpy(), plt.gca(), label)
|
710 |
-
plt.axis('off')
|
711 |
-
image_path = os.path.join(output_dir, f"grounding_seg_output_{file_temp}.jpg")
|
712 |
-
plt.savefig(image_path, bbox_inches="tight")
|
713 |
-
plt.clf()
|
714 |
-
plt.close('all')
|
715 |
-
segment_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
|
716 |
-
os.remove(image_path)
|
717 |
output_images.append(Image.fromarray(segment_image_result))
|
718 |
-
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
|
|
719 |
|
720 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_3_')
|
721 |
if task_type == 'detection' or task_type == 'segment':
|
722 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
|
723 |
-
return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !=''))
|
724 |
-
elif task_type in ['inpainting', 'outpainting'] or task_type == '
|
725 |
-
if mask_source_radio == mask_source_segment:
|
726 |
-
task_type = 'remove'
|
727 |
|
728 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_4_')
|
729 |
-
if
|
730 |
-
|
731 |
-
|
732 |
-
|
733 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
734 |
else:
|
735 |
-
|
736 |
-
|
737 |
-
|
738 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
739 |
output_images.append(mask_pil.convert("RGB"))
|
740 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
741 |
|
742 |
-
if task_type in ['inpainting', '
|
743 |
# image_inpainting = sd_model(prompt = "", image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
|
744 |
-
input_img.save("test.png")
|
745 |
-
|
746 |
-
|
747 |
-
|
|
|
|
|
|
|
|
|
748 |
else:
|
749 |
# remove from mask
|
750 |
aasds = 1
|
751 |
|
752 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_6_')
|
753 |
-
image_inpainting = lama_cleaner_process(np.array(image_pil), np.array(mask_pil.convert("L")), cleaner_size_limit)
|
754 |
if image_inpainting is None:
|
755 |
logger.info(f'run_anything_task_failed_')
|
756 |
-
return None, None, None, None
|
757 |
|
758 |
# output_images.append(image_inpainting)
|
759 |
# run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
760 |
|
761 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_7_')
|
762 |
image_inpainting = image_inpainting.resize((image_pil.size[0], image_pil.size[1]))
|
|
|
763 |
output_images.append(image_inpainting)
|
764 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
765 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
|
766 |
-
return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !=''))
|
767 |
else:
|
768 |
logger.info(f"task_type:{task_type} error!")
|
769 |
logger.info(f'run_anything_task_[{file_temp}]_9_9_')
|
770 |
-
return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !=''))
|
771 |
|
772 |
def change_radio_display(task_type, mask_source_radio, orig_img):
|
773 |
mask_source_radio_visible = False
|
@@ -789,20 +735,19 @@ def change_radio_display(task_type, mask_source_radio, orig_img):
|
|
789 |
mask_source_radio_visible = True
|
790 |
if task_type == "relate anything":
|
791 |
num_relation_visible = True
|
792 |
-
if task_type == "
|
793 |
-
ret = gr.Image(value= orig_img, elem_id="image_upload", type='pil', label="Upload", height=512, tool = "editor")# tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6)
|
794 |
-
elif task_type == "inpainting":
|
795 |
ret = gr.Image(value = orig_img, elem_id="image_upload", type='pil', label="Upload", height=512, tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6)
|
|
|
|
|
796 |
|
797 |
return (gr.Radio.update(visible=mask_source_radio_visible),
|
798 |
gr.Slider.update(visible=num_relation_visible),
|
799 |
gr.Gallery.update(visible=image_gallery_visible),
|
800 |
-
gr.Radio
|
801 |
-
gr.
|
802 |
-
gr.HighlightedText.update(visible=kosmos_text_output_visible),
|
803 |
ret, [],
|
804 |
-
gr.Button("Undo point", visible = task_type
|
805 |
-
gr.Button("Clear point", visible = task_type
|
806 |
|
807 |
def get_model_device(module):
|
808 |
try:
|
@@ -832,10 +777,11 @@ def main_gradio(args):
|
|
832 |
with gr.Row():
|
833 |
with gr.Column():
|
834 |
selected_points = gr.State([])
|
835 |
-
original_image = gr.State()
|
836 |
task_types = ["segment"]
|
837 |
if inpainting_enable:
|
838 |
task_types.append("inpainting")
|
|
|
839 |
|
840 |
|
841 |
input_image = gr.Image(elem_id="image_upload", type='pil', label="Upload", height=512)
|
@@ -854,7 +800,7 @@ def main_gradio(args):
|
|
854 |
with gr.Row():
|
855 |
with gr.Column():
|
856 |
|
857 |
-
undo_point_button = gr.Button("Undo point")
|
858 |
undo_point_button.click(
|
859 |
fn= undo_button,
|
860 |
inputs=[original_image, selected_points],
|
@@ -863,7 +809,7 @@ def main_gradio(args):
|
|
863 |
|
864 |
with gr.Column():
|
865 |
|
866 |
-
clear_point_button = gr.Button("Clear point")
|
867 |
clear_point_button.click(
|
868 |
fn= clear_button,
|
869 |
inputs=[original_image],
|
@@ -876,10 +822,15 @@ def main_gradio(args):
|
|
876 |
mask_source_radio = gr.Radio([mask_source_draw, mask_source_segment],
|
877 |
value=mask_source_draw, label="Mask from",
|
878 |
visible=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
879 |
num_relation = gr.Slider(label="How many relations do you want to see", minimum=1, maximum=20, value=5, step=1, visible=False)
|
880 |
|
881 |
-
kosmos_input = gr.Radio(["Brief", "Detailed"], label="Kosmos Description Type", value="Brief", visible=False)
|
882 |
-
|
883 |
run_button = gr.Button(label="Run", visible=True)
|
884 |
# with gr.Accordion("Advanced options", open=False) as advanced_options:
|
885 |
# box_threshold = gr.Slider(
|
@@ -900,47 +851,21 @@ def main_gradio(args):
|
|
900 |
|
901 |
with gr.Column():
|
902 |
image_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", height=512, visible=True
|
903 |
-
).style(preview=True, columns=[5], object_fit="scale-down", height=
|
904 |
time_cost = gr.Textbox(label="Time cost by step (ms):", visible=False, interactive=False)
|
905 |
|
906 |
-
kosmos_output = gr.Image(type="pil", label="result images", visible=False)
|
907 |
-
kosmos_text_output = gr.HighlightedText(
|
908 |
-
label="Generated Description",
|
909 |
-
combine_adjacent=False,
|
910 |
-
show_legend=True,
|
911 |
-
visible=False,
|
912 |
-
).style(color_map=color_map)
|
913 |
-
# record which text span (label) is selected
|
914 |
-
selected = gr.Number(-1, show_label=False, placeholder="Selected", visible=False)
|
915 |
-
|
916 |
-
# record the current `entities`
|
917 |
-
entity_output = gr.Textbox(visible=False)
|
918 |
-
|
919 |
-
# get the current selected span label
|
920 |
-
def get_text_span_label(evt: gr.SelectData):
|
921 |
-
if evt.value[-1] is None:
|
922 |
-
return -1
|
923 |
-
return int(evt.value[-1])
|
924 |
-
# and set this information to `selected`
|
925 |
-
kosmos_text_output.select(get_text_span_label, None, selected)
|
926 |
|
927 |
-
# update output image when we change the span (enity) selection
|
928 |
-
def update_output_image(img_input, image_output, entities, idx):
|
929 |
-
entities = ast.literal_eval(entities)
|
930 |
-
updated_image = draw_entity_boxes_on_image(img_input, entities, entity_index=idx)
|
931 |
-
return updated_image
|
932 |
-
selected.change(update_output_image, [kosmos_output, kosmos_output, entity_output, selected], [kosmos_output])
|
933 |
|
934 |
run_button.click(fn=run_anything_task, inputs=[
|
935 |
input_image, selected_points, original_image, task_type,
|
936 |
-
mask_source_radio],
|
937 |
-
outputs=[image_gallery, image_gallery, time_cost, time_cost
|
938 |
|
939 |
mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio, original_image],
|
940 |
outputs=[mask_source_radio, num_relation])
|
941 |
task_type.change(fn=change_radio_display, inputs=[task_type, mask_source_radio, original_image],
|
942 |
outputs=[mask_source_radio, num_relation,
|
943 |
-
image_gallery,
|
944 |
])
|
945 |
|
946 |
# DESCRIPTION = f'### This demo from [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything). <br>'
|
@@ -985,6 +910,7 @@ if __name__ == "__main__":
|
|
985 |
|
986 |
if sam_enable:
|
987 |
load_sam_model(device)
|
|
|
988 |
|
989 |
if inpainting_enable:
|
990 |
load_sd_model(device)
|
|
|
40 |
import sys
|
41 |
|
42 |
sys.path.insert(0, "/home/ubuntu/Thesis-Demo/I2SB")
|
43 |
+
sys.path.insert(0, "/home/ubuntu/Thesis-Demo/SegFormer")
|
44 |
|
45 |
import numpy as np
|
46 |
import torch
|
|
|
63 |
from I2SB.logger import Logger
|
64 |
from I2SB.sample import *
|
65 |
|
66 |
+
from pathlib import Path
|
67 |
+
|
68 |
+
inpaint_checkpoint = Path("/home/ubuntu/Thesis-Demo/I2SB/results")
|
69 |
+
|
70 |
+
if not inpaint_checkpoint.exists():
|
71 |
+
os.system("pip install transformers==4.32.0")
|
72 |
+
|
73 |
+
# SegFormer
|
74 |
+
from PIL import Image
|
75 |
+
|
76 |
+
from SegFormer.mmseg.apis import inference_segmentor, init_segmentor, visualize_result_pyplot
|
77 |
+
from SegFormer.mmseg.core.evaluation import get_palette
|
78 |
|
79 |
|
80 |
import cv2
|
|
|
102 |
inpainting_enable = False
|
103 |
kosmos_enable = False
|
104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
# segment anything
|
106 |
from segment_anything import build_sam, SamPredictor, SamAutomaticMaskGenerator
|
107 |
|
|
|
197 |
|
198 |
|
199 |
def undo_button(orig_img, sel_pix):
|
200 |
+
if orig_img:
|
201 |
+
temp = orig_img.copy()
|
202 |
+
temp = np.array(temp, dtype=np.uint8)
|
203 |
+
if len(sel_pix) != 0:
|
204 |
+
sel_pix.pop()
|
205 |
+
for point in sel_pix:
|
206 |
+
cv2.drawMarker(temp, point, colors[0], markerType=markers[0], markerSize=6, thickness=2)
|
207 |
+
return Image.fromarray(temp).convert("RGB")
|
208 |
+
return orig_img
|
209 |
+
|
210 |
|
211 |
def clear_button(orig_img):
|
212 |
|
|
|
265 |
runner.ema = ExponentialMovingAverage(
|
266 |
runner.net.parameters(), decay=0.99) # re-init ema with fp16 weight
|
267 |
|
268 |
+
logger.info(f"I2SB Loading time:\t {(time.time()-s)*1e3} ms.")
|
269 |
print("Loading time:", (time.time()-s)*1e3, "ms.")
|
270 |
i2sb_model = runner
|
271 |
return runner
|
272 |
|
273 |
+
def load_segformer(device):
|
274 |
+
global segformer_model
|
275 |
+
s = time.time()
|
276 |
+
config = "SegFormer/local_configs/segformer/B3/segformer.b3.256x256.wtm.160k.py"
|
277 |
+
checkpoint = "SegFormer/work_dirs/segformer.b3.256x256.wtm.160k/iter_160000.pth"
|
278 |
+
model = init_segmentor(config, checkpoint, device=device)
|
279 |
+
|
280 |
+
logger.info(f"SegFormer Loading time:\t {(time.time()-s)*1e3} ms.")
|
281 |
+
segformer_model = model
|
282 |
+
return model
|
283 |
+
|
284 |
def plot_boxes_to_image(image_pil, tgt):
|
285 |
H, W = tgt["size"]
|
286 |
boxes = tgt["boxes"]
|
|
|
347 |
return image_pil, image
|
348 |
|
349 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
350 |
def show_mask(mask, ax, random_color=False):
|
351 |
if random_color:
|
352 |
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
|
|
432 |
)
|
433 |
sd_model = sd_model.to(device)
|
434 |
|
435 |
+
def forward_i2sb(img, mask, dilation_mask_extend):
|
436 |
+
|
437 |
+
|
438 |
+
print(np.unique(mask),mask.shape)
|
439 |
mask = np.where(mask > 0, 1, 0)
|
440 |
+
print(np.unique(mask),mask.shape)
|
441 |
+
mask = mask.astype(np.uint8)
|
442 |
+
if dilation_mask_extend.isdigit():
|
443 |
+
|
444 |
+
kernel_size = int(dilation_mask_extend)
|
445 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (int(kernel_size), int(kernel_size)))
|
446 |
+
mask = cv2.dilate(mask, kernel, iterations = 1)
|
447 |
+
|
448 |
img_tensor = i2sb_transforms(img).to(
|
449 |
i2sb_opt.device).unsqueeze(0)
|
450 |
|
451 |
mask_tensor = torch.from_numpy(np.resize(np.array(mask), (256,256))).to(
|
452 |
i2sb_opt.device).unsqueeze(0).unsqueeze(0)
|
453 |
+
# print("POST PROCESSING\t", torch.unique(img_tensor))
|
454 |
+
corrupt_tensor = img_tensor * (1. - mask_tensor) + mask_tensor
|
455 |
+
print("DOUBLE CHECK:\t", corrupt_tensor.shape)
|
456 |
+
print("DOUBLE CHECK:\t", img_tensor.shape)
|
457 |
+
print("DOUBLE CHECK:\t", mask_tensor.shape)
|
458 |
f = time.time()
|
459 |
xs, _ = i2sb_model.ddpm_sampling(
|
460 |
ckpt_opt, img_tensor, mask=mask_tensor, cond=None, clip_denoise=i2sb_opt.clip_denoise, nfe=nfe, verbose=i2sb_opt.n_gpu_per_node == 1)
|
461 |
recon_img = xs[:, 0, ...].to(i2sb_opt.device)
|
462 |
+
# tu.save_image((recon_img+1)/2, "output.png")
|
463 |
+
# tu.save_image((corrupt_tensor+1)/2, "output.png")
|
464 |
print(recon_img.shape)
|
465 |
+
return transforms.ToPILImage()(((recon_img+1)/2)[0]), transforms.ToPILImage()(((corrupt_tensor+1)/2)[0])
|
466 |
|
467 |
+
def forward_segformer(img):
|
468 |
+
img_np = np.array(img)
|
469 |
+
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
470 |
|
471 |
+
result = inference_segmentor(segformer_model, img_np)
|
472 |
+
|
473 |
+
return np.asarray(result[0], dtype=np.uint8)
|
|
|
|
|
|
|
|
|
474 |
|
475 |
# visualization
|
476 |
def draw_selected_mask(mask, draw):
|
|
|
563 |
return run_task_time, time_cost_str
|
564 |
|
565 |
def run_anything_task(input_image, input_points, origin_image, task_type,
|
566 |
+
mask_source_radio, segmentation_radio, dilation_mask_extend):
|
567 |
|
568 |
run_task_time = 0
|
569 |
time_cost_str = ''
|
570 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
571 |
print("HERE................", task_type)
|
572 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
573 |
if input_image is None:
|
574 |
+
return [], gr.Gallery.update(label='Please upload a image!😂😂😂😂'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !=''))
|
575 |
|
576 |
file_temp = int(time.time())
|
577 |
logger.info(f'run_anything_task_002/{device}_[{file_temp}]_{task_type}/[{mask_source_radio}]_1_')
|
|
|
601 |
groundingdino_device = 'cpu'
|
602 |
|
603 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_')
|
604 |
+
if task_type == 'segment' or task_type == 'pipeline':
|
605 |
+
image = np.array(origin_image)
|
606 |
+
if segmentation_radio == "SAM":
|
607 |
+
if sam_predictor:
|
608 |
+
sam_predictor.set_image(image)
|
609 |
+
|
610 |
+
if sam_predictor:
|
611 |
+
logger.info(f"Forward with: {input_points}")
|
612 |
+
masks, _, _, _ = sam_predictor.predict(
|
613 |
+
point_coords = np.array(input_points),
|
614 |
+
point_labels = np.array([1 for _ in range(len(input_points))]),
|
615 |
+
# boxes = transformed_boxes,
|
616 |
+
multimask_output = False,
|
617 |
+
)
|
618 |
+
# masks: [9, 1, 512, 512]
|
619 |
+
assert sam_checkpoint, 'sam_checkpoint is not found!'
|
620 |
+
else:
|
621 |
+
run_mode = "rectangle"
|
622 |
+
|
623 |
+
# draw output image
|
624 |
+
plt.figure(figsize=(10, 10))
|
625 |
+
plt.imshow(origin_image)
|
626 |
+
for mask in masks:
|
627 |
+
show_mask(mask, plt.gca(), random_color=True)
|
628 |
+
# for box, label in zip(boxes_filt, pred_phrases):
|
629 |
+
# show_box(box.cpu().numpy(), plt.gca(), label)
|
630 |
+
plt.axis('off')
|
631 |
+
image_path = os.path.join(output_dir, f"grounding_seg_output_{file_temp}.jpg")
|
632 |
+
plt.savefig(image_path, bbox_inches="tight")
|
633 |
+
plt.clf()
|
634 |
+
plt.close('all')
|
635 |
+
segment_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
|
636 |
+
os.remove(image_path)
|
637 |
+
|
638 |
else:
|
639 |
+
masks = forward_segformer(image)
|
640 |
+
|
641 |
+
segment_image_result = visualize_result_pyplot(segformer_model, image, masks, get_palette("wtm"), dilation=dilation_mask_extend)# if task_type == "pipeline" else None)
|
642 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
643 |
output_images.append(Image.fromarray(segment_image_result))
|
644 |
+
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
645 |
+
|
646 |
|
647 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_3_')
|
648 |
if task_type == 'detection' or task_type == 'segment':
|
649 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
|
650 |
+
return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !=''))
|
651 |
+
elif task_type in ['inpainting', 'outpainting'] or task_type == 'pipeline':
|
|
|
|
|
652 |
|
653 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_4_')
|
654 |
+
if task_type == "pipeline":
|
655 |
+
if segmentation_radio == "SAM":
|
656 |
+
masks_ori = copy.deepcopy(masks)
|
657 |
+
print(masks.shape)
|
658 |
+
# masks = torch.where(masks > 0, True, False)
|
659 |
+
mask = masks[0]
|
660 |
+
mask_pil = Image.fromarray(mask)
|
661 |
+
mask = np.where(mask == True, 1, 0)
|
662 |
+
else:
|
663 |
+
mask = masks
|
664 |
+
save_mask = copy.deepcopy(mask)
|
665 |
+
save_mask = np.where(mask > 0, 255, 0).astype(np.uint8)
|
666 |
+
print((save_mask.dtype))
|
667 |
+
mask_pil = Image.fromarray(save_mask)
|
668 |
+
|
669 |
else:
|
670 |
+
if mask_source_radio == mask_source_draw:
|
671 |
+
input_mask_pil = input_image['mask']
|
672 |
+
input_mask = np.array(input_mask_pil.convert("L"))
|
673 |
+
mask_pil = input_mask_pil
|
674 |
+
mask = input_mask
|
675 |
+
else:
|
676 |
+
pass
|
677 |
+
# masks_ori = copy.deepcopy(masks)
|
678 |
+
# masks = torch.where(masks > 0, True, False)
|
679 |
+
# mask = masks[0][0].cpu().numpy()
|
680 |
+
# mask_pil = Image.fromarray(mask)
|
681 |
output_images.append(mask_pil.convert("RGB"))
|
682 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
683 |
|
684 |
+
if task_type in ['inpainting', 'pipeline']:
|
685 |
# image_inpainting = sd_model(prompt = "", image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
|
686 |
+
# input_img.save("test.png")
|
687 |
+
w, h = input_img.size
|
688 |
+
input_img = input_img.resize((256,256))
|
689 |
+
image_inpainting, corrupted = forward_i2sb(input_img, mask, dilation_mask_extend)
|
690 |
+
input_img = input_img.resize((w,h))
|
691 |
+
corrupted = corrupted.resize((w,h))
|
692 |
+
image_inpainting = image_inpainting.resize((w,h))
|
693 |
+
# print("RESULT\t", np.array(image_inpainting))
|
694 |
else:
|
695 |
# remove from mask
|
696 |
aasds = 1
|
697 |
|
698 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_6_')
|
|
|
699 |
if image_inpainting is None:
|
700 |
logger.info(f'run_anything_task_failed_')
|
701 |
+
return None, None, None, None
|
702 |
|
703 |
# output_images.append(image_inpainting)
|
704 |
# run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
705 |
|
706 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_7_')
|
707 |
image_inpainting = image_inpainting.resize((image_pil.size[0], image_pil.size[1]))
|
708 |
+
output_images.append(corrupted)
|
709 |
output_images.append(image_inpainting)
|
710 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
711 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
|
712 |
+
return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !=''))
|
713 |
else:
|
714 |
logger.info(f"task_type:{task_type} error!")
|
715 |
logger.info(f'run_anything_task_[{file_temp}]_9_9_')
|
716 |
+
return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !=''))
|
717 |
|
718 |
def change_radio_display(task_type, mask_source_radio, orig_img):
|
719 |
mask_source_radio_visible = False
|
|
|
735 |
mask_source_radio_visible = True
|
736 |
if task_type == "relate anything":
|
737 |
num_relation_visible = True
|
738 |
+
if task_type == "inpainting":
|
|
|
|
|
739 |
ret = gr.Image(value = orig_img, elem_id="image_upload", type='pil', label="Upload", height=512, tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6)
|
740 |
+
elif task_type in ["segment", "pipeline"]:
|
741 |
+
ret = gr.Image(value= orig_img, elem_id="image_upload", type='pil', label="Upload", height=512, tool = "editor")# tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6)
|
742 |
|
743 |
return (gr.Radio.update(visible=mask_source_radio_visible),
|
744 |
gr.Slider.update(visible=num_relation_visible),
|
745 |
gr.Gallery.update(visible=image_gallery_visible),
|
746 |
+
gr.Radio(["SegFormer", "SAM"], value="SAM", label="Segementation Model", visible= task_type != "inpainting"),
|
747 |
+
gr.Textbox(label="Dilation kernel size", value='7', visible= task_type == "pipeline"),
|
|
|
748 |
ret, [],
|
749 |
+
gr.Button("Undo point", visible = task_type != "inpainting"),
|
750 |
+
gr.Button("Clear point", visible = task_type != "inpainting"),)
|
751 |
|
752 |
def get_model_device(module):
|
753 |
try:
|
|
|
777 |
with gr.Row():
|
778 |
with gr.Column():
|
779 |
selected_points = gr.State([])
|
780 |
+
original_image = gr.State(None)
|
781 |
task_types = ["segment"]
|
782 |
if inpainting_enable:
|
783 |
task_types.append("inpainting")
|
784 |
+
task_types.append("pipeline")
|
785 |
|
786 |
|
787 |
input_image = gr.Image(elem_id="image_upload", type='pil', label="Upload", height=512)
|
|
|
800 |
with gr.Row():
|
801 |
with gr.Column():
|
802 |
|
803 |
+
undo_point_button = gr.Button("Undo point", visible= True if original_image is not None else False)
|
804 |
undo_point_button.click(
|
805 |
fn= undo_button,
|
806 |
inputs=[original_image, selected_points],
|
|
|
809 |
|
810 |
with gr.Column():
|
811 |
|
812 |
+
clear_point_button = gr.Button("Clear point", visible= True if original_image is not None else False)
|
813 |
clear_point_button.click(
|
814 |
fn= clear_button,
|
815 |
inputs=[original_image],
|
|
|
822 |
mask_source_radio = gr.Radio([mask_source_draw, mask_source_segment],
|
823 |
value=mask_source_draw, label="Mask from",
|
824 |
visible=False)
|
825 |
+
|
826 |
+
segmentation_radio = gr.Radio(["SegFormer", "SAM"],
|
827 |
+
value="SAM", label="Segementation Model",
|
828 |
+
visible=True)
|
829 |
+
|
830 |
+
dilation_mask_extend = gr.Textbox(label="Dilation kernel size", value='5', visible=False)
|
831 |
+
|
832 |
num_relation = gr.Slider(label="How many relations do you want to see", minimum=1, maximum=20, value=5, step=1, visible=False)
|
833 |
|
|
|
|
|
834 |
run_button = gr.Button(label="Run", visible=True)
|
835 |
# with gr.Accordion("Advanced options", open=False) as advanced_options:
|
836 |
# box_threshold = gr.Slider(
|
|
|
851 |
|
852 |
with gr.Column():
|
853 |
image_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", height=512, visible=True
|
854 |
+
).style(preview=True, columns=[5], object_fit="scale-down", height=512)
|
855 |
time_cost = gr.Textbox(label="Time cost by step (ms):", visible=False, interactive=False)
|
856 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
857 |
|
|
|
|
|
|
|
|
|
|
|
|
|
858 |
|
859 |
run_button.click(fn=run_anything_task, inputs=[
|
860 |
input_image, selected_points, original_image, task_type,
|
861 |
+
mask_source_radio, segmentation_radio, dilation_mask_extend],
|
862 |
+
outputs=[image_gallery, image_gallery, time_cost, time_cost], show_progress=True, queue=True)
|
863 |
|
864 |
mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio, original_image],
|
865 |
outputs=[mask_source_radio, num_relation])
|
866 |
task_type.change(fn=change_radio_display, inputs=[task_type, mask_source_radio, original_image],
|
867 |
outputs=[mask_source_radio, num_relation,
|
868 |
+
image_gallery, segmentation_radio, dilation_mask_extend, input_image, selected_points, undo_point_button, clear_point_button
|
869 |
])
|
870 |
|
871 |
# DESCRIPTION = f'### This demo from [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything). <br>'
|
|
|
910 |
|
911 |
if sam_enable:
|
912 |
load_sam_model(device)
|
913 |
+
load_segformer(device)
|
914 |
|
915 |
if inpainting_enable:
|
916 |
load_sd_model(device)
|