Spaces:
Runtime error
Runtime error
File size: 15,416 Bytes
69f3483 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 |
from importlib import import_module
from typing import Callable, Optional, Union
from collections import deque
import torch
import torch.nn.functional as F
from torch import nn
from diffusers.models.attention_processor import Attention
from diffusers.utils import USE_PEFT_BACKEND, deprecate, logging
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import maybe_allow_in_graph
from diffusers.models.lora import LoRACompatibleLinear, LoRALinearLayer
from .utils import get_nn_feats, random_bipartite_soft_matching
if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None
class CachedSTAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __init__(self, name=None, use_feature_injection=False,
feature_injection_strength=0.8,
feature_similarity_threshold=0.98,
interval=4,
max_frames=1,
use_tome_cache=False,
tome_metric="keys",
use_grid=False,
tome_ratio=0.5):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.name = name
self.use_feature_injection = use_feature_injection
self.fi_strength = feature_injection_strength
self.threshold = feature_similarity_threshold
self.zero_tensor = torch.tensor(0)
self.frame_id = torch.tensor(0)
self.interval = torch.tensor(interval)
self.max_frames = max_frames
self.cached_key = None
self.cached_value = None
self.cached_output = None
self.use_tome_cache = use_tome_cache
self.tome_metric = tome_metric
self.use_grid = use_grid
self.tome_ratio = tome_ratio
def _tome_step_kvout(self, keys, values, outputs):
keys = torch.cat([self.cached_key, keys], dim=1)
values = torch.cat([self.cached_value, values], dim=1)
outputs = torch.cat([self.cached_output, outputs], dim=1)
m_kv_out, _, _= random_bipartite_soft_matching(metric=keys, use_grid=self.use_grid, ratio=self.tome_ratio)
compact_keys, compact_values, compact_outputs = m_kv_out(keys, values, outputs)
self.cached_key = compact_keys
self.cached_value = compact_values
self.cached_output = compact_outputs
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
) -> torch.FloatTensor:
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
args = () if USE_PEFT_BACKEND else (scale,)
query = attn.to_q(hidden_states, *args)
is_selfattn = False
if encoder_hidden_states is None:
is_selfattn = True
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)
if is_selfattn:
cached_key = key.clone()
cached_value = value.clone()
# Avoid if statement -> replace the dynamic graph to static graph
if torch.equal(self.frame_id, self.zero_tensor):
# ONNX
self.cached_key = cached_key
self.cached_value = cached_value
key = torch.cat([key, self.cached_key], dim=1)
value = torch.cat([value, self.cached_value], dim=1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
if is_selfattn:
cached_output = hidden_states.clone()
if torch.equal(self.frame_id, self.zero_tensor):
self.cached_output = cached_output
if self.use_feature_injection and ("up_blocks.0" in self.name or "up_blocks.1" in self.name or 'mid_block' in self.name):
nn_hidden_states = get_nn_feats(hidden_states, self.cached_output, threshold=self.threshold)
hidden_states = hidden_states * (1-self.fi_strength) + self.fi_strength * nn_hidden_states
mod_result = torch.remainder(self.frame_id, self.interval)
if torch.equal(mod_result, self.zero_tensor) and is_selfattn:
self._tome_step_kvout(cached_key, cached_value, cached_output)
self.frame_id = self.frame_id + 1
return hidden_states
class CachedSTXFormersAttnProcessor:
r"""
Processor for implementing memory efficient attention using xFormers.
Args:
attention_op (`Callable`, *optional*, defaults to `None`):
The base
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
operator.
"""
def __init__(self, attention_op: Optional[Callable] = None, name=None,
use_feature_injection=False, feature_injection_strength=0.8, feature_similarity_threshold=0.98,
interval=4, max_frames=4, use_tome_cache=False, tome_metric="keys", use_grid=False, tome_ratio=0.5):
self.attention_op = attention_op
self.name = name
self.use_feature_injection = use_feature_injection
self.fi_strength = feature_injection_strength
self.threshold = feature_similarity_threshold
self.frame_id = 0
self.interval = interval
self.cached_key = deque(maxlen=max_frames)
self.cached_value = deque(maxlen=max_frames)
self.cached_output = deque(maxlen=max_frames)
self.use_tome_cache = use_tome_cache
self.tome_metric = tome_metric
self.use_grid = use_grid
self.tome_ratio = tome_ratio
def _tome_step_kvout(self, keys, values, outputs):
if len(self.cached_value) == 1:
keys = torch.cat(list(self.cached_key) + [keys], dim=1)
values = torch.cat(list(self.cached_value) + [values], dim=1)
outputs = torch.cat(list(self.cached_output) + [outputs], dim=1)
m_kv_out, _, _= random_bipartite_soft_matching(metric=eval(self.tome_metric), use_grid=self.use_grid, ratio=self.tome_ratio)
compact_keys, compact_values, compact_outputs = m_kv_out(keys, values, outputs)
self.cached_key.append(compact_keys)
self.cached_value.append(compact_values)
self.cached_output.append(compact_outputs)
else:
self.cached_key.append(keys)
self.cached_value.append(values)
self.cached_output.append(outputs)
def _tome_step_kv(self, keys, values):
if len(self.cached_value) == 1:
keys = torch.cat(list(self.cached_key) + [keys], dim=1)
values = torch.cat(list(self.cached_value) + [values], dim=1)
_, m_kv, _= random_bipartite_soft_matching(metric=eval(self.tome_metric), use_grid=self.use_grid, ratio=self.tome_ratio)
compact_keys, compact_values = m_kv(keys, values)
self.cached_key.append(compact_keys)
self.cached_value.append(compact_values)
else:
self.cached_key.append(keys)
self.cached_value.append(values)
def _tome_step_out(self, outputs):
if len(self.cached_value) == 1:
outputs = torch.cat(list(self.cached_output) + [outputs], dim=1)
_, _, m_out= random_bipartite_soft_matching(metric=outputs, use_grid=self.use_grid, ratio=self.tome_ratio)
compact_outputs = m_out(outputs)
self.cached_output.append(compact_outputs)
else:
self.cached_output.append(outputs)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
) -> torch.FloatTensor:
residual = hidden_states
args = () if USE_PEFT_BACKEND else (scale,)
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, key_tokens, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
if attention_mask is not None:
# expand our mask's singleton query_tokens dimension:
# [batch*heads, 1, key_tokens] ->
# [batch*heads, query_tokens, key_tokens]
# so that it can be added as a bias onto the attention scores that xformers computes:
# [batch*heads, query_tokens, key_tokens]
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
_, query_tokens, _ = hidden_states.shape
attention_mask = attention_mask.expand(-1, query_tokens, -1)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, *args)
is_selfattn = False
if encoder_hidden_states is None:
is_selfattn = True
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)
if is_selfattn:
cached_key = key.clone()
cached_value = value.clone()
if len(self.cached_key) > 0:
key = torch.cat([key] + list(self.cached_key), dim=1)
value = torch.cat([value] + list(self.cached_value), dim=1)
## Code for storing and visualizing features
# if self.frame_id % self.interval == 0:
# # if "down_blocks.0" in self.name or "up_blocks.3" in self.name:
# # feats = {
# # "hidden_states": hidden_states.clone().cpu(),
# # "query": query.clone().cpu(),
# # "key": cached_key.cpu(),
# # "value": cached_value.cpu(),
# # }
# # torch.save(feats, f'./outputs/self_attn_feats_SD/{self.name}.frame{self.frame_id}.pt')
# if self.use_tome_cache:
# cached_key, cached_value = self._tome_step(cached_key, cached_value)
query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
if is_selfattn:
cached_output = hidden_states.clone()
if self.use_feature_injection and ("up_blocks.0" in self.name or "up_blocks.1" in self.name or 'mid_block' in self.name):
if len(self.cached_output) > 0:
nn_hidden_states = get_nn_feats(hidden_states, self.cached_output, threshold=self.threshold)
hidden_states = hidden_states * (1-self.fi_strength) + self.fi_strength * nn_hidden_states
if self.frame_id % self.interval == 0:
if is_selfattn:
if self.use_tome_cache:
self._tome_step_kvout(cached_key, cached_value, cached_output)
else:
self.cached_key.append(cached_key)
self.cached_value.append(cached_value)
self.cached_output.append(cached_output)
self.frame_id += 1
return hidden_states
|