huy-ha's picture
allow gpu
5cbb0f4
raw
history blame
13.4 kB
from .clip import *
from .clip_gradcam import ClipGradcam
import torch
import numpy as np
from PIL import Image
import torchvision
from functools import reduce
def factors(n):
return set(
reduce(
list.__add__,
([i, n // i] for i in range(1, int(n**0.5) + 1) if n % i == 0),
)
)
saliency_configs = {
"ours": lambda img_dim: {
"distractor_labels": {},
"horizontal_flipping": True,
"augmentations": 5,
"imagenet_prompt_ensemble": False,
"positive_attn_only": True,
"cropping_augmentations": [
{"tile_size": img_dim, "stride": img_dim // 4},
{"tile_size": int(img_dim * 2 / 3), "stride": int(img_dim * 2 / 3) // 4},
{"tile_size": img_dim // 2, "stride": (img_dim // 2) // 4},
{"tile_size": img_dim // 4, "stride": (img_dim // 4) // 4},
],
},
"ours_fast": lambda img_dim: {
"distractor_labels": {},
"horizontal_flipping": True,
"augmentations": 2,
"imagenet_prompt_ensemble": False,
"positive_attn_only": True,
"cropping_augmentations": [
{"tile_size": img_dim, "stride": img_dim // 4},
{"tile_size": int(img_dim * 2 / 3), "stride": int(img_dim * 2 / 3) // 4},
{"tile_size": img_dim // 2, "stride": (img_dim // 2) // 4},
],
},
"chefer_et_al": lambda img_dim: {
"distractor_labels": {},
"horizontal_flipping": False,
"augmentations": 0,
"imagenet_prompt_ensemble": False,
"positive_attn_only": True,
"cropping_augmentations": [{"tile_size": img_dim, "stride": img_dim // 4}],
},
}
class ClipWrapper:
# SINGLETON WRAPPER
clip_model = None
clip_preprocess = None
clip_gradcam = None
lavt = None
device = None
jittering_transforms = None
def __init__(self, clip_model_type, device, **kwargs):
ClipWrapper.device = device
ClipWrapper.jittering_transforms = torchvision.transforms.ColorJitter(
brightness=0.6, contrast=0.6, saturation=0.6, hue=0.1
)
ClipWrapper.clip_model, ClipWrapper.clip_preprocess = load(
clip_model_type, ClipWrapper.device, **kwargs
)
ClipWrapper.clip_gradcam = ClipGradcam(
clip_model_name=clip_model_type,
classes=[""],
templates=["{}"],
device=ClipWrapper.device,
**kwargs
)
@classmethod
def check_initialized(cls, clip_model_type="ViT-B/32", **kwargs):
if cls.clip_gradcam is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
ClipWrapper(clip_model_type=clip_model_type, device=device, **kwargs)
print("using", device)
@classmethod
def get_clip_text_feature(cls, string):
ClipWrapper.check_initialized()
with torch.no_grad():
return (
cls.clip_model.encode_text(
tokenize(string, context_length=77).to(cls.device)
)
.squeeze()
.cpu()
.numpy()
)
@classmethod
def get_visual_feature(cls, rgb, tile_attn_mask, device=None):
if device is None:
device = ClipWrapper.device
ClipWrapper.check_initialized()
rgb = ClipWrapper.clip_preprocess(Image.fromarray(rgb)).unsqueeze(0)
with torch.no_grad():
clip_feature = ClipWrapper.clip_model.encode_image(
rgb.to(ClipWrapper.device), tile_attn_mask=tile_attn_mask
).squeeze()
return clip_feature.to(device)
@classmethod
def get_clip_saliency(
cls,
img,
text_labels,
prompts,
distractor_labels=set(),
use_lavt=False,
**kwargs
):
cls.check_initialized()
if use_lavt:
return cls.lavt.localize(img=img, prompts=text_labels)
cls.clip_gradcam.templates = prompts
cls.clip_gradcam.set_classes(text_labels)
text_label_features = torch.stack(
list(cls.clip_gradcam.class_to_language_feature.values()), dim=0
)
text_label_features = text_label_features.squeeze(dim=-1).cpu()
text_maps = cls.get_clip_saliency_convolve(
img=img, text_labels=text_labels, **kwargs
)
if len(distractor_labels) > 0:
distractor_labels = set(distractor_labels) - set(text_labels)
cls.clip_gradcam.set_classes(list(distractor_labels))
distractor_maps = cls.get_clip_saliency_convolve(
img=img, text_labels=list(distractor_labels), **kwargs
)
text_maps -= distractor_maps.mean(dim=0)
text_maps = text_maps.cpu()
return text_maps, text_label_features.squeeze(dim=-1)
@classmethod
def get_clip_saliency_convolve(
cls,
text_labels,
horizontal_flipping=False,
positive_attn_only: bool = False,
tile_batch_size=32,
prompt_batch_size=32,
tile_interpolate_batch_size=16,
**kwargs
):
cls.clip_gradcam.positive_attn_only = positive_attn_only
tiles, tile_imgs, counts, tile_sizes = cls.create_tiles(**kwargs)
outputs = {
k: torch.zeros(
[len(text_labels)] + list(count.shape), device=cls.device
).half()
for k, count in counts.items()
}
tile_gradcams = torch.cat(
[
torch.cat(
[
cls.clip_gradcam(
x=tile_imgs[tile_idx : tile_idx + tile_batch_size],
o=text_labels[prompt_idx : prompt_idx + prompt_batch_size],
)
for tile_idx in np.arange(0, len(tile_imgs), tile_batch_size)
],
dim=1,
)
for prompt_idx in np.arange(0, len(text_labels), prompt_batch_size)
],
dim=0,
)
if horizontal_flipping:
flipped_tile_imgs = tile_imgs[
..., torch.flip(torch.arange(0, tile_imgs.shape[-1]), dims=[0])
]
flipped_tile_gradcams = torch.cat(
[
torch.cat(
[
cls.clip_gradcam(
x=flipped_tile_imgs[
tile_idx : tile_idx + tile_batch_size
],
o=text_labels[
prompt_idx : prompt_idx + prompt_batch_size
],
)
for tile_idx in np.arange(
0, len(tile_imgs), tile_batch_size
)
],
dim=1,
)
for prompt_idx in np.arange(0, len(text_labels), prompt_batch_size)
],
dim=0,
)
with torch.no_grad():
flipped_tile_gradcams = flipped_tile_gradcams[
...,
torch.flip(
torch.arange(0, flipped_tile_gradcams.shape[-1]), dims=[0]
),
]
tile_gradcams = (tile_gradcams + flipped_tile_gradcams) / 2
del flipped_tile_gradcams
with torch.no_grad():
torch.cuda.empty_cache()
for tile_size in np.unique(tile_sizes):
tile_size_mask = tile_sizes == tile_size
curr_size_grads = tile_gradcams[:, tile_size_mask]
curr_size_tiles = tiles[tile_size_mask]
for tile_idx in np.arange(
0, curr_size_grads.shape[1], tile_interpolate_batch_size
):
resized_tiles = torch.nn.functional.interpolate(
curr_size_grads[
:, tile_idx : tile_idx + tile_interpolate_batch_size
],
size=tile_size,
mode="bilinear",
align_corners=False,
)
for tile_idx, tile_slice in enumerate(
curr_size_tiles[
tile_idx : tile_idx + tile_interpolate_batch_size
]
):
outputs[tile_size][tile_slice] += resized_tiles[
:, tile_idx, ...
]
output = sum(
output.float() / count
for output, count in zip(outputs.values(), counts.values())
) / len(counts)
del outputs, counts, tile_gradcams
output = output.cpu()
return output
@classmethod
def create_tiles(cls, img, augmentations, cropping_augmentations, **kwargs):
assert type(img) == np.ndarray
images = []
cls.check_initialized()
# compute image crops
img_pil = Image.fromarray(img)
images.append(np.array(img_pil))
for _ in range(augmentations):
images.append(np.array(cls.jittering_transforms(img_pil)))
# for taking average
counts = {
crop_aug["tile_size"]: torch.zeros(img.shape[:2], device=cls.device).float()
+ 1e-5
for crop_aug in cropping_augmentations
}
tiles = []
tile_imgs = []
tile_sizes = []
for img in images:
for crop_aug in cropping_augmentations:
tile_size = crop_aug["tile_size"]
stride = crop_aug["stride"]
for y in np.arange(0, img.shape[1] - tile_size + 1, stride):
if y >= img.shape[0]:
continue
for x in np.arange(0, img.shape[0] - tile_size + 1, stride):
if x >= img.shape[1]:
continue
tile = (
slice(None, None),
slice(x, x + tile_size),
slice(y, y + tile_size),
)
tiles.append(tile)
counts[tile_size][tile[1:]] += 1
tile_sizes.append(tile_size)
# this is currently biggest bottle neck
tile_imgs.append(
cls.clip_gradcam.preprocess(
Image.fromarray(img[tiles[-1][1:]])
)
)
tile_imgs = torch.stack(tile_imgs).to(cls.device)
return np.array(tiles), tile_imgs, counts, np.array(tile_sizes)
imagenet_templates = [
"a bad photo of a {}.",
"a photo of many {}.",
"a sculpture of a {}.",
"a photo of the hard to see {}.",
"a low resolution photo of the {}.",
"a rendering of a {}.",
"graffiti of a {}.",
"a bad photo of the {}.",
"a cropped photo of the {}.",
"a tattoo of a {}.",
"the embroidered {}.",
"a photo of a hard to see {}.",
"a bright photo of a {}.",
"a photo of a clean {}.",
"a photo of a dirty {}.",
"a dark photo of the {}.",
"a drawing of a {}.",
"a photo of my {}.",
"the plastic {}.",
"a photo of the cool {}.",
"a close-up photo of a {}.",
"a black and white photo of the {}.",
"a painting of the {}.",
"a painting of a {}.",
"a pixelated photo of the {}.",
"a sculpture of the {}.",
"a bright photo of the {}.",
"a cropped photo of a {}.",
"a plastic {}.",
"a photo of the dirty {}.",
"a jpeg corrupted photo of a {}.",
"a blurry photo of the {}.",
"a photo of the {}.",
"a good photo of the {}.",
"a rendering of the {}.",
"a {} in a video game.",
"a photo of one {}.",
"a doodle of a {}.",
"a close-up photo of the {}.",
"a photo of a {}.",
"the origami {}.",
"the {} in a video game.",
"a sketch of a {}.",
"a doodle of the {}.",
"a origami {}.",
"a low resolution photo of a {}.",
"the toy {}.",
"a rendition of the {}.",
"a photo of the clean {}.",
"a photo of a large {}.",
"a rendition of a {}.",
"a photo of a nice {}.",
"a photo of a weird {}.",
"a blurry photo of a {}.",
"a cartoon {}.",
"art of a {}.",
"a sketch of the {}.",
"a embroidered {}.",
"a pixelated photo of a {}.",
"itap of the {}.",
"a jpeg corrupted photo of the {}.",
"a good photo of a {}.",
"a plushie {}.",
"a photo of the nice {}.",
"a photo of the small {}.",
"a photo of the weird {}.",
"the cartoon {}.",
"art of the {}.",
"a drawing of the {}.",
"a photo of the large {}.",
"a black and white photo of a {}.",
"the plushie {}.",
"a dark photo of a {}.",
"itap of a {}.",
"graffiti of the {}.",
"a toy {}.",
"itap of my {}.",
"a photo of a cool {}.",
"a photo of a small {}.",
"a tattoo of the {}.",
]
__all__ = ["ClipWrapper", "imagenet_templates"]