|
""" |
|
image_proj_model.py |
|
|
|
This module defines the ImageProjModel class, which is responsible for |
|
projecting image embeddings into a different dimensional space. The model |
|
leverages a linear transformation followed by a layer normalization to |
|
reshape and normalize the input image embeddings for further processing in |
|
cross-attention mechanisms or other downstream tasks. |
|
|
|
Classes: |
|
ImageProjModel |
|
|
|
Dependencies: |
|
torch |
|
diffusers.ModelMixin |
|
|
|
""" |
|
|
|
import torch |
|
from diffusers import ModelMixin |
|
|
|
|
|
class ImageProjModel(ModelMixin): |
|
""" |
|
ImageProjModel is a class that projects image embeddings into a different |
|
dimensional space. It inherits from ModelMixin, providing additional functionalities |
|
specific to image projection. |
|
|
|
Attributes: |
|
cross_attention_dim (int): The dimension of the cross attention. |
|
clip_embeddings_dim (int): The dimension of the CLIP embeddings. |
|
clip_extra_context_tokens (int): The number of extra context tokens in CLIP. |
|
|
|
Methods: |
|
forward(image_embeds): Forward pass of the ImageProjModel, which takes in image |
|
embeddings and returns the projected tokens. |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
cross_attention_dim=1024, |
|
clip_embeddings_dim=1024, |
|
clip_extra_context_tokens=4, |
|
): |
|
super().__init__() |
|
|
|
self.generator = None |
|
self.cross_attention_dim = cross_attention_dim |
|
self.clip_extra_context_tokens = clip_extra_context_tokens |
|
self.proj = torch.nn.Linear( |
|
clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim |
|
) |
|
self.norm = torch.nn.LayerNorm(cross_attention_dim) |
|
|
|
def forward(self, image_embeds): |
|
""" |
|
Forward pass of the ImageProjModel, which takes in image embeddings and returns the |
|
projected tokens after reshaping and normalization. |
|
|
|
Args: |
|
image_embeds (torch.Tensor): The input image embeddings, with shape |
|
batch_size x num_image_tokens x clip_embeddings_dim. |
|
|
|
Returns: |
|
clip_extra_context_tokens (torch.Tensor): The projected tokens after reshaping |
|
and normalization, with shape batch_size x (clip_extra_context_tokens * |
|
cross_attention_dim). |
|
|
|
""" |
|
embeds = image_embeds |
|
clip_extra_context_tokens = self.proj(embeds).reshape( |
|
-1, self.clip_extra_context_tokens, self.cross_attention_dim |
|
) |
|
clip_extra_context_tokens = self.norm(clip_extra_context_tokens) |
|
return clip_extra_context_tokens |
|
|