|
import os |
|
import numpy as np |
|
from torchvision import transforms |
|
import torch |
|
import torch.nn as nn |
|
import PIL |
|
import clip |
|
import open_clip |
|
from functools import partial |
|
import random |
|
import json |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Clipper(torch.nn.Module): |
|
def __init__(self, clip_variant, clamp_embs=False, norm_embs=False, |
|
hidden_state=False, device=torch.device('cpu')): |
|
super().__init__() |
|
assert clip_variant in ("RN50", "ViT-L/14", "ViT-B/32", "RN50x64"), \ |
|
"clip_variant must be one of RN50, ViT-L/14, ViT-B/32, RN50x64" |
|
print(clip_variant, device) |
|
|
|
if clip_variant=="ViT-L/14" and hidden_state: |
|
|
|
|
|
from transformers import CLIPVisionModelWithProjection |
|
sd_cache_dir = '/fsx/proj-fmri/shared/cache/models--shi-labs--versatile-diffusion/snapshots/2926f8e11ea526b562cd592b099fcf9c2985d0b7' |
|
image_encoder = CLIPVisionModelWithProjection.from_pretrained(sd_cache_dir, subfolder='image_encoder').eval() |
|
image_encoder = image_encoder.to(device) |
|
for param in image_encoder.parameters(): |
|
param.requires_grad = False |
|
self.image_encoder = image_encoder |
|
elif hidden_state: |
|
raise Exception("hidden_state embeddings only works with ViT-L/14 right now") |
|
|
|
clip_model, preprocess = clip.load(clip_variant, device=device) |
|
clip_model.eval() |
|
for param in clip_model.parameters(): |
|
param.requires_grad = False |
|
|
|
self.clip = clip_model |
|
self.clip_variant = clip_variant |
|
if clip_variant == "RN50x64": |
|
self.clip_size = (448,448) |
|
else: |
|
self.clip_size = (224,224) |
|
|
|
preproc = transforms.Compose([ |
|
transforms.Resize(size=self.clip_size[0], interpolation=transforms.InterpolationMode.BICUBIC), |
|
transforms.CenterCrop(size=self.clip_size), |
|
transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) |
|
]) |
|
self.preprocess = preproc |
|
self.hidden_state = hidden_state |
|
self.mean = np.array([0.48145466, 0.4578275, 0.40821073]) |
|
self.std = np.array([0.26862954, 0.26130258, 0.27577711]) |
|
self.normalize = transforms.Normalize(self.mean, self.std) |
|
self.denormalize = transforms.Normalize((-self.mean / self.std).tolist(), (1.0 / self.std).tolist()) |
|
self.clamp_embs = clamp_embs |
|
self.norm_embs = norm_embs |
|
self.device= device |
|
|
|
def versatile_normalize_embeddings(encoder_output): |
|
embeds = encoder_output.last_hidden_state |
|
embeds = image_encoder.vision_model.post_layernorm(embeds) |
|
embeds = image_encoder.visual_projection(embeds) |
|
return embeds |
|
self.versatile_normalize_embeddings = versatile_normalize_embeddings |
|
|
|
def resize_image(self, image): |
|
|
|
return transforms.Resize(self.clip_size)(image.to(self.device)) |
|
|
|
def embed_image(self, image): |
|
"""Expects images in -1 to 1 range""" |
|
if self.hidden_state: |
|
|
|
clip_emb = self.preprocess((image).to(self.device)) |
|
clip_emb = self.image_encoder(clip_emb) |
|
clip_emb = self.versatile_normalize_embeddings(clip_emb) |
|
else: |
|
clip_emb = self.preprocess(image.to(self.device)) |
|
clip_emb = self.clip.encode_image(clip_emb) |
|
|
|
if self.clamp_embs: |
|
clip_emb = torch.clamp(clip_emb, -1.5, 1.5) |
|
if self.norm_embs: |
|
if self.hidden_state: |
|
|
|
clip_emb = clip_emb / torch.norm(clip_emb[:, 0], dim=-1).reshape(-1, 1, 1) |
|
else: |
|
clip_emb = nn.functional.normalize(clip_emb, dim=-1) |
|
return clip_emb |
|
|
|
def embed_text(self, text_samples): |
|
clip_text = clip.tokenize(text_samples).to(self.device) |
|
clip_text = self.clip.encode_text(clip_text) |
|
if self.clamp_embs: |
|
clip_text = torch.clamp(clip_text, -1.5, 1.5) |
|
if self.norm_embs: |
|
clip_text = nn.functional.normalize(clip_text, dim=-1) |
|
return clip_text |
|
|
|
def embed_curated_annotations(self, annots): |
|
for i,b in enumerate(annots): |
|
t = '' |
|
while t == '': |
|
rand = torch.randint(5,(1,1))[0][0] |
|
t = b[0,rand] |
|
if i==0: |
|
txt = np.array(t) |
|
else: |
|
txt = np.vstack((txt,t)) |
|
txt = txt.flatten() |
|
return self.embed_text(txt) |
|
|
|
class OpenClipper(torch.nn.Module): |
|
def __init__(self, clip_variant, norm_embs=False, device=torch.device('cpu')): |
|
super().__init__() |
|
print(clip_variant, device) |
|
assert clip_variant == 'ViT-H-14' |
|
|
|
clip_model, _, preprocess = open_clip.create_model_and_transforms('ViT-H-14', |
|
pretrained='laion2b_s32b_b79k', device=device) |
|
clip_model.eval() |
|
for param in clip_model.parameters(): |
|
param.requires_grad = False |
|
|
|
|
|
preprocess = transforms.Compose([ |
|
transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC, antialias=None), |
|
transforms.CenterCrop(224), |
|
transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) |
|
]) |
|
|
|
tokenizer = open_clip.get_tokenizer('ViT-H-14') |
|
|
|
self.clip = clip_model |
|
self.norm_embs = norm_embs |
|
self.preprocess = preprocess |
|
self.tokenizer = tokenizer |
|
self.device = device |
|
|
|
def embed_image(self, image): |
|
"""Expects images in -1 to 1 range""" |
|
image = self.preprocess(image).to(self.device) |
|
with torch.no_grad(), torch.cuda.amp.autocast(): |
|
image_features = self.clip.encode_image(image) |
|
if self.norm_embs: |
|
image_features = nn.functional.normalize(image_features, dim=-1) |
|
return image_features |
|
|
|
def embed_text(self, text_samples): |
|
text = self.tokenizer(text_samples).to(self.device) |
|
with torch.no_grad(), torch.cuda.amp.autocast(): |
|
text_features = self.clip.encode_text(text) |
|
if self.norm_embs: |
|
text_features = nn.functional.normalize(text_features, dim=-1) |
|
return text_features |
|
|
|
def embed_curated_annotations(self, annots): |
|
for i,b in enumerate(annots): |
|
t = '' |
|
while t == '': |
|
rand = torch.randint(5,(1,1))[0][0] |
|
t = b[0,rand] |
|
if i==0: |
|
txt = np.array(t) |
|
else: |
|
txt = np.vstack((txt,t)) |
|
txt = txt.flatten() |
|
return self.embed_text(txt) |