|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Any, Dict, Optional
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn
|
|
|
|
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
|
from diffusers.models.embeddings import ImagePositionalEmbeddings
|
|
from diffusers.utils import BaseOutput, deprecate
|
|
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
|
from diffusers.models.attention import FeedForward, AdaLayerNorm, AdaLayerNormZero, Attention
|
|
from diffusers.models.embeddings import PatchEmbed
|
|
from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
|
|
from diffusers.models.modeling_utils import ModelMixin
|
|
from diffusers.utils.import_utils import is_xformers_available
|
|
|
|
from einops import rearrange
|
|
import pdb
|
|
import random
|
|
import math
|
|
|
|
|
|
if is_xformers_available():
|
|
import xformers
|
|
import xformers.ops
|
|
else:
|
|
xformers = None
|
|
|
|
|
|
@dataclass
|
|
class TransformerMV2DModelOutput(BaseOutput):
|
|
"""
|
|
The output of [`Transformer2DModel`].
|
|
|
|
Args:
|
|
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
|
|
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
|
|
distributions for the unnoised latent pixels.
|
|
"""
|
|
|
|
sample: torch.FloatTensor
|
|
|
|
|
|
class TransformerMV2DModel(ModelMixin, ConfigMixin):
|
|
"""
|
|
A 2D Transformer model for image-like data.
|
|
|
|
Parameters:
|
|
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
|
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
|
in_channels (`int`, *optional*):
|
|
The number of channels in the input and output (specify if the input is **continuous**).
|
|
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
|
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
|
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
|
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
|
|
This is fixed during training since it is used to learn a number of position embeddings.
|
|
num_vector_embeds (`int`, *optional*):
|
|
The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
|
|
Includes the class for the masked latent pixel.
|
|
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
|
|
num_embeds_ada_norm ( `int`, *optional*):
|
|
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
|
|
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
|
|
added to the hidden states.
|
|
|
|
During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
|
|
attention_bias (`bool`, *optional*):
|
|
Configure if the `TransformerBlocks` attention should contain a bias parameter.
|
|
"""
|
|
|
|
@register_to_config
|
|
def __init__(
|
|
self,
|
|
num_attention_heads: int = 16,
|
|
attention_head_dim: int = 88,
|
|
in_channels: Optional[int] = None,
|
|
out_channels: Optional[int] = None,
|
|
num_layers: int = 1,
|
|
dropout: float = 0.0,
|
|
norm_num_groups: int = 32,
|
|
cross_attention_dim: Optional[int] = None,
|
|
attention_bias: bool = False,
|
|
sample_size: Optional[int] = None,
|
|
num_vector_embeds: Optional[int] = None,
|
|
patch_size: Optional[int] = None,
|
|
activation_fn: str = "geglu",
|
|
num_embeds_ada_norm: Optional[int] = None,
|
|
use_linear_projection: bool = False,
|
|
only_cross_attention: bool = False,
|
|
upcast_attention: bool = False,
|
|
norm_type: str = "layer_norm",
|
|
norm_elementwise_affine: bool = True,
|
|
num_views: int = 1,
|
|
cd_attention_last: bool=False,
|
|
cd_attention_mid: bool=False,
|
|
multiview_attention: bool=True,
|
|
sparse_mv_attention: bool = True,
|
|
mvcd_attention: bool=False
|
|
):
|
|
super().__init__()
|
|
self.use_linear_projection = use_linear_projection
|
|
self.num_attention_heads = num_attention_heads
|
|
self.attention_head_dim = attention_head_dim
|
|
inner_dim = num_attention_heads * attention_head_dim
|
|
|
|
|
|
|
|
self.is_input_continuous = (in_channels is not None) and (patch_size is None)
|
|
self.is_input_vectorized = num_vector_embeds is not None
|
|
self.is_input_patches = in_channels is not None and patch_size is not None
|
|
|
|
if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
|
|
deprecation_message = (
|
|
f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
|
|
" incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
|
|
" Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
|
|
" results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
|
|
" would be very nice if you could open a Pull request for the `transformer/config.json` file"
|
|
)
|
|
deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
|
|
norm_type = "ada_norm"
|
|
|
|
if self.is_input_continuous and self.is_input_vectorized:
|
|
raise ValueError(
|
|
f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
|
|
" sure that either `in_channels` or `num_vector_embeds` is None."
|
|
)
|
|
elif self.is_input_vectorized and self.is_input_patches:
|
|
raise ValueError(
|
|
f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
|
|
" sure that either `num_vector_embeds` or `num_patches` is None."
|
|
)
|
|
elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
|
|
raise ValueError(
|
|
f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
|
|
f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
|
|
)
|
|
|
|
|
|
if self.is_input_continuous:
|
|
self.in_channels = in_channels
|
|
|
|
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
|
if use_linear_projection:
|
|
self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
|
|
else:
|
|
self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
|
elif self.is_input_vectorized:
|
|
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
|
|
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
|
|
|
|
self.height = sample_size
|
|
self.width = sample_size
|
|
self.num_vector_embeds = num_vector_embeds
|
|
self.num_latent_pixels = self.height * self.width
|
|
|
|
self.latent_image_embedding = ImagePositionalEmbeddings(
|
|
num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
|
|
)
|
|
elif self.is_input_patches:
|
|
assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
|
|
|
|
self.height = sample_size
|
|
self.width = sample_size
|
|
|
|
self.patch_size = patch_size
|
|
self.pos_embed = PatchEmbed(
|
|
height=sample_size,
|
|
width=sample_size,
|
|
patch_size=patch_size,
|
|
in_channels=in_channels,
|
|
embed_dim=inner_dim,
|
|
)
|
|
|
|
|
|
self.transformer_blocks = nn.ModuleList(
|
|
[
|
|
BasicMVTransformerBlock(
|
|
inner_dim,
|
|
num_attention_heads,
|
|
attention_head_dim,
|
|
dropout=dropout,
|
|
cross_attention_dim=cross_attention_dim,
|
|
activation_fn=activation_fn,
|
|
num_embeds_ada_norm=num_embeds_ada_norm,
|
|
attention_bias=attention_bias,
|
|
only_cross_attention=only_cross_attention,
|
|
upcast_attention=upcast_attention,
|
|
norm_type=norm_type,
|
|
norm_elementwise_affine=norm_elementwise_affine,
|
|
num_views=num_views,
|
|
cd_attention_last=cd_attention_last,
|
|
cd_attention_mid=cd_attention_mid,
|
|
multiview_attention=multiview_attention,
|
|
mvcd_attention=mvcd_attention
|
|
)
|
|
for d in range(num_layers)
|
|
]
|
|
)
|
|
|
|
|
|
self.out_channels = in_channels if out_channels is None else out_channels
|
|
if self.is_input_continuous:
|
|
|
|
if use_linear_projection:
|
|
self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
|
|
else:
|
|
self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
|
elif self.is_input_vectorized:
|
|
self.norm_out = nn.LayerNorm(inner_dim)
|
|
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
|
|
elif self.is_input_patches:
|
|
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
|
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
|
|
self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
timestep: Optional[torch.LongTensor] = None,
|
|
class_labels: Optional[torch.LongTensor] = None,
|
|
cross_attention_kwargs: Dict[str, Any] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
|
return_dict: bool = True,
|
|
):
|
|
"""
|
|
The [`Transformer2DModel`] forward method.
|
|
|
|
Args:
|
|
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
|
Input `hidden_states`.
|
|
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
|
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
|
self-attention.
|
|
timestep ( `torch.LongTensor`, *optional*):
|
|
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
|
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
|
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
|
`AdaLayerZeroNorm`.
|
|
encoder_attention_mask ( `torch.Tensor`, *optional*):
|
|
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
|
|
|
|
* Mask `(batch, sequence_length)` True = keep, False = discard.
|
|
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
|
|
|
|
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
|
|
above. This bias will be added to the cross-attention scores.
|
|
return_dict (`bool`, *optional*, defaults to `True`):
|
|
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
|
tuple.
|
|
|
|
Returns:
|
|
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
|
`tuple` where the first element is the sample tensor.
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if attention_mask is not None and attention_mask.ndim == 2:
|
|
|
|
|
|
|
|
|
|
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
|
attention_mask = attention_mask.unsqueeze(1)
|
|
|
|
|
|
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
|
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
|
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
|
|
|
|
|
if self.is_input_continuous:
|
|
batch, _, height, width = hidden_states.shape
|
|
residual = hidden_states
|
|
|
|
hidden_states = self.norm(hidden_states)
|
|
if not self.use_linear_projection:
|
|
hidden_states = self.proj_in(hidden_states)
|
|
inner_dim = hidden_states.shape[1]
|
|
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
|
else:
|
|
inner_dim = hidden_states.shape[1]
|
|
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
|
hidden_states = self.proj_in(hidden_states)
|
|
elif self.is_input_vectorized:
|
|
hidden_states = self.latent_image_embedding(hidden_states)
|
|
elif self.is_input_patches:
|
|
hidden_states = self.pos_embed(hidden_states)
|
|
|
|
|
|
for block in self.transformer_blocks:
|
|
hidden_states = block(
|
|
hidden_states,
|
|
attention_mask=attention_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
timestep=timestep,
|
|
cross_attention_kwargs=cross_attention_kwargs,
|
|
class_labels=class_labels,
|
|
)
|
|
|
|
|
|
if self.is_input_continuous:
|
|
if not self.use_linear_projection:
|
|
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
|
hidden_states = self.proj_out(hidden_states)
|
|
else:
|
|
hidden_states = self.proj_out(hidden_states)
|
|
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
|
|
|
output = hidden_states + residual
|
|
elif self.is_input_vectorized:
|
|
hidden_states = self.norm_out(hidden_states)
|
|
logits = self.out(hidden_states)
|
|
|
|
logits = logits.permute(0, 2, 1)
|
|
|
|
|
|
output = F.log_softmax(logits.double(), dim=1).float()
|
|
elif self.is_input_patches:
|
|
|
|
conditioning = self.transformer_blocks[0].norm1.emb(
|
|
timestep, class_labels, hidden_dtype=hidden_states.dtype
|
|
)
|
|
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
|
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
|
hidden_states = self.proj_out_2(hidden_states)
|
|
|
|
|
|
height = width = int(hidden_states.shape[1] ** 0.5)
|
|
hidden_states = hidden_states.reshape(
|
|
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
|
|
)
|
|
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
|
output = hidden_states.reshape(
|
|
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
|
|
)
|
|
|
|
if not return_dict:
|
|
return (output,)
|
|
|
|
return TransformerMV2DModelOutput(sample=output)
|
|
|
|
|
|
@maybe_allow_in_graph
|
|
class BasicMVTransformerBlock(nn.Module):
|
|
r"""
|
|
A basic Transformer block.
|
|
|
|
Parameters:
|
|
dim (`int`): The number of channels in the input and output.
|
|
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
|
attention_head_dim (`int`): The number of channels in each head.
|
|
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
|
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
|
only_cross_attention (`bool`, *optional*):
|
|
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
|
double_self_attention (`bool`, *optional*):
|
|
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
|
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
|
num_embeds_ada_norm (:
|
|
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
|
attention_bias (:
|
|
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
num_attention_heads: int,
|
|
attention_head_dim: int,
|
|
dropout=0.0,
|
|
cross_attention_dim: Optional[int] = None,
|
|
activation_fn: str = "geglu",
|
|
num_embeds_ada_norm: Optional[int] = None,
|
|
attention_bias: bool = False,
|
|
only_cross_attention: bool = False,
|
|
double_self_attention: bool = False,
|
|
upcast_attention: bool = False,
|
|
norm_elementwise_affine: bool = True,
|
|
norm_type: str = "layer_norm",
|
|
final_dropout: bool = False,
|
|
num_views: int = 1,
|
|
cd_attention_last: bool = False,
|
|
cd_attention_mid: bool = False,
|
|
multiview_attention: bool = True,
|
|
mvcd_attention: bool = False,
|
|
rowwise_attention: bool = True
|
|
):
|
|
super().__init__()
|
|
self.only_cross_attention = only_cross_attention
|
|
|
|
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
|
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
|
|
|
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
|
raise ValueError(
|
|
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
|
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
|
)
|
|
|
|
|
|
|
|
if self.use_ada_layer_norm:
|
|
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
|
elif self.use_ada_layer_norm_zero:
|
|
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
|
else:
|
|
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
|
|
|
self.multiview_attention = multiview_attention
|
|
self.mvcd_attention = mvcd_attention
|
|
self.rowwise_attention = multiview_attention and rowwise_attention
|
|
|
|
|
|
|
|
print('INFO: using row wise attention...')
|
|
|
|
self.attn1 = CustomAttention(
|
|
query_dim=dim,
|
|
heads=num_attention_heads,
|
|
dim_head=attention_head_dim,
|
|
dropout=dropout,
|
|
bias=attention_bias,
|
|
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
|
upcast_attention=upcast_attention,
|
|
processor=MVAttnProcessor()
|
|
)
|
|
|
|
|
|
if cross_attention_dim is not None or double_self_attention:
|
|
|
|
|
|
|
|
self.norm2 = (
|
|
AdaLayerNorm(dim, num_embeds_ada_norm)
|
|
if self.use_ada_layer_norm
|
|
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
|
)
|
|
self.attn2 = Attention(
|
|
query_dim=dim,
|
|
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
|
heads=num_attention_heads,
|
|
dim_head=attention_head_dim,
|
|
dropout=dropout,
|
|
bias=attention_bias,
|
|
upcast_attention=upcast_attention,
|
|
)
|
|
else:
|
|
self.norm2 = None
|
|
self.attn2 = None
|
|
|
|
|
|
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
|
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
|
|
|
|
|
|
self._chunk_size = None
|
|
self._chunk_dim = 0
|
|
|
|
self.num_views = num_views
|
|
|
|
self.cd_attention_last = cd_attention_last
|
|
|
|
if self.cd_attention_last:
|
|
|
|
self.attn_joint = CustomJointAttention(
|
|
query_dim=dim,
|
|
heads=num_attention_heads,
|
|
dim_head=attention_head_dim,
|
|
dropout=dropout,
|
|
bias=attention_bias,
|
|
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
|
upcast_attention=upcast_attention,
|
|
processor=JointAttnProcessor()
|
|
)
|
|
nn.init.zeros_(self.attn_joint.to_out[0].weight.data)
|
|
self.norm_joint = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
|
|
|
|
|
self.cd_attention_mid = cd_attention_mid
|
|
|
|
if self.cd_attention_mid:
|
|
print("joint twice")
|
|
|
|
self.attn_joint_twice = CustomJointAttention(
|
|
query_dim=dim,
|
|
heads=num_attention_heads,
|
|
dim_head=attention_head_dim,
|
|
dropout=dropout,
|
|
bias=attention_bias,
|
|
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
|
upcast_attention=upcast_attention,
|
|
processor=JointAttnProcessor()
|
|
)
|
|
nn.init.zeros_(self.attn_joint_twice.to_out[0].weight.data)
|
|
self.norm_joint_twice = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
|
|
|
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
|
|
|
|
self._chunk_size = chunk_size
|
|
self._chunk_dim = dim
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.FloatTensor,
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
|
timestep: Optional[torch.LongTensor] = None,
|
|
cross_attention_kwargs: Dict[str, Any] = None,
|
|
class_labels: Optional[torch.LongTensor] = None,
|
|
):
|
|
assert attention_mask is None
|
|
|
|
|
|
if self.use_ada_layer_norm:
|
|
norm_hidden_states = self.norm1(hidden_states, timestep)
|
|
elif self.use_ada_layer_norm_zero:
|
|
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
|
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
|
)
|
|
else:
|
|
norm_hidden_states = self.norm1(hidden_states)
|
|
|
|
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
|
|
|
attn_output = self.attn1(
|
|
norm_hidden_states,
|
|
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
|
attention_mask=attention_mask,
|
|
multiview_attention=self.multiview_attention,
|
|
mvcd_attention=self.mvcd_attention,
|
|
num_views=self.num_views,
|
|
**cross_attention_kwargs,
|
|
)
|
|
|
|
if self.use_ada_layer_norm_zero:
|
|
attn_output = gate_msa.unsqueeze(1) * attn_output
|
|
hidden_states = attn_output + hidden_states
|
|
|
|
|
|
if self.cd_attention_mid:
|
|
norm_hidden_states = (
|
|
self.norm_joint_twice(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_twice(hidden_states)
|
|
)
|
|
hidden_states = self.attn_joint_twice(norm_hidden_states) + hidden_states
|
|
|
|
|
|
if self.attn2 is not None:
|
|
norm_hidden_states = (
|
|
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
|
)
|
|
|
|
attn_output = self.attn2(
|
|
norm_hidden_states,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
attention_mask=encoder_attention_mask,
|
|
**cross_attention_kwargs,
|
|
)
|
|
hidden_states = attn_output + hidden_states
|
|
|
|
|
|
norm_hidden_states = self.norm3(hidden_states)
|
|
|
|
if self.use_ada_layer_norm_zero:
|
|
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
|
|
|
if self._chunk_size is not None:
|
|
|
|
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
|
|
raise ValueError(
|
|
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
|
)
|
|
|
|
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
|
|
ff_output = torch.cat(
|
|
[self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
|
|
dim=self._chunk_dim,
|
|
)
|
|
else:
|
|
ff_output = self.ff(norm_hidden_states)
|
|
|
|
if self.use_ada_layer_norm_zero:
|
|
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
|
|
|
hidden_states = ff_output + hidden_states
|
|
|
|
if self.cd_attention_last:
|
|
norm_hidden_states = (
|
|
self.norm_joint(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint(hidden_states)
|
|
)
|
|
hidden_states = self.attn_joint(norm_hidden_states) + hidden_states
|
|
|
|
return hidden_states
|
|
|
|
|
|
class CustomAttention(Attention):
|
|
def set_use_memory_efficient_attention_xformers(
|
|
self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
|
|
):
|
|
processor = XFormersMVAttnProcessor()
|
|
self.set_processor(processor)
|
|
|
|
|
|
|
|
class CustomJointAttention(Attention):
|
|
def set_use_memory_efficient_attention_xformers(
|
|
self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
|
|
):
|
|
processor = XFormersJointAttnProcessor()
|
|
self.set_processor(processor)
|
|
|
|
|
|
class MVAttnProcessor:
|
|
r"""
|
|
Default processor for performing attention-related computations.
|
|
"""
|
|
|
|
def __call__(
|
|
self,
|
|
attn: Attention,
|
|
hidden_states,
|
|
encoder_hidden_states=None,
|
|
attention_mask=None,
|
|
temb=None,
|
|
num_views=1,
|
|
multiview_attention=True
|
|
):
|
|
residual = hidden_states
|
|
|
|
if attn.spatial_norm is not None:
|
|
hidden_states = attn.spatial_norm(hidden_states, temb)
|
|
|
|
input_ndim = hidden_states.ndim
|
|
|
|
if input_ndim == 4:
|
|
batch_size, channel, height, width = hidden_states.shape
|
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
|
|
|
batch_size, sequence_length, _ = (
|
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
|
)
|
|
height = int(math.sqrt(sequence_length))
|
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
|
|
|
if attn.group_norm is not None:
|
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
|
|
|
query = attn.to_q(hidden_states)
|
|
|
|
if encoder_hidden_states is None:
|
|
encoder_hidden_states = hidden_states
|
|
elif attn.norm_cross:
|
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
|
|
|
key = attn.to_k(encoder_hidden_states)
|
|
value = attn.to_v(encoder_hidden_states)
|
|
|
|
|
|
|
|
|
|
|
|
key = rearrange(key, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
|
|
value = rearrange(value, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
|
|
query = rearrange(query, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
|
|
|
|
query = attn.head_to_batch_dim(query).contiguous()
|
|
key = attn.head_to_batch_dim(key).contiguous()
|
|
value = attn.head_to_batch_dim(value).contiguous()
|
|
|
|
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
|
hidden_states = torch.bmm(attention_probs, value)
|
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states)
|
|
|
|
hidden_states = attn.to_out[1](hidden_states)
|
|
hidden_states = rearrange(hidden_states, "(b h) (v w) c -> (b v) (h w) c", v=num_views, h=height)
|
|
if input_ndim == 4:
|
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
|
|
|
if attn.residual_connection:
|
|
hidden_states = hidden_states + residual
|
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor
|
|
|
|
return hidden_states
|
|
|
|
|
|
class XFormersMVAttnProcessor:
|
|
r"""
|
|
Default processor for performing attention-related computations.
|
|
"""
|
|
|
|
def __call__(
|
|
self,
|
|
attn: Attention,
|
|
hidden_states,
|
|
encoder_hidden_states=None,
|
|
attention_mask=None,
|
|
temb=None,
|
|
num_views=1,
|
|
multiview_attention=True,
|
|
mvcd_attention=False,
|
|
):
|
|
residual = hidden_states
|
|
|
|
if attn.spatial_norm is not None:
|
|
hidden_states = attn.spatial_norm(hidden_states, temb)
|
|
|
|
input_ndim = hidden_states.ndim
|
|
|
|
if input_ndim == 4:
|
|
batch_size, channel, height, width = hidden_states.shape
|
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
|
|
|
batch_size, sequence_length, _ = (
|
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
|
)
|
|
height = int(math.sqrt(sequence_length))
|
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
|
|
|
if attention_mask is not None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_, query_tokens, _ = hidden_states.shape
|
|
attention_mask = attention_mask.expand(-1, query_tokens, -1)
|
|
|
|
if attn.group_norm is not None:
|
|
print('Warning: using group norm, pay attention to use it in row-wise attention')
|
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
|
|
|
query = attn.to_q(hidden_states)
|
|
|
|
if encoder_hidden_states is None:
|
|
encoder_hidden_states = hidden_states
|
|
elif attn.norm_cross:
|
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
|
|
|
key_raw = attn.to_k(encoder_hidden_states)
|
|
value_raw = attn.to_v(encoder_hidden_states)
|
|
|
|
|
|
|
|
|
|
key = rearrange(key_raw, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
|
|
value = rearrange(value_raw, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
|
|
query = rearrange(query, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
|
|
if mvcd_attention:
|
|
|
|
key_0, key_1 = torch.chunk(key_raw, dim=0, chunks=2)
|
|
value_0, value_1 = torch.chunk(value_raw, dim=0, chunks=2)
|
|
key_cross = torch.concat([key_1, key_0], dim=0)
|
|
value_cross = torch.concat([value_1, value_0], dim=0)
|
|
key = torch.cat([key, key_cross], dim=1)
|
|
value = torch.cat([value, value_cross], dim=1)
|
|
|
|
|
|
query = attn.head_to_batch_dim(query)
|
|
key = attn.head_to_batch_dim(key)
|
|
value = attn.head_to_batch_dim(value)
|
|
|
|
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
|
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states)
|
|
|
|
hidden_states = attn.to_out[1](hidden_states)
|
|
|
|
hidden_states = rearrange(hidden_states, "(b h) (v w) c -> (b v) (h w) c", v=num_views, h=height)
|
|
if input_ndim == 4:
|
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
|
|
|
if attn.residual_connection:
|
|
hidden_states = hidden_states + residual
|
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor
|
|
|
|
return hidden_states
|
|
|
|
|
|
class XFormersJointAttnProcessor:
|
|
r"""
|
|
Default processor for performing attention-related computations.
|
|
"""
|
|
|
|
def __call__(
|
|
self,
|
|
attn: Attention,
|
|
hidden_states,
|
|
encoder_hidden_states=None,
|
|
attention_mask=None,
|
|
temb=None,
|
|
num_tasks=2
|
|
):
|
|
|
|
residual = hidden_states
|
|
|
|
if attn.spatial_norm is not None:
|
|
hidden_states = attn.spatial_norm(hidden_states, temb)
|
|
|
|
input_ndim = hidden_states.ndim
|
|
|
|
if input_ndim == 4:
|
|
batch_size, channel, height, width = hidden_states.shape
|
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
|
|
|
batch_size, sequence_length, _ = (
|
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
|
)
|
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
|
|
|
|
|
if attention_mask is not None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_, query_tokens, _ = hidden_states.shape
|
|
attention_mask = attention_mask.expand(-1, query_tokens, -1)
|
|
|
|
if attn.group_norm is not None:
|
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
|
|
|
query = attn.to_q(hidden_states)
|
|
|
|
if encoder_hidden_states is None:
|
|
encoder_hidden_states = hidden_states
|
|
elif attn.norm_cross:
|
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
|
|
|
key = attn.to_k(encoder_hidden_states)
|
|
value = attn.to_v(encoder_hidden_states)
|
|
|
|
assert num_tasks == 2
|
|
|
|
key_0, key_1 = torch.chunk(key, dim=0, chunks=2)
|
|
value_0, value_1 = torch.chunk(value, dim=0, chunks=2)
|
|
key = torch.cat([key_0, key_1], dim=1)
|
|
value = torch.cat([value_0, value_1], dim=1)
|
|
key = torch.cat([key]*2, dim=0)
|
|
value = torch.cat([value]*2, dim=0)
|
|
|
|
|
|
query = attn.head_to_batch_dim(query).contiguous()
|
|
key = attn.head_to_batch_dim(key).contiguous()
|
|
value = attn.head_to_batch_dim(value).contiguous()
|
|
|
|
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
|
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states)
|
|
|
|
hidden_states = attn.to_out[1](hidden_states)
|
|
|
|
if input_ndim == 4:
|
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
|
|
|
if attn.residual_connection:
|
|
hidden_states = hidden_states + residual
|
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor
|
|
|
|
return hidden_states
|
|
|
|
|
|
class JointAttnProcessor:
|
|
r"""
|
|
Default processor for performing attention-related computations.
|
|
"""
|
|
|
|
def __call__(
|
|
self,
|
|
attn: Attention,
|
|
hidden_states,
|
|
encoder_hidden_states=None,
|
|
attention_mask=None,
|
|
temb=None,
|
|
num_tasks=2
|
|
):
|
|
|
|
residual = hidden_states
|
|
|
|
if attn.spatial_norm is not None:
|
|
hidden_states = attn.spatial_norm(hidden_states, temb)
|
|
|
|
input_ndim = hidden_states.ndim
|
|
|
|
if input_ndim == 4:
|
|
batch_size, channel, height, width = hidden_states.shape
|
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
|
|
|
batch_size, sequence_length, _ = (
|
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
|
)
|
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
|
|
|
|
|
if attn.group_norm is not None:
|
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
|
|
|
query = attn.to_q(hidden_states)
|
|
|
|
if encoder_hidden_states is None:
|
|
encoder_hidden_states = hidden_states
|
|
elif attn.norm_cross:
|
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
|
|
|
key = attn.to_k(encoder_hidden_states)
|
|
value = attn.to_v(encoder_hidden_states)
|
|
|
|
assert num_tasks == 2
|
|
|
|
key_0, key_1 = torch.chunk(key, dim=0, chunks=2)
|
|
value_0, value_1 = torch.chunk(value, dim=0, chunks=2)
|
|
key = torch.cat([key_0, key_1], dim=1)
|
|
value = torch.cat([value_0, value_1], dim=1)
|
|
key = torch.cat([key]*2, dim=0)
|
|
value = torch.cat([value]*2, dim=0)
|
|
|
|
|
|
query = attn.head_to_batch_dim(query).contiguous()
|
|
key = attn.head_to_batch_dim(key).contiguous()
|
|
value = attn.head_to_batch_dim(value).contiguous()
|
|
|
|
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
|
hidden_states = torch.bmm(attention_probs, value)
|
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states)
|
|
|
|
hidden_states = attn.to_out[1](hidden_states)
|
|
|
|
if input_ndim == 4:
|
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
|
|
|
if attn.residual_connection:
|
|
hidden_states = hidden_states + residual
|
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor
|
|
|
|
return hidden_states |