|
""" |
|
ADOBE CONFIDENTIAL |
|
Copyright 2024 Adobe |
|
All Rights Reserved. |
|
NOTICE: All information contained herein is, and remains |
|
the property of Adobe and its suppliers, if any. The intellectual |
|
and technical concepts contained herein are proprietary to Adobe |
|
and its suppliers and are protected by all applicable intellectual |
|
property laws, including trade secret and copyright laws. |
|
Dissemination of this information or reproduction of this material |
|
is strictly forbidden unless prior written permission is obtained |
|
from Adobe. |
|
""" |
|
|
|
import torch as th |
|
from torchvision import transforms |
|
from diffusers import ModelMixin |
|
from diffusers.configuration_utils import ConfigMixin, register_to_config |
|
|
|
DINO_SIZE = 224 |
|
DINO_MEAN = [0.485, 0.456, 0.406] |
|
DINO_STD = [0.229, 0.224, 0.225] |
|
|
|
SIGLIP_SIZE = 256 |
|
SIGLIP_MEAN = [0.5] |
|
SIGLIP_STD = [0.5] |
|
|
|
|
|
class AnalogyInputProcessor(ModelMixin, ConfigMixin): |
|
|
|
@register_to_config |
|
def __init__(self,): |
|
super(AnalogyInputProcessor, self).__init__() |
|
|
|
self.dino_transform = transforms.Compose( |
|
[ |
|
transforms.Resize((DINO_SIZE, DINO_SIZE)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(DINO_MEAN, DINO_STD), |
|
] |
|
) |
|
|
|
self.siglip_transform = transforms.Compose( |
|
[ |
|
transforms.Resize((SIGLIP_SIZE, SIGLIP_SIZE)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(SIGLIP_MEAN, SIGLIP_STD), |
|
] |
|
) |
|
|
|
dino_mean = th.tensor(DINO_MEAN).view(1, 3, 1, 1) |
|
dino_std = th.tensor(DINO_STD).view(1, 3, 1, 1) |
|
siglip_mean = [SIGLIP_MEAN[0],] * 3 |
|
siglip_std = [SIGLIP_STD[0],] * 3 |
|
siglip_mean = th.tensor(siglip_mean).view(1, 3, 1, 1) |
|
siglip_std = th.tensor(siglip_std).view(1, 3, 1, 1) |
|
self.register_buffer("dino_mean", dino_mean) |
|
self.register_buffer("dino_std", dino_std) |
|
self.register_buffer("siglip_mean", siglip_mean) |
|
self.register_buffer("siglip_std", siglip_std) |
|
|
|
def __call__(self, analogy_prompt): |
|
|
|
img_a_dino = [] |
|
img_a_siglip = [] |
|
img_a_star_dino = [] |
|
img_a_star_siglip = [] |
|
img_b_dino = [] |
|
img_b_siglip = [] |
|
|
|
for im_set in analogy_prompt: |
|
img_a, img_a_star, img_b = im_set |
|
img_a_dino.append(self.dino_transform(img_a)) |
|
img_a_siglip.append(self.siglip_transform(img_a)) |
|
img_a_star_dino.append(self.dino_transform(img_a_star)) |
|
img_a_star_siglip.append(self.siglip_transform(img_a_star)) |
|
img_b_dino.append(self.dino_transform(img_b)) |
|
img_b_siglip.append(self.siglip_transform(img_b)) |
|
|
|
img_a_dino = th.stack(img_a_dino, 0) |
|
img_a_siglip = th.stack(img_a_siglip, 0) |
|
img_a_star_dino = th.stack(img_a_star_dino, 0) |
|
img_a_star_siglip = th.stack(img_a_star_siglip, 0) |
|
img_b_dino = th.stack(img_b_dino, 0) |
|
img_b_siglip = th.stack(img_b_siglip, 0) |
|
|
|
dino_combined_input = th.stack([img_b_dino, img_a_dino, img_a_star_dino], 0) |
|
siglip_combined_input = th.stack([img_b_siglip, img_a_siglip, img_a_star_siglip], 0) |
|
|
|
return dino_combined_input, siglip_combined_input |
|
def get_negative(self, dino_in, siglip_in): |
|
|
|
dino_i = ((dino_in * 0 + 0.5) - self.dino_mean) / self.dino_std |
|
siglip_i = ((siglip_in * 0 + 0.5) - self.siglip_mean) / self.siglip_std |
|
return dino_i, siglip_i |
|
|
|
|