cute / app.py
s-tian's picture
Fix typo
244a848
raw
history blame
5.83 kB
import torch
import torchvision
from PIL import Image
from pathlib import Path
import os
import numpy as np
from carvekit.api.high import HiInterface
import gradio as gr
import torch
class PlatonicDistanceModel(torch.nn.Module):
def __init__(self, device, carvekit_object_type="object"):
"""
:param device: string or torch.device object to run the model on.
:param carvekit_object_type: object type for foreground segmentation. Can be "object" or "hairs-like".
We find that "object" works well for most images in the CUTE dataset as well as vehicle ReID.
"""
super().__init__()
self.device = device
self.encoder = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')
self.encoder.to(self.device)
self.interface = HiInterface(object_type=carvekit_object_type, # Can be "object" or "hairs-like".
batch_size_seg=5,
batch_size_matting=1,
device=str(self.device), # HIInterface requires a string device.
seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
matting_mask_size=2048,
trimap_prob_threshold=231,
trimap_dilation=30,
trimap_erosion_iters=5,
fp16=False)
def preprocess(self, x_list):
preprocessed_images = []
for x in x_list:
# width, height = x.size
new_width = 336
new_height = 336
def _to_rgb(x):
if x.mode != "RGB":
x = x.convert("RGB")
return x
preprocessed_image = torchvision.transforms.Compose([
_to_rgb,
torchvision.transforms.Resize((new_height, new_width), interpolation=Image.BICUBIC),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])(x)
preprocessed_images.append(preprocessed_image)
return torch.stack(preprocessed_images, dim=0).to(self.device)
def get_foreground_mask(self, tensor_imgs):
masks = []
for tensor_img in tensor_imgs:
tensor_img = tensor_img.detach().cpu()
numpy_img_sum = tensor_img.sum(dim=0).numpy()
min_value = np.min(numpy_img_sum)
mask = ~(numpy_img_sum == min_value)
mask = mask.astype(np.uint8)
mask = Image.fromarray(mask * 255)
resized_mask = mask.resize((24, 24), Image.BILINEAR)
resized_mask_numpy = np.array(resized_mask)
resized_mask_numpy = resized_mask_numpy / 255.0
tensor_mask = torch.from_numpy(resized_mask_numpy.astype(np.float32))
tensor_mask[tensor_mask > 0.5] = 1.0
tensor_mask = tensor_mask.unsqueeze(0).long().to(self.device)
if tensor_mask.sum() == 0:
tensor_mask = torch.ones_like(tensor_mask)
masks.append(tensor_mask)
return torch.stack(masks, dim=0)
def forward(self, variant, *x):
if len(x) == 1 and (isinstance(x[0], list) or isinstance(x[0], torch.Tensor)):
return self.forward_single(x[0], variant)
elif len(x) == 1:
return self.forward_single([x[0]], variant)
elif len(x) == 2:
return torch.cosine_similarity(self.forward_single(x[0], variant)[0], self.forward_single(x[1], variant)[0], dim=0).cpu().item()
else:
raise ValueError("Invalid number of inputs, only 1 or 2 inputs are supported.")
def forward_single(self, x_list, variant):
with torch.no_grad():
original_sizes = [(x.size[1], x.size[0]) for x in x_list]
img_list = [np.array(self.interface([x])[0]) for x in x_list]
for img in img_list:
img[img[..., 3] == 0] = [0, 0, 0, 0]
img_list = [Image.fromarray(img) for img in img_list]
preprocessed_imgs = self.preprocess(img_list)
masks = self.get_foreground_mask(preprocessed_imgs)
if variant == "Crop-Feat":
emb = self.encoder.forward_features(preprocessed_imgs)
elif variant == "Crop-Img":
emb = self.encoder.forward_features(self.preprocess(x_list))
else:
raise ValueError("Invalid variant, only Crop-Feat and Crop-Img are supported.")
grid = emb["x_norm_patchtokens"].view(len(x_list), 24, 24, -1)
return (grid * masks.permute(0, 2, 3, 1)).sum(dim=(1, 2)) / masks.sum(dim=(1, 2, 3)).unsqueeze(-1)
def compare(image_1, image_2, variant):
similarity_score = model(variant, [image_1], [image_2])
return f"The similarity score is: {similarity_score:.2f}"
device = "cuda" if torch.cuda.is_available() else "cpu"
model = PlatonicDistanceModel(device)
demo = gr.Interface(title="Foreground Feature Averaging (FFA) Intrinsic Object Similarity Demo",
description="Compare two images using the foreground feature averaging metric, a strong baseline for intrinsic object similarity. Please see our project website at https://s-tian.github.io/projects/cute/ for more information.",
fn=compare,
inputs=[gr.Image(type="pil", label="Image 1"),
gr.Image(type="pil", label="Image 2"),
gr.Radio(choices=["Crop-Feat", "Crop-Img"], value="Crop-Feat", label="Variant (use Crop-Feat if not sure)")],
outputs="text")
if __name__ == "__main__":
demo.launch()