Spaces:
Build error
Build error
from .clip import * | |
from .clip_gradcam import ClipGradcam | |
import torch | |
import numpy as np | |
from PIL import Image | |
import torchvision | |
from functools import reduce | |
def factors(n): | |
return set( | |
reduce( | |
list.__add__, | |
([i, n // i] for i in range(1, int(n**0.5) + 1) if n % i == 0), | |
) | |
) | |
saliency_configs = { | |
"ours": lambda img_dim: { | |
"distractor_labels": {}, | |
"horizontal_flipping": True, | |
"augmentations": 5, | |
"imagenet_prompt_ensemble": False, | |
"positive_attn_only": True, | |
"cropping_augmentations": [ | |
{"tile_size": img_dim, "stride": img_dim // 4}, | |
{"tile_size": int(img_dim * 2 / 3), "stride": int(img_dim * 2 / 3) // 4}, | |
{"tile_size": img_dim // 2, "stride": (img_dim // 2) // 4}, | |
{"tile_size": img_dim // 4, "stride": (img_dim // 4) // 4}, | |
], | |
}, | |
"ours_fast": lambda img_dim: { | |
"distractor_labels": {}, | |
"horizontal_flipping": True, | |
"augmentations": 2, | |
"imagenet_prompt_ensemble": False, | |
"positive_attn_only": True, | |
"cropping_augmentations": [ | |
{"tile_size": img_dim, "stride": img_dim // 4}, | |
{"tile_size": int(img_dim * 2 / 3), "stride": int(img_dim * 2 / 3) // 4}, | |
{"tile_size": img_dim // 2, "stride": (img_dim // 2) // 4}, | |
], | |
}, | |
"chefer_et_al": lambda img_dim: { | |
"distractor_labels": {}, | |
"horizontal_flipping": False, | |
"augmentations": 0, | |
"imagenet_prompt_ensemble": False, | |
"positive_attn_only": True, | |
"cropping_augmentations": [{"tile_size": img_dim, "stride": img_dim // 4}], | |
}, | |
} | |
class ClipWrapper: | |
# SINGLETON WRAPPER | |
clip_model = None | |
clip_preprocess = None | |
clip_gradcam = None | |
lavt = None | |
device = None | |
jittering_transforms = None | |
def __init__(self, clip_model_type, device, **kwargs): | |
ClipWrapper.device = device | |
ClipWrapper.jittering_transforms = torchvision.transforms.ColorJitter( | |
brightness=0.6, contrast=0.6, saturation=0.6, hue=0.1 | |
) | |
ClipWrapper.clip_model, ClipWrapper.clip_preprocess = load( | |
clip_model_type, ClipWrapper.device, **kwargs | |
) | |
ClipWrapper.clip_gradcam = ClipGradcam( | |
clip_model_name=clip_model_type, | |
classes=[""], | |
templates=["{}"], | |
device=ClipWrapper.device, | |
**kwargs | |
) | |
def check_initialized(cls, clip_model_type="ViT-B/32", **kwargs): | |
if cls.clip_gradcam is None: | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
ClipWrapper(clip_model_type=clip_model_type, device=device, **kwargs) | |
print("using", device) | |
def get_clip_text_feature(cls, string): | |
ClipWrapper.check_initialized() | |
with torch.no_grad(): | |
return ( | |
cls.clip_model.encode_text( | |
tokenize(string, context_length=77).to(cls.device) | |
) | |
.squeeze() | |
.cpu() | |
.numpy() | |
) | |
def get_visual_feature(cls, rgb, tile_attn_mask, device=None): | |
if device is None: | |
device = ClipWrapper.device | |
ClipWrapper.check_initialized() | |
rgb = ClipWrapper.clip_preprocess(Image.fromarray(rgb)).unsqueeze(0) | |
with torch.no_grad(): | |
clip_feature = ClipWrapper.clip_model.encode_image( | |
rgb.to(ClipWrapper.device), tile_attn_mask=tile_attn_mask | |
).squeeze() | |
return clip_feature.to(device) | |
def get_clip_saliency( | |
cls, | |
img, | |
text_labels, | |
prompts, | |
distractor_labels=set(), | |
use_lavt=False, | |
**kwargs | |
): | |
cls.check_initialized() | |
if use_lavt: | |
return cls.lavt.localize(img=img, prompts=text_labels) | |
cls.clip_gradcam.templates = prompts | |
cls.clip_gradcam.set_classes(text_labels) | |
text_label_features = torch.stack( | |
list(cls.clip_gradcam.class_to_language_feature.values()), dim=0 | |
) | |
text_label_features = text_label_features.squeeze(dim=-1).cpu() | |
text_maps = cls.get_clip_saliency_convolve( | |
img=img, text_labels=text_labels, **kwargs | |
) | |
if len(distractor_labels) > 0: | |
distractor_labels = set(distractor_labels) - set(text_labels) | |
cls.clip_gradcam.set_classes(list(distractor_labels)) | |
distractor_maps = cls.get_clip_saliency_convolve( | |
img=img, text_labels=list(distractor_labels), **kwargs | |
) | |
text_maps -= distractor_maps.mean(dim=0) | |
text_maps = text_maps.cpu() | |
return text_maps, text_label_features.squeeze(dim=-1) | |
def get_clip_saliency_convolve( | |
cls, | |
text_labels, | |
horizontal_flipping=False, | |
positive_attn_only: bool = False, | |
tile_batch_size=32, | |
prompt_batch_size=32, | |
tile_interpolate_batch_size=16, | |
**kwargs | |
): | |
cls.clip_gradcam.positive_attn_only = positive_attn_only | |
tiles, tile_imgs, counts, tile_sizes = cls.create_tiles(**kwargs) | |
outputs = { | |
k: torch.zeros( | |
[len(text_labels)] + list(count.shape), device=cls.device | |
).half() | |
for k, count in counts.items() | |
} | |
tile_gradcams = torch.cat( | |
[ | |
torch.cat( | |
[ | |
cls.clip_gradcam( | |
x=tile_imgs[tile_idx : tile_idx + tile_batch_size], | |
o=text_labels[prompt_idx : prompt_idx + prompt_batch_size], | |
) | |
for tile_idx in np.arange(0, len(tile_imgs), tile_batch_size) | |
], | |
dim=1, | |
) | |
for prompt_idx in np.arange(0, len(text_labels), prompt_batch_size) | |
], | |
dim=0, | |
) | |
if horizontal_flipping: | |
flipped_tile_imgs = tile_imgs[ | |
..., torch.flip(torch.arange(0, tile_imgs.shape[-1]), dims=[0]) | |
] | |
flipped_tile_gradcams = torch.cat( | |
[ | |
torch.cat( | |
[ | |
cls.clip_gradcam( | |
x=flipped_tile_imgs[ | |
tile_idx : tile_idx + tile_batch_size | |
], | |
o=text_labels[ | |
prompt_idx : prompt_idx + prompt_batch_size | |
], | |
) | |
for tile_idx in np.arange( | |
0, len(tile_imgs), tile_batch_size | |
) | |
], | |
dim=1, | |
) | |
for prompt_idx in np.arange(0, len(text_labels), prompt_batch_size) | |
], | |
dim=0, | |
) | |
with torch.no_grad(): | |
flipped_tile_gradcams = flipped_tile_gradcams[ | |
..., | |
torch.flip( | |
torch.arange(0, flipped_tile_gradcams.shape[-1]), dims=[0] | |
), | |
] | |
tile_gradcams = (tile_gradcams + flipped_tile_gradcams) / 2 | |
del flipped_tile_gradcams | |
with torch.no_grad(): | |
torch.cuda.empty_cache() | |
for tile_size in np.unique(tile_sizes): | |
tile_size_mask = tile_sizes == tile_size | |
curr_size_grads = tile_gradcams[:, tile_size_mask] | |
curr_size_tiles = tiles[tile_size_mask] | |
for tile_idx in np.arange( | |
0, curr_size_grads.shape[1], tile_interpolate_batch_size | |
): | |
resized_tiles = torch.nn.functional.interpolate( | |
curr_size_grads[ | |
:, tile_idx : tile_idx + tile_interpolate_batch_size | |
], | |
size=tile_size, | |
mode="bilinear", | |
align_corners=False, | |
) | |
for tile_idx, tile_slice in enumerate( | |
curr_size_tiles[ | |
tile_idx : tile_idx + tile_interpolate_batch_size | |
] | |
): | |
outputs[tile_size][tile_slice] += resized_tiles[ | |
:, tile_idx, ... | |
] | |
output = sum( | |
output.float() / count | |
for output, count in zip(outputs.values(), counts.values()) | |
) / len(counts) | |
del outputs, counts, tile_gradcams | |
output = output.cpu() | |
return output | |
def create_tiles(cls, img, augmentations, cropping_augmentations, **kwargs): | |
assert type(img) == np.ndarray | |
images = [] | |
cls.check_initialized() | |
# compute image crops | |
img_pil = Image.fromarray(img) | |
images.append(np.array(img_pil)) | |
for _ in range(augmentations): | |
images.append(np.array(cls.jittering_transforms(img_pil))) | |
# for taking average | |
counts = { | |
crop_aug["tile_size"]: torch.zeros(img.shape[:2], device=cls.device).float() | |
+ 1e-5 | |
for crop_aug in cropping_augmentations | |
} | |
tiles = [] | |
tile_imgs = [] | |
tile_sizes = [] | |
for img in images: | |
for crop_aug in cropping_augmentations: | |
tile_size = crop_aug["tile_size"] | |
stride = crop_aug["stride"] | |
for y in np.arange(0, img.shape[1] - tile_size + 1, stride): | |
if y >= img.shape[0]: | |
continue | |
for x in np.arange(0, img.shape[0] - tile_size + 1, stride): | |
if x >= img.shape[1]: | |
continue | |
tile = ( | |
slice(None, None), | |
slice(x, x + tile_size), | |
slice(y, y + tile_size), | |
) | |
tiles.append(tile) | |
counts[tile_size][tile[1:]] += 1 | |
tile_sizes.append(tile_size) | |
# this is currently biggest bottle neck | |
tile_imgs.append( | |
cls.clip_gradcam.preprocess( | |
Image.fromarray(img[tiles[-1][1:]]) | |
) | |
) | |
tile_imgs = torch.stack(tile_imgs).to(cls.device) | |
return np.array(tiles), tile_imgs, counts, np.array(tile_sizes) | |
imagenet_templates = [ | |
"a bad photo of a {}.", | |
"a photo of many {}.", | |
"a sculpture of a {}.", | |
"a photo of the hard to see {}.", | |
"a low resolution photo of the {}.", | |
"a rendering of a {}.", | |
"graffiti of a {}.", | |
"a bad photo of the {}.", | |
"a cropped photo of the {}.", | |
"a tattoo of a {}.", | |
"the embroidered {}.", | |
"a photo of a hard to see {}.", | |
"a bright photo of a {}.", | |
"a photo of a clean {}.", | |
"a photo of a dirty {}.", | |
"a dark photo of the {}.", | |
"a drawing of a {}.", | |
"a photo of my {}.", | |
"the plastic {}.", | |
"a photo of the cool {}.", | |
"a close-up photo of a {}.", | |
"a black and white photo of the {}.", | |
"a painting of the {}.", | |
"a painting of a {}.", | |
"a pixelated photo of the {}.", | |
"a sculpture of the {}.", | |
"a bright photo of the {}.", | |
"a cropped photo of a {}.", | |
"a plastic {}.", | |
"a photo of the dirty {}.", | |
"a jpeg corrupted photo of a {}.", | |
"a blurry photo of the {}.", | |
"a photo of the {}.", | |
"a good photo of the {}.", | |
"a rendering of the {}.", | |
"a {} in a video game.", | |
"a photo of one {}.", | |
"a doodle of a {}.", | |
"a close-up photo of the {}.", | |
"a photo of a {}.", | |
"the origami {}.", | |
"the {} in a video game.", | |
"a sketch of a {}.", | |
"a doodle of the {}.", | |
"a origami {}.", | |
"a low resolution photo of a {}.", | |
"the toy {}.", | |
"a rendition of the {}.", | |
"a photo of the clean {}.", | |
"a photo of a large {}.", | |
"a rendition of a {}.", | |
"a photo of a nice {}.", | |
"a photo of a weird {}.", | |
"a blurry photo of a {}.", | |
"a cartoon {}.", | |
"art of a {}.", | |
"a sketch of the {}.", | |
"a embroidered {}.", | |
"a pixelated photo of a {}.", | |
"itap of the {}.", | |
"a jpeg corrupted photo of the {}.", | |
"a good photo of a {}.", | |
"a plushie {}.", | |
"a photo of the nice {}.", | |
"a photo of the small {}.", | |
"a photo of the weird {}.", | |
"the cartoon {}.", | |
"art of the {}.", | |
"a drawing of the {}.", | |
"a photo of the large {}.", | |
"a black and white photo of a {}.", | |
"the plushie {}.", | |
"a dark photo of a {}.", | |
"itap of a {}.", | |
"graffiti of the {}.", | |
"a toy {}.", | |
"itap of my {}.", | |
"a photo of a cool {}.", | |
"a photo of a small {}.", | |
"a tattoo of the {}.", | |
] | |
__all__ = ["ClipWrapper", "imagenet_templates"] | |