from importlib import import_module from typing import Callable, Optional, Union from collections import deque import torch import torch.nn.functional as F from torch import nn from diffusers.models.attention_processor import Attention from diffusers.utils import USE_PEFT_BACKEND, deprecate, logging from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.torch_utils import maybe_allow_in_graph from diffusers.models.lora import LoRACompatibleLinear, LoRALinearLayer from .utils import get_nn_feats, random_bipartite_soft_matching if is_xformers_available(): import xformers import xformers.ops else: xformers = None class CachedSTAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). """ def __init__(self, name=None, use_feature_injection=False, feature_injection_strength=0.8, feature_similarity_threshold=0.98, interval=4, max_frames=1, use_tome_cache=False, tome_metric="keys", use_grid=False, tome_ratio=0.5): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") self.name = name self.use_feature_injection = use_feature_injection self.fi_strength = feature_injection_strength self.threshold = feature_similarity_threshold self.zero_tensor = torch.tensor(0) self.frame_id = torch.tensor(0) self.interval = torch.tensor(interval) self.max_frames = max_frames self.cached_key = None self.cached_value = None self.cached_output = None self.use_tome_cache = use_tome_cache self.tome_metric = tome_metric self.use_grid = use_grid self.tome_ratio = tome_ratio def _tome_step_kvout(self, keys, values, outputs): keys = torch.cat([self.cached_key, keys], dim=1) values = torch.cat([self.cached_value, values], dim=1) outputs = torch.cat([self.cached_output, outputs], dim=1) m_kv_out, _, _= random_bipartite_soft_matching(metric=keys, use_grid=self.use_grid, ratio=self.tome_ratio) compact_keys, compact_values, compact_outputs = m_kv_out(keys, values, outputs) self.cached_key = compact_keys self.cached_value = compact_values self.cached_output = compact_outputs def __call__( self, attn: Attention, hidden_states: torch.FloatTensor, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0, ) -> torch.FloatTensor: residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) args = () if USE_PEFT_BACKEND else (scale,) query = attn.to_q(hidden_states, *args) is_selfattn = False if encoder_hidden_states is None: is_selfattn = True encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states, *args) value = attn.to_v(encoder_hidden_states, *args) if is_selfattn: cached_key = key.clone() cached_value = value.clone() # Avoid if statement -> replace the dynamic graph to static graph if torch.equal(self.frame_id, self.zero_tensor): # ONNX self.cached_key = cached_key self.cached_value = cached_value key = torch.cat([key, self.cached_key], dim=1) value = torch.cat([value, self.cached_value], dim=1) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states, *args) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor if is_selfattn: cached_output = hidden_states.clone() if torch.equal(self.frame_id, self.zero_tensor): self.cached_output = cached_output if self.use_feature_injection and ("up_blocks.0" in self.name or "up_blocks.1" in self.name or 'mid_block' in self.name): nn_hidden_states = get_nn_feats(hidden_states, self.cached_output, threshold=self.threshold) hidden_states = hidden_states * (1-self.fi_strength) + self.fi_strength * nn_hidden_states mod_result = torch.remainder(self.frame_id, self.interval) if torch.equal(mod_result, self.zero_tensor) and is_selfattn: self._tome_step_kvout(cached_key, cached_value, cached_output) self.frame_id = self.frame_id + 1 return hidden_states class CachedSTXFormersAttnProcessor: r""" Processor for implementing memory efficient attention using xFormers. Args: attention_op (`Callable`, *optional*, defaults to `None`): The base [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator. """ def __init__(self, attention_op: Optional[Callable] = None, name=None, use_feature_injection=False, feature_injection_strength=0.8, feature_similarity_threshold=0.98, interval=4, max_frames=4, use_tome_cache=False, tome_metric="keys", use_grid=False, tome_ratio=0.5): self.attention_op = attention_op self.name = name self.use_feature_injection = use_feature_injection self.fi_strength = feature_injection_strength self.threshold = feature_similarity_threshold self.frame_id = 0 self.interval = interval self.cached_key = deque(maxlen=max_frames) self.cached_value = deque(maxlen=max_frames) self.cached_output = deque(maxlen=max_frames) self.use_tome_cache = use_tome_cache self.tome_metric = tome_metric self.use_grid = use_grid self.tome_ratio = tome_ratio def _tome_step_kvout(self, keys, values, outputs): if len(self.cached_value) == 1: keys = torch.cat(list(self.cached_key) + [keys], dim=1) values = torch.cat(list(self.cached_value) + [values], dim=1) outputs = torch.cat(list(self.cached_output) + [outputs], dim=1) m_kv_out, _, _= random_bipartite_soft_matching(metric=eval(self.tome_metric), use_grid=self.use_grid, ratio=self.tome_ratio) compact_keys, compact_values, compact_outputs = m_kv_out(keys, values, outputs) self.cached_key.append(compact_keys) self.cached_value.append(compact_values) self.cached_output.append(compact_outputs) else: self.cached_key.append(keys) self.cached_value.append(values) self.cached_output.append(outputs) def _tome_step_kv(self, keys, values): if len(self.cached_value) == 1: keys = torch.cat(list(self.cached_key) + [keys], dim=1) values = torch.cat(list(self.cached_value) + [values], dim=1) _, m_kv, _= random_bipartite_soft_matching(metric=eval(self.tome_metric), use_grid=self.use_grid, ratio=self.tome_ratio) compact_keys, compact_values = m_kv(keys, values) self.cached_key.append(compact_keys) self.cached_value.append(compact_values) else: self.cached_key.append(keys) self.cached_value.append(values) def _tome_step_out(self, outputs): if len(self.cached_value) == 1: outputs = torch.cat(list(self.cached_output) + [outputs], dim=1) _, _, m_out= random_bipartite_soft_matching(metric=outputs, use_grid=self.use_grid, ratio=self.tome_ratio) compact_outputs = m_out(outputs) self.cached_output.append(compact_outputs) else: self.cached_output.append(outputs) def __call__( self, attn: Attention, hidden_states: torch.FloatTensor, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0, ) -> torch.FloatTensor: residual = hidden_states args = () if USE_PEFT_BACKEND else (scale,) if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, key_tokens, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size) if attention_mask is not None: # expand our mask's singleton query_tokens dimension: # [batch*heads, 1, key_tokens] -> # [batch*heads, query_tokens, key_tokens] # so that it can be added as a bias onto the attention scores that xformers computes: # [batch*heads, query_tokens, key_tokens] # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. _, query_tokens, _ = hidden_states.shape attention_mask = attention_mask.expand(-1, query_tokens, -1) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states, *args) is_selfattn = False if encoder_hidden_states is None: is_selfattn = True encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states, *args) value = attn.to_v(encoder_hidden_states, *args) if is_selfattn: cached_key = key.clone() cached_value = value.clone() if len(self.cached_key) > 0: key = torch.cat([key] + list(self.cached_key), dim=1) value = torch.cat([value] + list(self.cached_value), dim=1) ## Code for storing and visualizing features # if self.frame_id % self.interval == 0: # # if "down_blocks.0" in self.name or "up_blocks.3" in self.name: # # feats = { # # "hidden_states": hidden_states.clone().cpu(), # # "query": query.clone().cpu(), # # "key": cached_key.cpu(), # # "value": cached_value.cpu(), # # } # # torch.save(feats, f'./outputs/self_attn_feats_SD/{self.name}.frame{self.frame_id}.pt') # if self.use_tome_cache: # cached_key, cached_value = self._tome_step(cached_key, cached_value) query = attn.head_to_batch_dim(query).contiguous() key = attn.head_to_batch_dim(key).contiguous() value = attn.head_to_batch_dim(value).contiguous() hidden_states = xformers.ops.memory_efficient_attention( query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale ) hidden_states = hidden_states.to(query.dtype) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states, *args) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor if is_selfattn: cached_output = hidden_states.clone() if self.use_feature_injection and ("up_blocks.0" in self.name or "up_blocks.1" in self.name or 'mid_block' in self.name): if len(self.cached_output) > 0: nn_hidden_states = get_nn_feats(hidden_states, self.cached_output, threshold=self.threshold) hidden_states = hidden_states * (1-self.fi_strength) + self.fi_strength * nn_hidden_states if self.frame_id % self.interval == 0: if is_selfattn: if self.use_tome_cache: self._tome_step_kvout(cached_key, cached_value, cached_output) else: self.cached_key.append(cached_key) self.cached_value.append(cached_value) self.cached_output.append(cached_output) self.frame_id += 1 return hidden_states