semabs-relevancy / CLIP /clip /clip_gradcam.py
huy-ha's picture
relevancy extractor
8fbac9e
raw
history blame
5.42 kB
from typing import List
import torch
import torch.nn as nn
from .clip_explainability import load
from .clip import tokenize
from torch import device
import numpy as np
import torch.nn.functional as nnf
import itertools
def zeroshot_classifier(clip_model, classnames, templates, device):
with torch.no_grad():
texts = list(
itertools.chain(
*[
[template.format(classname) for template in templates]
for classname in classnames
]
)
) # format with class
texts = tokenize(texts).to(device) # tokenize
class_embeddings = clip_model.encode_text(texts)
class_embeddings = class_embeddings.view(len(classnames), len(templates), -1)
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
zeroshot_weights = class_embeddings.mean(dim=1)
return zeroshot_weights.T # shape: [dim, n classes]
class ClipGradcam(nn.Module):
def __init__(
self,
clip_model_name: str,
classes: List[str],
templates: List[str],
device: device,
num_layers=10,
positive_attn_only=False,
**kwargs
):
super(ClipGradcam, self).__init__()
self.clip_model_name = clip_model_name
self.model, self.preprocess = load(clip_model_name, device=device, **kwargs)
self.templates = templates
self.device = device
self.target_classes = None
self.set_classes(classes)
self.num_layers = num_layers
self.positive_attn_only = positive_attn_only
self.num_res_attn_blocks = {
"ViT-B/32": 12,
"ViT-B/16": 12,
"ViT-L/14": 16,
"ViT-L/14@336px": 16,
}[clip_model_name]
def forward(self, x: torch.Tensor, o: List[str]):
"""
non-standard hack around an nn, really should be more principled here
"""
image_features = self.model.encode_image(x.to(self.device))
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
zeroshot_weights = torch.cat(
[self.class_to_language_feature[prompt] for prompt in o], dim=1
)
logits_per_image = 100.0 * image_features @ zeroshot_weights
return self.interpret(logits_per_image, self.model, self.device)
def interpret(self, logits_per_image, model, device):
# modified from: https://colab.research.google.com/github/hila-chefer/Transformer-MM-Explainability/blob/main/CLIP_explainability.ipynb#scrollTo=fWKGyu2YAeSV
batch_size = logits_per_image.shape[0]
num_prompts = logits_per_image.shape[1]
one_hot = [logit for logit in logits_per_image.sum(dim=0)]
model.zero_grad()
image_attn_blocks = list(
dict(model.visual.transformer.resblocks.named_children()).values()
)
num_tokens = image_attn_blocks[0].attn_probs.shape[-1]
R = torch.eye(
num_tokens, num_tokens, dtype=image_attn_blocks[0].attn_probs.dtype
).to(device)
R = R[None, None, :, :].repeat(num_prompts, batch_size, 1, 1)
for i, block in enumerate(image_attn_blocks):
if i <= self.num_layers:
continue
# TODO try scaling block.attn_probs by value magnitude
# TODO actual parallelized prompt gradients
grad = torch.stack(
[
torch.autograd.grad(logit, [block.attn_probs], retain_graph=True)[
0
].detach()
for logit in one_hot
]
)
grad = grad.view(
num_prompts,
batch_size,
self.num_res_attn_blocks,
num_tokens,
num_tokens,
)
cam = (
block.attn_probs.view(
1, batch_size, self.num_res_attn_blocks, num_tokens, num_tokens
)
.detach()
.repeat(num_prompts, 1, 1, 1, 1)
)
cam = cam.reshape(num_prompts, batch_size, -1, cam.shape[-1], cam.shape[-1])
grad = grad.reshape(
num_prompts, batch_size, -1, grad.shape[-1], grad.shape[-1]
)
cam = grad * cam
cam = cam.reshape(
num_prompts * batch_size, -1, cam.shape[-1], cam.shape[-1]
)
if self.positive_attn_only:
cam = cam.clamp(min=0)
# average of all heads
cam = cam.mean(dim=-3)
R = R + torch.bmm(
cam, R.view(num_prompts * batch_size, num_tokens, num_tokens)
).view(num_prompts, batch_size, num_tokens, num_tokens)
image_relevance = R[:, :, 0, 1:]
img_dim = int(np.sqrt(num_tokens - 1))
image_relevance = image_relevance.reshape(
num_prompts, batch_size, img_dim, img_dim
)
return image_relevance
def set_classes(self, classes):
self.target_classes = classes
language_features = zeroshot_classifier(
self.model, self.target_classes, self.templates, self.device
)
self.class_to_language_feature = {}
for i, c in enumerate(self.target_classes):
self.class_to_language_feature[c] = language_features[:, [i]]