Spaces:
Running
on
L4
Running
on
L4
import torch | |
import torch.nn as nn | |
import numpy as np | |
from functools import partial | |
from lib.model_zoo.common.get_model import register | |
import torch.nn.functional as F | |
symbol = 'clip' | |
class AbstractEncoder(nn.Module): | |
def __init__(self): | |
super().__init__() | |
def encode(self, *args, **kwargs): | |
raise NotImplementedError | |
from transformers import CLIPTokenizer, CLIPTextModel | |
def disabled_train(self, mode=True): | |
"""Overwrite model.train with this function to make sure train/eval mode | |
does not change anymore.""" | |
return self | |
############### | |
# for vd next # | |
############### | |
from transformers import CLIPModel | |
class CLIPTextContextEncoder(AbstractEncoder): | |
def __init__(self, | |
version="openai/clip-vit-large-patch14", | |
max_length=77, | |
fp16=False, ): | |
super().__init__() | |
self.tokenizer = CLIPTokenizer.from_pretrained(version) | |
self.model = CLIPModel.from_pretrained(version) | |
self.max_length = max_length | |
self.fp16 = fp16 | |
self.freeze() | |
def get_device(self): | |
# A trick to get device | |
return self.model.text_projection.weight.device | |
def freeze(self): | |
self.model = self.model.eval() | |
self.train = disabled_train | |
for param in self.parameters(): | |
param.requires_grad = False | |
def encode(self, text): | |
batch_encoding = self.tokenizer( | |
text, truncation=True, max_length=self.max_length, return_length=True, | |
return_overflowing_tokens=False, padding="max_length", return_tensors="pt") | |
tokens = batch_encoding["input_ids"].to(self.get_device()) | |
outputs = self.model.text_model(input_ids=tokens) | |
z = self.model.text_projection(outputs.last_hidden_state) | |
z_pooled = self.model.text_projection(outputs.pooler_output) | |
z = z / torch.norm(z_pooled.unsqueeze(1), dim=-1, keepdim=True) | |
return z | |
from transformers import CLIPProcessor | |
class CLIPImageContextEncoder(AbstractEncoder): | |
def __init__(self, | |
version="openai/clip-vit-large-patch14", | |
fp16=False, ): | |
super().__init__() | |
self.tokenizer = CLIPTokenizer.from_pretrained(version) | |
self.processor = CLIPProcessor.from_pretrained(version) | |
self.model = CLIPModel.from_pretrained(version) | |
self.fp16 = fp16 | |
self.freeze() | |
def get_device(self): | |
# A trick to get device | |
return self.model.text_projection.weight.device | |
def freeze(self): | |
self.model = self.model.eval() | |
self.train = disabled_train | |
for param in self.parameters(): | |
param.requires_grad = False | |
def _encode(self, images): | |
if isinstance(images, torch.Tensor): | |
import torchvision.transforms as tvtrans | |
images = [tvtrans.ToPILImage()(i) for i in images] | |
inputs = self.processor(images=images, return_tensors="pt") | |
pixels = inputs['pixel_values'].half() if self.fp16 else inputs['pixel_values'] | |
pixels = pixels.to(self.get_device()) | |
outputs = self.model.vision_model(pixel_values=pixels) | |
z = outputs.last_hidden_state | |
z = self.model.vision_model.post_layernorm(z) | |
z = self.model.visual_projection(z) | |
z_pooled = z[:, 0:1] | |
z = z / torch.norm(z_pooled, dim=-1, keepdim=True) | |
return z | |
def _encode_wmask(self, images, masks): | |
assert isinstance(masks, torch.Tensor) | |
assert (len(masks.shape)==4) and (masks.shape[1]==1) | |
masks = torch.clamp(masks, 0, 1) | |
masks = masks.float() | |
masks = F.interpolate(masks, [224, 224], mode='bilinear') | |
if masks.sum() == masks.numel(): | |
return self._encode(images) | |
device = images.device | |
dtype = images.dtype | |
gscale = masks.mean(axis=[1, 2, 3], keepdim=True).flatten(2) | |
vtoken_kernel_size = self.model.vision_model.embeddings.patch_embedding.kernel_size | |
vtoken_stride = self.model.vision_model.embeddings.patch_embedding.stride | |
mask_kernal = torch.ones([1, 1, *vtoken_kernel_size], device=device, requires_grad=False).float() | |
vtoken_mask = torch.nn.functional.conv2d(masks, mask_kernal, stride=vtoken_stride).flatten(2).transpose(1, 2) | |
vtoken_mask = vtoken_mask/np.prod(vtoken_kernel_size) | |
vtoken_mask = torch.concat([gscale, vtoken_mask], axis=1) | |
import types | |
def customized_embedding_forward(self, pixel_values): | |
batch_size = pixel_values.shape[0] | |
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] | |
patch_embeds = patch_embeds.flatten(2).transpose(1, 2) | |
class_embeds = self.class_embedding.expand(batch_size, 1, -1) | |
embeddings = torch.cat([class_embeds, patch_embeds], dim=1) | |
embeddings = embeddings + self.position_embedding(self.position_ids) | |
embeddings = embeddings*vtoken_mask.to(embeddings.dtype) | |
return embeddings | |
old_forward = self.model.vision_model.embeddings.forward | |
self.model.vision_model.embeddings.forward = types.MethodType( | |
customized_embedding_forward, self.model.vision_model.embeddings) | |
z = self._encode(images) | |
self.model.vision_model.embeddings.forward = old_forward | |
z = z * vtoken_mask.to(dtype) | |
return z | |
def encode(self, images, masks=None): | |
if masks is None: | |
return self._encode(images) | |
else: | |
return self._encode_wmask(images, masks) | |