import gradio as gr import matplotlib.pyplot as plt import numpy as np from PIL import Image import cv2 import matplotlib import clip from utils import * matplotlib.get_cachedir() plt.rc('font', family="Times New Roman") from sklearn import metrics import torch from torchvision import transforms from tqdm import tqdm from models.submodular_vit_efficient_plus import MultiModalSubModularExplanationEfficientPlus data_transform = transforms.Compose( [ transforms.Resize( (224,224), interpolation=transforms.InterpolationMode.BICUBIC ), # transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711), ), ] ) class CLIPModel_Super(torch.nn.Module): def __init__(self, type="ViT-L/14", download_root=None, device = "cuda"): super().__init__() self.device = device self.model, _ = clip.load(type, device=self.device, download_root=download_root) self.model = self.model.type(torch.float32) def forward(self, vision_inputs): """ Input: vision_inputs: torch.size([B,C,W,H]) Output: embeddings: a d-dimensional vector torch.size([B,d]) """ vision_inputs = vision_inputs.type(torch.float32) with torch.no_grad(): image_features = self.model.encode_image(vision_inputs) image_features /= image_features.norm(dim=-1, keepdim=True) return image_features def transform_vision_data(image): """ Input: image: An image read by opencv [w,h,c] Output: image: After preproccessing, is a tensor [c,w,h] """ image = Image.fromarray(image) image = data_transform(image) return image def zeroshot_classifier(model, classnames, templates, device): with torch.no_grad(): zeroshot_weights = [] for classname in tqdm(classnames): texts = [template.format(classname) for template in templates] #format with class texts = clip.tokenize(texts).to(device) #tokenize with torch.no_grad(): class_embeddings = model.model.encode_text(texts) class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) class_embedding = class_embeddings.mean(dim=0) class_embedding /= class_embedding.norm() zeroshot_weights.append(class_embedding) zeroshot_weights = torch.stack(zeroshot_weights).cuda() return zeroshot_weights*100 device = "cuda" if torch.cuda.is_available() else "cpu" # device = "cpu" # Instantiate model vis_model = CLIPModel_Super("ViT-B/16", device=device, download_root="./ckpt") vis_model.eval() vis_model.to(device) print("load clip model") semantic_path = "./clip_vitb_imagenet_zeroweights.pt" if os.path.exists(semantic_path): semantic_feature = torch.load(semantic_path, map_location="cpu") semantic_feature = semantic_feature.to(device) semantic_feature = semantic_feature.type(torch.float32) else: semantic_feature = zeroshot_classifier(vis_model, imagenet_classes, imagenet_templates, device) torch.save(semantic_feature, semantic_path) explainer = MultiModalSubModularExplanationEfficientPlus( vis_model, semantic_feature, transform_vision_data, device=device, lambda1=0.01, lambda2=0.05, lambda3=20., lambda4=5.) explainer.org_semantic_feature = semantic_feature def add_value_decrease(smdl_mask, json_file): single_mask = np.zeros_like(smdl_mask[0].mean(-1)) value_list_1 = np.array(json_file["consistency_score"]) + np.array(json_file["collaboration_score"]) value_list_2 = np.array([json_file["baseline_score"]] + json_file["consistency_score"][:-1]) + np.array([1 - json_file["org_score"]] + json_file["collaboration_score"][:-1]) value_list = value_list_1 - value_list_2 values = [] value = 0 for smdl_single_mask, smdl_value in zip(smdl_mask, value_list): value = value - abs(smdl_value) single_mask[smdl_single_mask.sum(-1)>0] = value values.append(value) attribution_map = single_mask - single_mask.min() attribution_map /= attribution_map.max() return attribution_map, np.array(values) def visualization(image, submodular_image_set, saved_json_file, index=None, compute_params=True): attribution_map, value_list = add_value_decrease(submodular_image_set, saved_json_file) vis_image, heatmap = gen_cam(image, norm_image(attribution_map)) insertion_ours_images = [] # deletion_ours_images = [] insertion_image = submodular_image_set[0] - submodular_image_set[0] insertion_ours_images.append(insertion_image) # deletion_ours_images.append(image - insertion_image) for smdl_sub_mask in submodular_image_set[:]: insertion_image = insertion_image.copy() + smdl_sub_mask insertion_ours_images.append(insertion_image) # deletion_ours_images.append(image - insertion_image) insertion_ours_images_input_results = np.array([1-saved_json_file["collaboration_score"][-1]] + saved_json_file["consistency_score"]) if index == None: ours_best_index = np.argmax(insertion_ours_images_input_results) else: ours_best_index = index x = [(insertion_ours_image.sum(-1)!=0).sum() / (image.shape[0] * image.shape[1]) for insertion_ours_image in insertion_ours_images] i = len(x) fig, [ax1, ax2, ax3] = plt.subplots(1,3, gridspec_kw = {'width_ratios':[1, 1, 1.5]}, figsize=(30,8)) ax1.spines["left"].set_visible(False) ax1.spines["right"].set_visible(False) ax1.spines["top"].set_visible(False) ax1.spines["bottom"].set_visible(False) ax1.xaxis.set_visible(False) ax1.yaxis.set_visible(False) ax1.set_title('Attribution Map', fontsize=54) ax1.set_facecolor('white') ax1.imshow(vis_image[...,::-1].astype(np.uint8)) ax2.spines["left"].set_visible(False) ax2.spines["right"].set_visible(False) ax2.spines["top"].set_visible(False) ax2.spines["bottom"].set_visible(False) ax2.xaxis.set_visible(True) ax2.yaxis.set_visible(False) ax2.set_title('Searched Region', fontsize=54) ax2.set_facecolor('white') ax2.set_xlabel("Confidence {:.4f}".format(insertion_ours_images_input_results[ours_best_index]), fontsize=44) ax2.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False) ax3.set_xlim((0, 1)) ax3.set_ylim((0, 1)) ax3.set_ylabel('Recognition Score', fontsize=44) ax3.set_xlabel('Percentage of image revealed', fontsize=44) ax3.tick_params(axis='both', which='major', labelsize=36) x_ = x[:i] ours_y = insertion_ours_images_input_results[:i] ax3.plot(x_, ours_y, color='dodgerblue', linewidth=3.5) # draw curve ax3.set_facecolor('white') ax3.spines['bottom'].set_color('black') ax3.spines['bottom'].set_linewidth(2.0) ax3.spines['top'].set_color('none') ax3.spines['left'].set_color('black') ax3.spines['left'].set_linewidth(2.0) ax3.spines['right'].set_color('none') # plt.legend(["Ours"], fontsize=40, loc="upper left") ax3.scatter(x_[-1], ours_y[-1], color='dodgerblue', s=54) # Plot latest point # 在曲线下方填充淡蓝色 ax3.fill_between(x_, ours_y, color='dodgerblue', alpha=0.1) kernel = np.ones((3, 3), dtype=np.uint8) # ax3.plot([x_[ours_best_index], x_[ours_best_index]], [0, 1], color='red', linewidth=3.5) # 绘制红色曲线 ax3.axvline(x=x_[int(ours_best_index)], color='red', linewidth=3.5) # 绘制红色垂直线 # Ours mask = (image - insertion_ours_images[int(ours_best_index)]).mean(-1) mask[mask>0] = 1 if int(ours_best_index) != 0: dilate = cv2.dilate(mask, kernel, 3) # erosion = cv2.erode(dilate, kernel, iterations=3) # dilate = cv2.dilate(erosion, kernel, 2) edge = dilate - mask # erosion = cv2.erode(dilate, kernel, iterations=1) image_debug = image.copy() image_debug[mask>0] = image_debug[mask>0] * 0.5 if int(ours_best_index) != 0: image_debug[edge>0] = np.array([255,0,0]) ax2.imshow(image_debug) if compute_params: auc = metrics.auc(x, insertion_ours_images_input_results) ax3.set_title('Insertion Curve', fontsize=54) fig.tight_layout() fig.canvas.draw() img_curve = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) img_curve = img_curve.reshape(fig.canvas.get_width_height()[::-1] + (3,)) plt.close(fig) # 关闭图形以释放资源 if compute_params: return img_curve, insertion_ours_images_input_results.max(), auc, ours_best_index else: return img_curve def gen_cam(image, mask): """ Generate heatmap :param image: [H,W,C] :param mask: [H,W],range 0-1 :return: tuple(cam,heatmap) """ # Read image # image = cv2.resize(cv2.imread(image_path), (224,224)) # mask->heatmap heatmap = cv2.applyColorMap(np.uint8(mask), cv2.COLORMAP_COOL) heatmap = np.float32(heatmap) # merge heatmap to original image cam = 0.5*heatmap + 0.5*np.float32(image) return cam, (heatmap).astype(np.uint8) def norm_image(image): """ Normalization image :param image: [H,W,C] :return: """ image = image.copy() image -= np.max(np.min(image), 0) image /= np.max(image) image *= 255. return np.uint8(image) def read_image(file_path): image = Image.open(file_path) image = image.convert("RGB") image = image.resize((512,512)) return np.array(image) # 使用同一个示例图像 "shark.png" default_images = { # "Default Image": read_image("images/shark.png"), "Example: tiger shark": read_image("images/shark.png"), "Example: quail": read_image("images/bird.png"), # 所有选项都使用相同的图片 "Example: tabby cat or lion": read_image("images/cat_lion.jpeg"), "Example: rabbit or duck": read_image("images/rabbit-duck.jpg"), } def interpret_image(uploaded_image, slider, text_input): # 使用上传的图像(如果有),否则使用生成的图像 if uploaded_image is not None: image = np.array(uploaded_image) else: return None, 0, 0 image = cv2.resize(image, (224, 224)) element_sets_V = SubRegionDivision(image, mode="slico", region_size=40) explainer.k = len(element_sets_V) global submodular_image_set global saved_json_file image_input = explainer.preproccessing_function(image).unsqueeze(0) predicted_class = (explainer.model(image_input.to(explainer.device)) @ explainer.semantic_feature.T).argmax().cpu().item() # input if text_input == "": target_id = predicted_class else: if text_input in imagenet_classes: target_id = imagenet_classes.index(text_input) else: target_id = -1 texts = [text_input] texts = clip.tokenize(texts).to(device) #tokenize with torch.no_grad(): class_embeddings = vis_model.model.encode_text(texts) class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) class_embeddings = class_embeddings.to(device) * 100 explainer.semantic_feature = torch.cat((explainer.org_semantic_feature, class_embeddings), dim=0) # global im submodular_image, submodular_image_set, saved_json_file = explainer(element_sets_V, id=target_id) # attribution_map, value_list = add_value_decrease(submodular_image_set, saved_json_file) # im, heatmap = gen_cam(image, norm_image(attribution_map)) image_curve, highest_confidence, insertion_auc_score, ours_best_index = visualization(image, submodular_image_set, saved_json_file, index=None) if target_id == -1: text_output_class = "This method explains that CLIP is interested in describing \"{}\".".format(text_input) else: text_output_class = "The method explains why the CLIP (ViT-B/16) model identifies an image as {}.".format(imagenet_classes[explainer.target_label]) text_output_predict = "The image is predicted as {}".format(imagenet_classes[predicted_class]) explainer.semantic_feature = explainer.org_semantic_feature return image_curve, highest_confidence, insertion_auc_score, text_output_class, text_output_predict, None def predict_image(uploaded_image): # 使用上传的图像(如果有),否则使用生成的图像 if uploaded_image is not None: image = np.array(uploaded_image) else: return None, 0, 0 image = cv2.resize(image, (224, 224)) image_input = explainer.preproccessing_function(image).unsqueeze(0) predicted_class = (explainer.model(image_input.to(explainer.device)) @ explainer.semantic_feature.T).argmax().cpu().item() text_output_predict = "The image is predicted as {}".format(imagenet_classes[predicted_class]) return text_output_predict def visualization_slider(uploaded_image, slider): # 使用上传的图像(如果有),否则使用生成的图像 if uploaded_image is not None: image = np.array(uploaded_image) else: return None, 0, 0 image = cv2.resize(image, (224, 224)) image_curve = visualization(image, submodular_image_set, saved_json_file, index=slider, compute_params=False) return image_curve def update_image(thumbnail_name): # 返回对应缩略图的图像数据 return default_images[thumbnail_name] # 创建 Gradio 界面 with gr.Blocks() as demo: gr.Markdown("# Semantic Region Attribution and Mistake Discovery via Submodular Subset Selection") # 使用Markdown添加标题 gr.Markdown("Since huggingface only has ordinary CPUs available, our sub-region division is relatively coarse-grained, which may affect the attribution performance. The inference time is about 5 minutes (GPU is about 4s). If you are interested, you can try our source code. We have written many scripts to facilitate visualization.") with gr.Row(): with gr.Column(): # 第一排:上传图像输入框和一个缩略图 with gr.Row(): # 上传图像输入框 image_input = gr.Image(label="Upload Image", type="numpy") # 第一个缩略图和按钮 with gr.Column(): # gr.Image(value=default_images["Default Image"], type="numpy") # button_default = gr.Button(value="Default Image") # button_default.click( # fn=lambda k="Default Image": update_image(k), # inputs=[], # outputs=image_input # ) gr.Textbox("Thank you for using our interpretable attribution method, which originates from the ICLR 2024 Oral paper titled \"Less is More: Fewer Interpretable Regions via Submodular Subset Selection.\" We have now implemented this method on the multimodal ViT model and achieved promising results in explaining model predictions. A key feature of our approach is its ability to clarify the reasons behind the model's prediction errors. We invite you to try out this demo and explore its capabilities. The source code is available at https://github.com/RuoyuChen10/SMDL-Attribution.\nYou can upload an image yourself or select one from the following, then click the button Interpreting Model to get the result. The demo currently does not support selecting categories or descriptions by yourself. If you are interested, you can try it from the source code.", label="Instructions for use", interactive=False) # 文本输入框和滑块 text_input = gr.Textbox(label="Text Input", placeholder="You can choose what you want to explain. You can enter a word (e.g., 'tabby cat') or a description (e.g., 'A photo of a tabby cat'). If the input is empty, the model will explain the predicted category.") # 第二排:两个缩略图 with gr.Row(): for key in default_images.keys(): with gr.Column(): gr.Image(value=default_images[key], type="numpy") button = gr.Button(value=key) button.click( fn=lambda k=key: update_image(k), inputs=[], outputs=image_input ) with gr.Column(): # 输出图像和控件 image_output = gr.Image(label="Output Image") slider = gr.Slider(minimum=0, maximum=34, step=1, label="Number of Sub-regions") text_output_predict = gr.Textbox(label="Predicted Category") text_output_class = gr.Textbox(label="Explaining Category") with gr.Row(): # 最高置信度和插入 AUC Score 并排显示 text_output_confidence = gr.Textbox(label="Highest Confidence") text_output_auc = gr.Textbox(label="Insertion AUC Score") with gr.Row(): predict_button = gr.Button("Model Inference") interpret_button = gr.Button("Interpreting Model") # 定义解释模型按钮点击事件 interpret_button.click( fn=interpret_image, inputs=[image_input, slider, text_input], outputs=[image_output, text_output_confidence, text_output_auc, text_output_class, text_output_predict, text_input] ) predict_button.click( fn=predict_image, inputs=[image_input], outputs=[text_output_predict] ) # 实时更新的滑块 slider.change( fn=visualization_slider, inputs=[image_input, slider], outputs=[image_output] ) # 启动 Gradio 应用 demo.launch()