""" ADOBE CONFIDENTIAL Copyright 2024 Adobe All Rights Reserved. NOTICE: All information contained herein is, and remains the property of Adobe and its suppliers, if any. The intellectual and technical concepts contained herein are proprietary to Adobe and its suppliers and are protected by all applicable intellectual property laws, including trade secret and copyright laws. Dissemination of this information or reproduction of this material is strictly forbidden unless prior written permission is obtained from Adobe. """ import torch as th from diffusers import ModelMixin from transformers import AutoModel, SiglipVisionConfig, Dinov2Config from transformers import SiglipVisionModel from diffusers.configuration_utils import ConfigMixin, register_to_config class AnalogyEncoder(ModelMixin, ConfigMixin): @register_to_config def __init__(self, load_pretrained=False, dino_config_dict=None, siglip_config_dict=None): super().__init__() if load_pretrained: image_encoder_dino = AutoModel.from_pretrained('facebook/dinov2-large', torch_dtype=th.float16) image_encoder_siglip = SiglipVisionModel.from_pretrained("google/siglip-large-patch16-256", torch_dtype=th.float16, attn_implementation="sdpa") else: image_encoder_dino = AutoModel.from_config(Dinov2Config.from_dict(dino_config_dict)) image_encoder_siglip = AutoModel.from_config(SiglipVisionConfig.from_dict(siglip_config_dict)) image_encoder_dino.requires_grad_(False) image_encoder_dino = image_encoder_dino.to(memory_format=th.channels_last) image_encoder_siglip.requires_grad_(False) image_encoder_siglip = image_encoder_siglip.to(memory_format=th.channels_last) self.image_encoder_dino = image_encoder_dino self.image_encoder_siglip = image_encoder_siglip def dino_normalization(self, encoder_output): embeds = encoder_output.last_hidden_state embeds_pooled = embeds[:, 0:1] embeds = embeds / th.norm(embeds_pooled, dim=-1, keepdim=True) return embeds def siglip_normalization(self, encoder_output): embeds = th.cat ([encoder_output.pooler_output[:, None, :], encoder_output.last_hidden_state], dim=1) embeds_pooled = embeds[:, 0:1] embeds = embeds / th.norm(embeds_pooled, dim=-1, keepdim=True) return embeds def forward(self, dino_in, siglip_in): x_1 = self.image_encoder_dino(dino_in, output_hidden_states=True) x_1_first = x_1.hidden_states[0] x_1 = self.dino_normalization(x_1) x_2 = self.image_encoder_siglip(siglip_in, output_hidden_states=True) x_2_first = x_2.hidden_states[0] x_2_first_pool = th.mean(x_2_first, dim=1, keepdim=True) x_2_first = th.cat([x_2_first_pool, x_2_first], 1) x_2 = self.siglip_normalization(x_2) dino_embd = th.cat([x_1, x_1_first], -1) siglip_embd = th.cat([x_2, x_2_first], -1) return dino_embd, siglip_embd