# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py import torch import torch.nn as nn import torch.nn.functional as F class AdaLayerNorm(nn.Module): def __init__(self, embedding_dim: int, time_embedding_dim: int = None): super().__init__() if time_embedding_dim is None: time_embedding_dim = embedding_dim self.silu = nn.SiLU() self.linear = nn.Linear(time_embedding_dim, 2 * embedding_dim, bias=True) nn.init.zeros_(self.linear.weight) nn.init.zeros_(self.linear.bias) self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) def forward( self, x: torch.Tensor, timestep_embedding: torch.Tensor ): emb = self.linear(self.silu(timestep_embedding)) shift, scale = emb.view(len(x), 1, -1).chunk(2, dim=-1) x = self.norm(x) * (1 + scale) + shift return x class AttnProcessor(nn.Module): r""" Default processor for performing attention-related computations. """ def __init__( self, hidden_size=None, cross_attention_dim=None, ): super().__init__() def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class IPAttnProcessor(nn.Module): r""" Attention processor for IP-Adapater. Args: hidden_size (`int`): The hidden size of the attention layer. cross_attention_dim (`int`): The number of channels in the `encoder_hidden_states`. scale (`float`, defaults to 1.0): the weight scale of image prompt. num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): The context length of the image features. """ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): super().__init__() self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim self.scale = scale self.num_tokens = num_tokens self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states else: # get encoder_hidden_states, ip_hidden_states end_pos = encoder_hidden_states.shape[1] - self.num_tokens encoder_hidden_states, ip_hidden_states = ( encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:, end_pos:, :], ) if attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # for ip-adapter ip_key = self.to_k_ip(ip_hidden_states) ip_value = self.to_v_ip(ip_hidden_states) ip_key = attn.head_to_batch_dim(ip_key) ip_value = attn.head_to_batch_dim(ip_value) ip_attention_probs = attn.get_attention_scores(query, ip_key, None) ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) hidden_states = hidden_states + self.scale * ip_hidden_states # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class TA_IPAttnProcessor(nn.Module): r""" Attention processor for IP-Adapater. Args: hidden_size (`int`): The hidden size of the attention layer. cross_attention_dim (`int`): The number of channels in the `encoder_hidden_states`. scale (`float`, defaults to 1.0): the weight scale of image prompt. num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): The context length of the image features. """ def __init__(self, hidden_size, cross_attention_dim=None, time_embedding_dim: int = None, scale=1.0, num_tokens=4): super().__init__() self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim self.scale = scale self.num_tokens = num_tokens self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) self.ln_k_ip = AdaLayerNorm(hidden_size, time_embedding_dim) self.ln_v_ip = AdaLayerNorm(hidden_size, time_embedding_dim) def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, ): assert temb is not None, "Timestep embedding is needed for a time-aware attention processor." residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states else: # get encoder_hidden_states, ip_hidden_states end_pos = encoder_hidden_states.shape[1] - self.num_tokens encoder_hidden_states, ip_hidden_states = ( encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:, end_pos:, :], ) if attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # for ip-adapter ip_key = self.to_k_ip(ip_hidden_states) ip_value = self.to_v_ip(ip_hidden_states) # time-dependent adaLN ip_key = self.ln_k_ip(ip_key, temb) ip_value = self.ln_v_ip(ip_value, temb) ip_key = attn.head_to_batch_dim(ip_key) ip_value = attn.head_to_batch_dim(ip_value) ip_attention_probs = attn.get_attention_scores(query, ip_key, None) ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) hidden_states = hidden_states + self.scale * ip_hidden_states # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class AttnProcessor2_0(torch.nn.Module): r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). """ def __init__( self, hidden_size=None, cross_attention_dim=None, ): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, external_kv=None, temb=None, ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) if external_kv: key = torch.cat([key, external_kv.k], axis=1) value = torch.cat([value, external_kv.v], axis=1) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class split_AttnProcessor2_0(torch.nn.Module): r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). """ def __init__( self, hidden_size=None, cross_attention_dim=None, time_embedding_dim=None, ): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, external_kv=None, temb=None, cat_dim=-2, original_shape=None, ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: # 2d to sequence. height, width = hidden_states.shape[-2:] if cat_dim==-2 or cat_dim==2: hidden_states_0 = hidden_states[:, :, :height//2, :] hidden_states_1 = hidden_states[:, :, -(height//2):, :] elif cat_dim==-1 or cat_dim==3: hidden_states_0 = hidden_states[:, :, :, :width//2] hidden_states_1 = hidden_states[:, :, :, -(width//2):] batch_size, channel, height, width = hidden_states_0.shape hidden_states_0 = hidden_states_0.view(batch_size, channel, height * width).transpose(1, 2) hidden_states_1 = hidden_states_1.view(batch_size, channel, height * width).transpose(1, 2) else: # directly split sqeuence according to concat dim. single_dim = original_shape[2] if cat_dim==-2 or cat_dim==2 else original_shape[1] hidden_states_0 = hidden_states[:, :single_dim*single_dim,:] hidden_states_1 = hidden_states[:, single_dim*(single_dim+1):,:] hidden_states = torch.cat([hidden_states_0, hidden_states_1], dim=1) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) if external_kv: key = torch.cat([key, external_kv.k], dim=1) value = torch.cat([value, external_kv.v], dim=1) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) # spatially split. hidden_states_0, hidden_states_1 = hidden_states.chunk(2, dim=1) if input_ndim == 4: hidden_states_0 = hidden_states_0.transpose(-1, -2).reshape(batch_size, channel, height, width) hidden_states_1 = hidden_states_1.transpose(-1, -2).reshape(batch_size, channel, height, width) if cat_dim==-2 or cat_dim==2: hidden_states_pad = torch.zeros(batch_size, channel, 1, width) elif cat_dim==-1 or cat_dim==3: hidden_states_pad = torch.zeros(batch_size, channel, height, 1) hidden_states_pad = hidden_states_pad.to(hidden_states_0.device, dtype=hidden_states_0.dtype) hidden_states = torch.cat([hidden_states_0, hidden_states_pad, hidden_states_1], dim=cat_dim) assert hidden_states.shape == residual.shape, f"{hidden_states.shape} != {residual.shape}" else: batch_size, sequence_length, inner_dim = hidden_states.shape hidden_states_pad = torch.zeros(batch_size, single_dim, inner_dim) hidden_states_pad = hidden_states_pad.to(hidden_states_0.device, dtype=hidden_states_0.dtype) hidden_states = torch.cat([hidden_states_0, hidden_states_pad, hidden_states_1], dim=1) assert hidden_states.shape == residual.shape, f"{hidden_states.shape} != {residual.shape}" if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class sep_split_AttnProcessor2_0(torch.nn.Module): r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). """ def __init__( self, hidden_size=None, cross_attention_dim=None, time_embedding_dim=None, ): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") self.ln_k_ref = AdaLayerNorm(hidden_size, time_embedding_dim) self.ln_v_ref = AdaLayerNorm(hidden_size, time_embedding_dim) # self.hidden_size = hidden_size # self.cross_attention_dim = cross_attention_dim # self.scale = scale # self.num_tokens = num_tokens # self.to_q_ref = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) # self.to_k_ref = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) # self.to_v_ref = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, external_kv=None, temb=None, cat_dim=-2, original_shape=None, ref_scale=1.0, ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: # 2d to sequence. height, width = hidden_states.shape[-2:] if cat_dim==-2 or cat_dim==2: hidden_states_0 = hidden_states[:, :, :height//2, :] hidden_states_1 = hidden_states[:, :, -(height//2):, :] elif cat_dim==-1 or cat_dim==3: hidden_states_0 = hidden_states[:, :, :, :width//2] hidden_states_1 = hidden_states[:, :, :, -(width//2):] batch_size, channel, height, width = hidden_states_0.shape hidden_states_0 = hidden_states_0.view(batch_size, channel, height * width).transpose(1, 2) hidden_states_1 = hidden_states_1.view(batch_size, channel, height * width).transpose(1, 2) else: # directly split sqeuence according to concat dim. single_dim = original_shape[2] if cat_dim==-2 or cat_dim==2 else original_shape[1] hidden_states_0 = hidden_states[:, :single_dim*single_dim,:] hidden_states_1 = hidden_states[:, single_dim*(single_dim+1):,:] batch_size, sequence_length, _ = ( hidden_states_0.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states_0 = attn.group_norm(hidden_states_0.transpose(1, 2)).transpose(1, 2) hidden_states_1 = attn.group_norm(hidden_states_1.transpose(1, 2)).transpose(1, 2) query_0 = attn.to_q(hidden_states_0) query_1 = attn.to_q(hidden_states_1) key_0 = attn.to_k(hidden_states_0) key_1 = attn.to_k(hidden_states_1) value_0 = attn.to_v(hidden_states_0) value_1 = attn.to_v(hidden_states_1) # time-dependent adaLN key_1 = self.ln_k_ref(key_1, temb) value_1 = self.ln_v_ref(value_1, temb) if external_kv: key_1 = torch.cat([key_1, external_kv.k], dim=1) value_1 = torch.cat([value_1, external_kv.v], dim=1) inner_dim = key_0.shape[-1] head_dim = inner_dim // attn.heads query_0 = query_0.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) query_1 = query_1.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key_0 = key_0.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key_1 = key_1.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value_0 = value_0.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value_1 = value_1.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states_0 = F.scaled_dot_product_attention( query_0, key_0, value_0, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states_1 = F.scaled_dot_product_attention( query_1, key_1, value_1, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) # cross-attn _hidden_states_0 = F.scaled_dot_product_attention( query_0, key_1, value_1, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states_0 = hidden_states_0 + ref_scale * _hidden_states_0 * 10 # TODO: drop this cross-attn _hidden_states_1 = F.scaled_dot_product_attention( query_1, key_0, value_0, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states_1 = hidden_states_1 + ref_scale * _hidden_states_1 hidden_states_0 = hidden_states_0.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states_1 = hidden_states_1.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states_0 = hidden_states_0.to(query_0.dtype) hidden_states_1 = hidden_states_1.to(query_1.dtype) # linear proj hidden_states_0 = attn.to_out[0](hidden_states_0) hidden_states_1 = attn.to_out[0](hidden_states_1) # dropout hidden_states_0 = attn.to_out[1](hidden_states_0) hidden_states_1 = attn.to_out[1](hidden_states_1) if input_ndim == 4: hidden_states_0 = hidden_states_0.transpose(-1, -2).reshape(batch_size, channel, height, width) hidden_states_1 = hidden_states_1.transpose(-1, -2).reshape(batch_size, channel, height, width) if cat_dim==-2 or cat_dim==2: hidden_states_pad = torch.zeros(batch_size, channel, 1, width) elif cat_dim==-1 or cat_dim==3: hidden_states_pad = torch.zeros(batch_size, channel, height, 1) hidden_states_pad = hidden_states_pad.to(hidden_states_0.device, dtype=hidden_states_0.dtype) hidden_states = torch.cat([hidden_states_0, hidden_states_pad, hidden_states_1], dim=cat_dim) assert hidden_states.shape == residual.shape, f"{hidden_states.shape} != {residual.shape}" else: batch_size, sequence_length, inner_dim = hidden_states.shape hidden_states_pad = torch.zeros(batch_size, single_dim, inner_dim) hidden_states_pad = hidden_states_pad.to(hidden_states_0.device, dtype=hidden_states_0.dtype) hidden_states = torch.cat([hidden_states_0, hidden_states_pad, hidden_states_1], dim=1) assert hidden_states.shape == residual.shape, f"{hidden_states.shape} != {residual.shape}" if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class AdditiveKV_AttnProcessor2_0(torch.nn.Module): r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). """ def __init__( self, hidden_size: int = None, cross_attention_dim: int = None, time_embedding_dim: int = None, additive_scale: float = 1.0, ): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") self.additive_scale = additive_scale def __call__( self, attn, hidden_states, encoder_hidden_states=None, external_kv=None, attention_mask=None, temb=None, ): assert temb is not None, "Timestep embedding is needed for a time-aware attention processor." residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) if external_kv: key = external_kv.k value = external_kv.v key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) external_attn_output = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) external_attn_output = external_attn_output.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states + self.additive_scale * external_attn_output hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class TA_AdditiveKV_AttnProcessor2_0(torch.nn.Module): r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). """ def __init__( self, hidden_size: int = None, cross_attention_dim: int = None, time_embedding_dim: int = None, additive_scale: float = 1.0, ): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") self.ln_k = AdaLayerNorm(hidden_size, time_embedding_dim) self.ln_v = AdaLayerNorm(hidden_size, time_embedding_dim) self.additive_scale = additive_scale def __call__( self, attn, hidden_states, encoder_hidden_states=None, external_kv=None, attention_mask=None, temb=None, ): assert temb is not None, "Timestep embedding is needed for a time-aware attention processor." residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) if external_kv: key = external_kv.k value = external_kv.v # time-dependent adaLN key = self.ln_k(key, temb) value = self.ln_v(value, temb) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) external_attn_output = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) external_attn_output = external_attn_output.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states + self.additive_scale * external_attn_output hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class IPAttnProcessor2_0(torch.nn.Module): r""" Attention processor for IP-Adapater for PyTorch 2.0. Args: hidden_size (`int`): The hidden size of the attention layer. cross_attention_dim (`int`): The number of channels in the `encoder_hidden_states`. scale (`float`, defaults to 1.0): the weight scale of image prompt. num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): The context length of the image features. """ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim self.scale = scale self.num_tokens = num_tokens self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) if isinstance(encoder_hidden_states, tuple): # FIXME: now hard coded to single image prompt. batch_size, _, hid_dim = encoder_hidden_states[0].shape ip_tokens = encoder_hidden_states[1][0] encoder_hidden_states = torch.cat([encoder_hidden_states[0], ip_tokens], dim=1) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states else: # get encoder_hidden_states, ip_hidden_states end_pos = encoder_hidden_states.shape[1] - self.num_tokens encoder_hidden_states, ip_hidden_states = ( encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:, end_pos:, :], ) if attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # for ip-adapter ip_key = self.to_k_ip(ip_hidden_states) ip_value = self.to_v_ip(ip_hidden_states) ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 ip_hidden_states = F.scaled_dot_product_attention( query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False ) ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) ip_hidden_states = ip_hidden_states.to(query.dtype) hidden_states = hidden_states + self.scale * ip_hidden_states # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class TA_IPAttnProcessor2_0(torch.nn.Module): r""" Attention processor for IP-Adapater for PyTorch 2.0. Args: hidden_size (`int`): The hidden size of the attention layer. cross_attention_dim (`int`): The number of channels in the `encoder_hidden_states`. scale (`float`, defaults to 1.0): the weight scale of image prompt. num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): The context length of the image features. """ def __init__(self, hidden_size, cross_attention_dim=None, time_embedding_dim: int = None, scale=1.0, num_tokens=4): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim self.scale = scale self.num_tokens = num_tokens self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) self.ln_k_ip = AdaLayerNorm(hidden_size, time_embedding_dim) self.ln_v_ip = AdaLayerNorm(hidden_size, time_embedding_dim) def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, external_kv=None, temb=None, ): assert temb is not None, "Timestep embedding is needed for a time-aware attention processor." residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) if not isinstance(encoder_hidden_states, tuple): # get encoder_hidden_states, ip_hidden_states end_pos = encoder_hidden_states.shape[1] - self.num_tokens encoder_hidden_states, ip_hidden_states = ( encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:, end_pos:, :], ) else: # FIXME: now hard coded to single image prompt. batch_size, _, hid_dim = encoder_hidden_states[0].shape ip_hidden_states = encoder_hidden_states[1][0] encoder_hidden_states = encoder_hidden_states[0] batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states else: if attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) if external_kv: key = torch.cat([key, external_kv.k], axis=1) value = torch.cat([value, external_kv.v], axis=1) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # for ip-adapter ip_key = self.to_k_ip(ip_hidden_states) ip_value = self.to_v_ip(ip_hidden_states) # time-dependent adaLN ip_key = self.ln_k_ip(ip_key, temb) ip_value = self.ln_v_ip(ip_value, temb) ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 ip_hidden_states = F.scaled_dot_product_attention( query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False ) ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) ip_hidden_states = ip_hidden_states.to(query.dtype) hidden_states = hidden_states + self.scale * ip_hidden_states # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states ## for controlnet class CNAttnProcessor: r""" Default processor for performing attention-related computations. """ def __init__(self, num_tokens=4): self.num_tokens = num_tokens def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states else: end_pos = encoder_hidden_states.shape[1] - self.num_tokens encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text if attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class CNAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). """ def __init__(self, num_tokens=4): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") self.num_tokens = num_tokens def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states else: end_pos = encoder_hidden_states.shape[1] - self.num_tokens encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text if attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states def init_attn_proc(unet, ip_adapter_tokens=16, use_lcm=False, use_adaln=True, use_external_kv=False): attn_procs = {} unet_sd = unet.state_dict() for name in unet.attn_processors.keys(): cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim if name.startswith("mid_block"): hidden_size = unet.config.block_out_channels[-1] elif name.startswith("up_blocks"): block_id = int(name[len("up_blocks.")]) hidden_size = list(reversed(unet.config.block_out_channels))[block_id] elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] if cross_attention_dim is None: if use_external_kv: attn_procs[name] = AdditiveKV_AttnProcessor2_0( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, time_embedding_dim=1280, ) if hasattr(F, "scaled_dot_product_attention") else AdditiveKV_AttnProcessor() else: attn_procs[name] = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor() else: if use_adaln: layer_name = name.split(".processor")[0] if use_lcm: weights = { "to_k_ip.weight": unet_sd[layer_name + ".to_k.base_layer.weight"], "to_v_ip.weight": unet_sd[layer_name + ".to_v.base_layer.weight"], } else: weights = { "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"], "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"], } attn_procs[name] = TA_IPAttnProcessor2_0( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=ip_adapter_tokens, time_embedding_dim=1280, ) if hasattr(F, "scaled_dot_product_attention") else \ TA_IPAttnProcessor( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=ip_adapter_tokens, time_embedding_dim=1280, ) attn_procs[name].load_state_dict(weights, strict=False) else: attn_procs[name] = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor() return attn_procs def init_aggregator_attn_proc(unet, use_adaln=False, split_attn=False): attn_procs = {} unet_sd = unet.state_dict() for name in unet.attn_processors.keys(): # get layer name and hidden dim cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim if name.startswith("mid_block"): hidden_size = unet.config.block_out_channels[-1] elif name.startswith("up_blocks"): block_id = int(name[len("up_blocks.")]) hidden_size = list(reversed(unet.config.block_out_channels))[block_id] elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] # init attn proc if split_attn: # layer_name = name.split(".processor")[0] # weights = { # "to_q_ref.weight": unet_sd[layer_name + ".to_q.weight"], # "to_k_ref.weight": unet_sd[layer_name + ".to_k.weight"], # "to_v_ref.weight": unet_sd[layer_name + ".to_v.weight"], # } attn_procs[name] = ( sep_split_AttnProcessor2_0( hidden_size=hidden_size, cross_attention_dim=hidden_size, time_embedding_dim=1280, ) if use_adaln else split_AttnProcessor2_0( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, time_embedding_dim=1280, ) ) # attn_procs[name].load_state_dict(weights, strict=False) else: attn_procs[name] = ( AttnProcessor2_0( hidden_size=hidden_size, cross_attention_dim=hidden_size, ) if hasattr(F, "scaled_dot_product_attention") else AttnProcessor( hidden_size=hidden_size, cross_attention_dim=hidden_size, ) ) return attn_procs