import gradio as gr import os import cv2 from PIL import Image import numpy as np from animeinsseg import AnimeInsSeg, AnimeInstances from animeinsseg.anime_instances import get_color from ccip import _VALID_MODEL_NAMES, _DEFAULT_MODEL_NAMES, ccip_difference, ccip_default_threshold from datasets import load_dataset import pathlib # 安装必要的库 os.system("mim install mmengine") os.system('mim install mmcv==2.1.0') os.system("mim install mmdet==3.2.0") # 加载模型 if not os.path.exists("models"): os.mkdir("models") os.system("huggingface-cli lfs-enable-largefiles .") os.system("git clone https://huggingface.co./dreMaz/AnimeInstanceSegmentation models/AnimeInstanceSegmentation") ckpt = r'models/AnimeInstanceSegmentation/rtmdetl_e60.ckpt' mask_thres = 0.3 instance_thres = 0.3 refine_kwargs = {'refine_method': 'refinenet_isnet'} # set to None if not using refinenet # refine_kwargs = None net = AnimeInsSeg(ckpt, mask_thr=mask_thres, refine_kwargs=refine_kwargs) # 加载数据集 Genshin_Impact_Illustration_ds = load_dataset("svjack/Genshin-Impact-Illustration")["train"] ds_size = len(Genshin_Impact_Illustration_ds) name_image_dict = {} for i in range(ds_size): row_dict = Genshin_Impact_Illustration_ds[i] name_image_dict[row_dict["name"]] = row_dict["image"] # 从数据集中选择一些图片作为示例 example_images = list(map(str, list(pathlib.Path(".").rglob("*.png")))) def fn(image, model_name): img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) instances: AnimeInstances = net.infer( img, output_type='numpy', pred_score_thr=instance_thres ) drawed = img.copy() im_h, im_w = img.shape[:2] # instances.bboxes, instances.masks will be None, None if no obj is detected if instances.bboxes is None: return Image.fromarray(drawed[..., ::-1]), "No instances detected" # 用于存储每个 bbox 的 top5 结果 top5_results = [] for ii, (xywh, mask) in enumerate(zip(instances.bboxes, instances.masks)): color = get_color(ii) mask_alpha = 0.5 linewidth = max(round(sum(img.shape) / 2 * 0.003), 2) # 提取 bbox 区域 x1, y1, w, h = map(int, xywh) x2, y2 = x1 + w, y1 + h bbox_image = img[y1:y2, x1:x2] # 计算相似度 threshold = ccip_default_threshold(model_name) results = [] for name, imagey in name_image_dict.items(): # 将数据集中的图片调整为与 bbox 区域相同的大小 imagey_resized = cv2.resize(imagey, (w, h)) diff = ccip_difference(bbox_image, imagey_resized) result = (diff, 'Same' if diff <= threshold else 'Not Same', name) results.append(result) # 按照 diff 值进行排序 results.sort(key=lambda x: x[0]) top5_results.append(results[:5]) # 取 top5 结果 # 绘制 bbox p1, p2 = (x1, y1), (x2, y2) cv2.rectangle(drawed, p1, p2, color, thickness=linewidth, lineType=cv2.LINE_AA) # 绘制 mask p = mask.astype(np.float32) blend_mask = np.full((im_h, im_w, 3), color, dtype=np.float32) alpha_msk = (mask_alpha * p)[..., None] alpha_ori = 1 - alpha_msk drawed = drawed * alpha_ori + alpha_msk * blend_mask drawed = drawed.astype(np.uint8) # 创建调色盘图像 palette_height = 100 palette_width = im_w palette = np.zeros((palette_height, palette_width, 3), dtype=np.uint8) # 绘制每个 bbox 的 top5 结果 for idx, results in enumerate(top5_results): color = get_color(idx) x_start = idx * (palette_width // len(top5_results)) x_end = (idx + 1) * (palette_width // len(top5_results)) # 填充颜色 palette[:, x_start:x_end] = color # 在调色盘上绘制 top5 结果 for i, (diff, pred, name) in enumerate(results): text = f"{name}: {diff:.2f} ({pred})" y_pos = 20 + i * 15 cv2.putText(palette, text, (x_start + 10, y_pos), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1, cv2.LINE_AA) return Image.fromarray(drawed[..., ::-1]), Image.fromarray(palette) # 创建 Gradio 界面 iface = gr.Interface( # design titles and text descriptions title="Anime Subject Instance Segmentation with Similarity Comparison", description="Segment image subjects with the proposed model in the paper [*Instance-guided Cartoon Editing with a Large-scale Dataset*](https://cartoonsegmentation.github.io/).", fn=fn, inputs=[gr.Image(type="numpy"), gr.Dropdown(_VALID_MODEL_NAMES, value=_DEFAULT_MODEL_NAMES, label='Model')], outputs=[gr.Image(type="pil", label="Segmentation Result"), gr.Image(type="pil", label="Top5 Results Palette")], examples=example_images ) iface.launch(share=True)