from typing import Any, Dict, Optional, Tuple, Union import torch import torch.utils.checkpoint from diffusers import UNet2DConditionModel from diffusers.models.unet_2d_condition import UNet2DConditionOutput from diffusers.utils import logging from torch.fft import fftn, ifftn, fftshift, ifftshift """ This is a small extension of the standard UNet2DConditionModel with the small addition of the Free-U trick (https://github.com/ChenyangSi/FreeU). """ logger = logging.get_logger(__name__) # pylint: disable=invalid-name def Fourier_filter(x, threshold, scale): # FFT x_freq = fftn(x, dim=(-2, -1)) x_freq = fftshift(x_freq, dim=(-2, -1)) B, C, H, W = x_freq.shape mask = torch.ones((B, C, H, W)).cuda() # CUDA için crow, ccol = H // 2, W // 2 mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale x_freq = x_freq * mask # IFFT x_freq = ifftshift(x_freq, dim=(-2, -1)) x_filtered = ifftn(x_freq, dim=(-2, -1)).real return x_filtered class FreeUUNet2DConditionModel(UNet2DConditionModel): def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet2DConditionOutput, Tuple]: r""" The [`UNet2DConditionModel`] forward method. Args: sample (`torch.FloatTensor`): The noisy input tensor with the following shape `(batch, channel, height, width)`. timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. encoder_hidden_states (`torch.FloatTensor`): The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. encoder_attention_mask (`torch.Tensor`): A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to "discard" tokens. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. added_cond_kwargs: (`dict`, *optional*): A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that are passed along to the UNet blocks. Returns: [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). # However, the upsampling interpolation output size can be forced to fit any upsampling size # on the fly if necessary. default_overall_up_factor = 2 ** self.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` forward_upsample_size = False upsample_size = None if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True # ensure attention_mask is a bias, and give it a singleton query_tokens dimension # expects mask of shape: # [batch, key_tokens] # adds singleton query_tokens dimension: # [batch, 1, key_tokens] # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) if attention_mask is not None: # assume that mask is expressed as: # (1 = keep, 0 = discard) # convert mask into a bias that can be added to attention scores: # (keep = +0, discard = -10000.0) attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None: encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 # 1. time timesteps = timestep if not torch.is_tensor(timesteps): # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) # `Timesteps` does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) aug_emb = None if self.class_embedding is not None: if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) # `Timesteps` does not contain any weights and will always return f32 tensors # there might be better ways to encapsulate this. class_labels = class_labels.to(dtype=sample.dtype) class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) if self.config.class_embeddings_concat: emb = torch.cat([emb, class_emb], dim=-1) else: emb = emb + class_emb if self.config.addition_embed_type == "text": aug_emb = self.add_embedding(encoder_hidden_states) elif self.config.addition_embed_type == "text_image": # Kandinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) aug_emb = self.add_embedding(text_embs, image_embs) elif self.config.addition_embed_type == "text_time": # SDXL - style if "text_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" ) text_embeds = added_cond_kwargs.get("text_embeds") if "time_ids" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" ) time_ids = added_cond_kwargs.get("time_ids") time_embeds = self.add_time_proj(time_ids.flatten()) time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) add_embeds = add_embeds.to(emb.dtype) aug_emb = self.add_embedding(add_embeds) elif self.config.addition_embed_type == "image": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") aug_emb = self.add_embedding(image_embs) elif self.config.addition_embed_type == "image_hint": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") hint = added_cond_kwargs.get("hint") aug_emb, hint = self.add_embedding(image_embs, hint) sample = torch.cat([sample, hint], dim=1) emb = emb + aug_emb if aug_emb is not None else emb if self.time_embed_act is not None: emb = self.time_embed_act(emb) if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": # Kadinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(image_embeds) # 2. pre-process sample = self.conv_in(sample) # 3. down is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: # For t2i-adapter CrossAttnDownBlock2D additional_residuals = {} if is_adapter and len(down_block_additional_residuals) > 0: additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0) sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, **additional_residuals, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) if is_adapter and len(down_block_additional_residuals) > 0: sample += down_block_additional_residuals.pop(0) down_block_res_samples += res_samples if is_controlnet: new_down_block_res_samples = () for down_block_res_sample, down_block_additional_residual in zip( down_block_res_samples, down_block_additional_residuals ): down_block_res_sample = down_block_res_sample + down_block_additional_residual new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) down_block_res_samples = new_down_block_res_samples # 4. mid if self.mid_block is not None: sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, ) if is_controlnet: sample = sample + mid_block_additional_residual # 5. up for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets):] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # Add the Free-U trick here! # Fourier Filter if sample.shape[1] == 1280: sample[:, :640] *= 1.2 # 1.1 # For SD2.1 sample = Fourier_filter(sample, threshold=1, scale=0.9) if sample.shape[1] == 640: sample[:, :320] *= 1.4 # 1.2 # For SD2.1 sample = Fourier_filter(sample, threshold=1, scale=0.2) # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, ) else: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size ) # 6. post-process if self.conv_norm_out: sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) if not return_dict: return (sample,) return UNet2DConditionOutput(sample=sample)