1inkusFace commited on
Commit
ec04216
·
verified ·
1 Parent(s): dbcf177

Create models/attention.py

Browse files
Files changed (1) hide show
  1. models/attention.py +1245 -0
models/attention.py ADDED
@@ -0,0 +1,1245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, List, Optional, Tuple
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+ from diffusers.utils import deprecate, logging
21
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
22
+ from diffusers.models.activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
23
+ from diffusers.models.attention_processor import Attention, JointAttnProcessor2_0
24
+ from diffusers.models.embeddings import SinusoidalPositionalEmbedding
25
+ from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
26
+
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
32
+ # "feed_forward_chunk_size" can be used to save memory
33
+ if hidden_states.shape[chunk_dim] % chunk_size != 0:
34
+ raise ValueError(
35
+ f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
36
+ )
37
+
38
+ num_chunks = hidden_states.shape[chunk_dim] // chunk_size
39
+ ff_output = torch.cat(
40
+ [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
41
+ dim=chunk_dim,
42
+ )
43
+ return ff_output
44
+
45
+
46
+ @maybe_allow_in_graph
47
+ class GatedSelfAttentionDense(nn.Module):
48
+ r"""
49
+ A gated self-attention dense layer that combines visual features and object features.
50
+
51
+ Parameters:
52
+ query_dim (`int`): The number of channels in the query.
53
+ context_dim (`int`): The number of channels in the context.
54
+ n_heads (`int`): The number of heads to use for attention.
55
+ d_head (`int`): The number of channels in each head.
56
+ """
57
+
58
+ def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
59
+ super().__init__()
60
+
61
+ # we need a linear projection since we need cat visual feature and obj feature
62
+ self.linear = nn.Linear(context_dim, query_dim)
63
+
64
+ self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
65
+ self.ff = FeedForward(query_dim, activation_fn="geglu")
66
+
67
+ self.norm1 = nn.LayerNorm(query_dim)
68
+ self.norm2 = nn.LayerNorm(query_dim)
69
+
70
+ self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
71
+ self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
72
+
73
+ self.enabled = True
74
+
75
+ def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
76
+ if not self.enabled:
77
+ return x
78
+
79
+ n_visual = x.shape[1]
80
+ objs = self.linear(objs)
81
+
82
+ x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
83
+ x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
84
+
85
+ return x
86
+
87
+
88
+ @maybe_allow_in_graph
89
+ class JointTransformerBlock(nn.Module):
90
+ r"""
91
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
92
+
93
+ Reference: https://arxiv.org/abs/2403.03206
94
+
95
+ Parameters:
96
+ dim (`int`): The number of channels in the input and output.
97
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
98
+ attention_head_dim (`int`): The number of channels in each head.
99
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
100
+ processing of `context` conditions.
101
+ """
102
+
103
+ def __init__(
104
+ self,
105
+ dim: int,
106
+ num_attention_heads: int,
107
+ attention_head_dim: int,
108
+ context_pre_only: bool = False,
109
+ qk_norm: Optional[str] = None,
110
+ use_dual_attention: bool = False,
111
+ ):
112
+ super().__init__()
113
+
114
+ self.use_dual_attention = use_dual_attention
115
+ self.context_pre_only = context_pre_only
116
+ context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
117
+
118
+ if use_dual_attention:
119
+ self.norm1 = SD35AdaLayerNormZeroX(dim)
120
+ else:
121
+ self.norm1 = AdaLayerNormZero(dim)
122
+
123
+ if context_norm_type == "ada_norm_continous":
124
+ self.norm1_context = AdaLayerNormContinuous(
125
+ dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
126
+ )
127
+ elif context_norm_type == "ada_norm_zero":
128
+ self.norm1_context = AdaLayerNormZero(dim)
129
+ else:
130
+ raise ValueError(
131
+ f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
132
+ )
133
+
134
+ if hasattr(F, "scaled_dot_product_attention"):
135
+ processor = JointAttnProcessor2_0()
136
+ else:
137
+ raise ValueError(
138
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
139
+ )
140
+
141
+ self.attn = Attention(
142
+ query_dim=dim,
143
+ cross_attention_dim=None,
144
+ added_kv_proj_dim=dim,
145
+ dim_head=attention_head_dim,
146
+ heads=num_attention_heads,
147
+ out_dim=dim,
148
+ context_pre_only=context_pre_only,
149
+ bias=True,
150
+ processor=processor,
151
+ qk_norm=qk_norm,
152
+ eps=1e-6,
153
+ )
154
+
155
+ if use_dual_attention:
156
+ self.attn2 = Attention(
157
+ query_dim=dim,
158
+ cross_attention_dim=None,
159
+ dim_head=attention_head_dim,
160
+ heads=num_attention_heads,
161
+ out_dim=dim,
162
+ bias=True,
163
+ processor=processor,
164
+ qk_norm=qk_norm,
165
+ eps=1e-6,
166
+ )
167
+ else:
168
+ self.attn2 = None
169
+
170
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
171
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
172
+
173
+ if not context_pre_only:
174
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
175
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
176
+ else:
177
+ self.norm2_context = None
178
+ self.ff_context = None
179
+
180
+ # let chunk size default to None
181
+ self._chunk_size = None
182
+ self._chunk_dim = 0
183
+
184
+ # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
185
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
186
+ # Sets chunk feed-forward
187
+ self._chunk_size = chunk_size
188
+ self._chunk_dim = dim
189
+
190
+ def forward(
191
+ self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor,
192
+ joint_attention_kwargs=None,
193
+ ):
194
+ if self.use_dual_attention:
195
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
196
+ hidden_states, emb=temb
197
+ )
198
+ else:
199
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
200
+
201
+ if self.context_pre_only:
202
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
203
+ else:
204
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
205
+ encoder_hidden_states, emb=temb
206
+ )
207
+
208
+ # Attention.
209
+ attn_output, context_attn_output = self.attn(
210
+ hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states,
211
+ **({} if joint_attention_kwargs is None else joint_attention_kwargs),
212
+ )
213
+
214
+ # Process attention outputs for the `hidden_states`.
215
+ attn_output = gate_msa.unsqueeze(1) * attn_output
216
+ hidden_states = hidden_states + attn_output
217
+
218
+ if self.use_dual_attention:
219
+ attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **({} if joint_attention_kwargs is None else joint_attention_kwargs),)
220
+ attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
221
+ hidden_states = hidden_states + attn_output2
222
+
223
+ norm_hidden_states = self.norm2(hidden_states)
224
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
225
+ if self._chunk_size is not None:
226
+ # "feed_forward_chunk_size" can be used to save memory
227
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
228
+ else:
229
+ ff_output = self.ff(norm_hidden_states)
230
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
231
+
232
+ hidden_states = hidden_states + ff_output
233
+
234
+ # Process attention outputs for the `encoder_hidden_states`.
235
+ if self.context_pre_only:
236
+ encoder_hidden_states = None
237
+ else:
238
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
239
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
240
+
241
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
242
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
243
+ if self._chunk_size is not None:
244
+ # "feed_forward_chunk_size" can be used to save memory
245
+ context_ff_output = _chunked_feed_forward(
246
+ self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
247
+ )
248
+ else:
249
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
250
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
251
+
252
+ return encoder_hidden_states, hidden_states
253
+
254
+
255
+ @maybe_allow_in_graph
256
+ class BasicTransformerBlock(nn.Module):
257
+ r"""
258
+ A basic Transformer block.
259
+
260
+ Parameters:
261
+ dim (`int`): The number of channels in the input and output.
262
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
263
+ attention_head_dim (`int`): The number of channels in each head.
264
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
265
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
266
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
267
+ num_embeds_ada_norm (:
268
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
269
+ attention_bias (:
270
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
271
+ only_cross_attention (`bool`, *optional*):
272
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
273
+ double_self_attention (`bool`, *optional*):
274
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
275
+ upcast_attention (`bool`, *optional*):
276
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
277
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
278
+ Whether to use learnable elementwise affine parameters for normalization.
279
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
280
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
281
+ final_dropout (`bool` *optional*, defaults to False):
282
+ Whether to apply a final dropout after the last feed-forward layer.
283
+ attention_type (`str`, *optional*, defaults to `"default"`):
284
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
285
+ positional_embeddings (`str`, *optional*, defaults to `None`):
286
+ The type of positional embeddings to apply to.
287
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
288
+ The maximum number of positional embeddings to apply.
289
+ """
290
+
291
+ def __init__(
292
+ self,
293
+ dim: int,
294
+ num_attention_heads: int,
295
+ attention_head_dim: int,
296
+ dropout=0.0,
297
+ cross_attention_dim: Optional[int] = None,
298
+ activation_fn: str = "geglu",
299
+ num_embeds_ada_norm: Optional[int] = None,
300
+ attention_bias: bool = False,
301
+ only_cross_attention: bool = False,
302
+ double_self_attention: bool = False,
303
+ upcast_attention: bool = False,
304
+ norm_elementwise_affine: bool = True,
305
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
306
+ norm_eps: float = 1e-5,
307
+ final_dropout: bool = False,
308
+ attention_type: str = "default",
309
+ positional_embeddings: Optional[str] = None,
310
+ num_positional_embeddings: Optional[int] = None,
311
+ ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
312
+ ada_norm_bias: Optional[int] = None,
313
+ ff_inner_dim: Optional[int] = None,
314
+ ff_bias: bool = True,
315
+ attention_out_bias: bool = True,
316
+ ):
317
+ super().__init__()
318
+ self.dim = dim
319
+ self.num_attention_heads = num_attention_heads
320
+ self.attention_head_dim = attention_head_dim
321
+ self.dropout = dropout
322
+ self.cross_attention_dim = cross_attention_dim
323
+ self.activation_fn = activation_fn
324
+ self.attention_bias = attention_bias
325
+ self.double_self_attention = double_self_attention
326
+ self.norm_elementwise_affine = norm_elementwise_affine
327
+ self.positional_embeddings = positional_embeddings
328
+ self.num_positional_embeddings = num_positional_embeddings
329
+ self.only_cross_attention = only_cross_attention
330
+
331
+ # We keep these boolean flags for backward-compatibility.
332
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
333
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
334
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
335
+ self.use_layer_norm = norm_type == "layer_norm"
336
+ self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
337
+
338
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
339
+ raise ValueError(
340
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
341
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
342
+ )
343
+
344
+ self.norm_type = norm_type
345
+ self.num_embeds_ada_norm = num_embeds_ada_norm
346
+
347
+ if positional_embeddings and (num_positional_embeddings is None):
348
+ raise ValueError(
349
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
350
+ )
351
+
352
+ if positional_embeddings == "sinusoidal":
353
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
354
+ else:
355
+ self.pos_embed = None
356
+
357
+ # Define 3 blocks. Each block has its own normalization layer.
358
+ # 1. Self-Attn
359
+ if norm_type == "ada_norm":
360
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
361
+ elif norm_type == "ada_norm_zero":
362
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
363
+ elif norm_type == "ada_norm_continuous":
364
+ self.norm1 = AdaLayerNormContinuous(
365
+ dim,
366
+ ada_norm_continous_conditioning_embedding_dim,
367
+ norm_elementwise_affine,
368
+ norm_eps,
369
+ ada_norm_bias,
370
+ "rms_norm",
371
+ )
372
+ else:
373
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
374
+
375
+ self.attn1 = Attention(
376
+ query_dim=dim,
377
+ heads=num_attention_heads,
378
+ dim_head=attention_head_dim,
379
+ dropout=dropout,
380
+ bias=attention_bias,
381
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
382
+ upcast_attention=upcast_attention,
383
+ out_bias=attention_out_bias,
384
+ )
385
+
386
+ # 2. Cross-Attn
387
+ if cross_attention_dim is not None or double_self_attention:
388
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
389
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
390
+ # the second cross attention block.
391
+ if norm_type == "ada_norm":
392
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
393
+ elif norm_type == "ada_norm_continuous":
394
+ self.norm2 = AdaLayerNormContinuous(
395
+ dim,
396
+ ada_norm_continous_conditioning_embedding_dim,
397
+ norm_elementwise_affine,
398
+ norm_eps,
399
+ ada_norm_bias,
400
+ "rms_norm",
401
+ )
402
+ else:
403
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
404
+
405
+ self.attn2 = Attention(
406
+ query_dim=dim,
407
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
408
+ heads=num_attention_heads,
409
+ dim_head=attention_head_dim,
410
+ dropout=dropout,
411
+ bias=attention_bias,
412
+ upcast_attention=upcast_attention,
413
+ out_bias=attention_out_bias,
414
+ ) # is self-attn if encoder_hidden_states is none
415
+ else:
416
+ if norm_type == "ada_norm_single": # For Latte
417
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
418
+ else:
419
+ self.norm2 = None
420
+ self.attn2 = None
421
+
422
+ # 3. Feed-forward
423
+ if norm_type == "ada_norm_continuous":
424
+ self.norm3 = AdaLayerNormContinuous(
425
+ dim,
426
+ ada_norm_continous_conditioning_embedding_dim,
427
+ norm_elementwise_affine,
428
+ norm_eps,
429
+ ada_norm_bias,
430
+ "layer_norm",
431
+ )
432
+
433
+ elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]:
434
+ self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
435
+ elif norm_type == "layer_norm_i2vgen":
436
+ self.norm3 = None
437
+
438
+ self.ff = FeedForward(
439
+ dim,
440
+ dropout=dropout,
441
+ activation_fn=activation_fn,
442
+ final_dropout=final_dropout,
443
+ inner_dim=ff_inner_dim,
444
+ bias=ff_bias,
445
+ )
446
+
447
+ # 4. Fuser
448
+ if attention_type == "gated" or attention_type == "gated-text-image":
449
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
450
+
451
+ # 5. Scale-shift for PixArt-Alpha.
452
+ if norm_type == "ada_norm_single":
453
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
454
+
455
+ # let chunk size default to None
456
+ self._chunk_size = None
457
+ self._chunk_dim = 0
458
+
459
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
460
+ # Sets chunk feed-forward
461
+ self._chunk_size = chunk_size
462
+ self._chunk_dim = dim
463
+
464
+ def forward(
465
+ self,
466
+ hidden_states: torch.Tensor,
467
+ attention_mask: Optional[torch.Tensor] = None,
468
+ encoder_hidden_states: Optional[torch.Tensor] = None,
469
+ encoder_attention_mask: Optional[torch.Tensor] = None,
470
+ timestep: Optional[torch.LongTensor] = None,
471
+ cross_attention_kwargs: Dict[str, Any] = None,
472
+ class_labels: Optional[torch.LongTensor] = None,
473
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
474
+ ) -> torch.Tensor:
475
+ if cross_attention_kwargs is not None:
476
+ if cross_attention_kwargs.get("scale", None) is not None:
477
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
478
+
479
+ # Notice that normalization is always applied before the real computation in the following blocks.
480
+ # 0. Self-Attention
481
+ batch_size = hidden_states.shape[0]
482
+
483
+ if self.norm_type == "ada_norm":
484
+ norm_hidden_states = self.norm1(hidden_states, timestep)
485
+ elif self.norm_type == "ada_norm_zero":
486
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
487
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
488
+ )
489
+ elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
490
+ norm_hidden_states = self.norm1(hidden_states)
491
+ elif self.norm_type == "ada_norm_continuous":
492
+ norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
493
+ elif self.norm_type == "ada_norm_single":
494
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
495
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
496
+ ).chunk(6, dim=1)
497
+ norm_hidden_states = self.norm1(hidden_states)
498
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
499
+ else:
500
+ raise ValueError("Incorrect norm used")
501
+
502
+ if self.pos_embed is not None:
503
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
504
+
505
+ # 1. Prepare GLIGEN inputs
506
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
507
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
508
+
509
+ attn_output = self.attn1(
510
+ norm_hidden_states,
511
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
512
+ attention_mask=attention_mask,
513
+ **cross_attention_kwargs,
514
+ )
515
+
516
+ if self.norm_type == "ada_norm_zero":
517
+ attn_output = gate_msa.unsqueeze(1) * attn_output
518
+ elif self.norm_type == "ada_norm_single":
519
+ attn_output = gate_msa * attn_output
520
+
521
+ hidden_states = attn_output + hidden_states
522
+ if hidden_states.ndim == 4:
523
+ hidden_states = hidden_states.squeeze(1)
524
+
525
+ # 1.2 GLIGEN Control
526
+ if gligen_kwargs is not None:
527
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
528
+
529
+ # 3. Cross-Attention
530
+ if self.attn2 is not None:
531
+ if self.norm_type == "ada_norm":
532
+ norm_hidden_states = self.norm2(hidden_states, timestep)
533
+ elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
534
+ norm_hidden_states = self.norm2(hidden_states)
535
+ elif self.norm_type == "ada_norm_single":
536
+ # For PixArt norm2 isn't applied here:
537
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
538
+ norm_hidden_states = hidden_states
539
+ elif self.norm_type == "ada_norm_continuous":
540
+ norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
541
+ else:
542
+ raise ValueError("Incorrect norm")
543
+
544
+ if self.pos_embed is not None and self.norm_type != "ada_norm_single":
545
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
546
+
547
+ attn_output = self.attn2(
548
+ norm_hidden_states,
549
+ encoder_hidden_states=encoder_hidden_states,
550
+ attention_mask=encoder_attention_mask,
551
+ **cross_attention_kwargs,
552
+ )
553
+ hidden_states = attn_output + hidden_states
554
+
555
+ # 4. Feed-forward
556
+ # i2vgen doesn't have this norm 🤷‍♂️
557
+ if self.norm_type == "ada_norm_continuous":
558
+ norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
559
+ elif not self.norm_type == "ada_norm_single":
560
+ norm_hidden_states = self.norm3(hidden_states)
561
+
562
+ if self.norm_type == "ada_norm_zero":
563
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
564
+
565
+ if self.norm_type == "ada_norm_single":
566
+ norm_hidden_states = self.norm2(hidden_states)
567
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
568
+
569
+ if self._chunk_size is not None:
570
+ # "feed_forward_chunk_size" can be used to save memory
571
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
572
+ else:
573
+ ff_output = self.ff(norm_hidden_states)
574
+
575
+ if self.norm_type == "ada_norm_zero":
576
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
577
+ elif self.norm_type == "ada_norm_single":
578
+ ff_output = gate_mlp * ff_output
579
+
580
+ hidden_states = ff_output + hidden_states
581
+ if hidden_states.ndim == 4:
582
+ hidden_states = hidden_states.squeeze(1)
583
+
584
+ return hidden_states
585
+
586
+
587
+ class LuminaFeedForward(nn.Module):
588
+ r"""
589
+ A feed-forward layer.
590
+
591
+ Parameters:
592
+ hidden_size (`int`):
593
+ The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
594
+ hidden representations.
595
+ intermediate_size (`int`): The intermediate dimension of the feedforward layer.
596
+ multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple
597
+ of this value.
598
+ ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden
599
+ dimension. Defaults to None.
600
+ """
601
+
602
+ def __init__(
603
+ self,
604
+ dim: int,
605
+ inner_dim: int,
606
+ multiple_of: Optional[int] = 256,
607
+ ffn_dim_multiplier: Optional[float] = None,
608
+ ):
609
+ super().__init__()
610
+ inner_dim = int(2 * inner_dim / 3)
611
+ # custom hidden_size factor multiplier
612
+ if ffn_dim_multiplier is not None:
613
+ inner_dim = int(ffn_dim_multiplier * inner_dim)
614
+ inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
615
+
616
+ self.linear_1 = nn.Linear(
617
+ dim,
618
+ inner_dim,
619
+ bias=False,
620
+ )
621
+ self.linear_2 = nn.Linear(
622
+ inner_dim,
623
+ dim,
624
+ bias=False,
625
+ )
626
+ self.linear_3 = nn.Linear(
627
+ dim,
628
+ inner_dim,
629
+ bias=False,
630
+ )
631
+ self.silu = FP32SiLU()
632
+
633
+ def forward(self, x):
634
+ return self.linear_2(self.silu(self.linear_1(x)) * self.linear_3(x))
635
+
636
+
637
+ @maybe_allow_in_graph
638
+ class TemporalBasicTransformerBlock(nn.Module):
639
+ r"""
640
+ A basic Transformer block for video like data.
641
+
642
+ Parameters:
643
+ dim (`int`): The number of channels in the input and output.
644
+ time_mix_inner_dim (`int`): The number of channels for temporal attention.
645
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
646
+ attention_head_dim (`int`): The number of channels in each head.
647
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
648
+ """
649
+
650
+ def __init__(
651
+ self,
652
+ dim: int,
653
+ time_mix_inner_dim: int,
654
+ num_attention_heads: int,
655
+ attention_head_dim: int,
656
+ cross_attention_dim: Optional[int] = None,
657
+ ):
658
+ super().__init__()
659
+ self.is_res = dim == time_mix_inner_dim
660
+
661
+ self.norm_in = nn.LayerNorm(dim)
662
+
663
+ # Define 3 blocks. Each block has its own normalization layer.
664
+ # 1. Self-Attn
665
+ self.ff_in = FeedForward(
666
+ dim,
667
+ dim_out=time_mix_inner_dim,
668
+ activation_fn="geglu",
669
+ )
670
+
671
+ self.norm1 = nn.LayerNorm(time_mix_inner_dim)
672
+ self.attn1 = Attention(
673
+ query_dim=time_mix_inner_dim,
674
+ heads=num_attention_heads,
675
+ dim_head=attention_head_dim,
676
+ cross_attention_dim=None,
677
+ )
678
+
679
+ # 2. Cross-Attn
680
+ if cross_attention_dim is not None:
681
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
682
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
683
+ # the second cross attention block.
684
+ self.norm2 = nn.LayerNorm(time_mix_inner_dim)
685
+ self.attn2 = Attention(
686
+ query_dim=time_mix_inner_dim,
687
+ cross_attention_dim=cross_attention_dim,
688
+ heads=num_attention_heads,
689
+ dim_head=attention_head_dim,
690
+ ) # is self-attn if encoder_hidden_states is none
691
+ else:
692
+ self.norm2 = None
693
+ self.attn2 = None
694
+
695
+ # 3. Feed-forward
696
+ self.norm3 = nn.LayerNorm(time_mix_inner_dim)
697
+ self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")
698
+
699
+ # let chunk size default to None
700
+ self._chunk_size = None
701
+ self._chunk_dim = None
702
+
703
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
704
+ # Sets chunk feed-forward
705
+ self._chunk_size = chunk_size
706
+ # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off
707
+ self._chunk_dim = 1
708
+
709
+ def forward(
710
+ self,
711
+ hidden_states: torch.Tensor,
712
+ num_frames: int,
713
+ encoder_hidden_states: Optional[torch.Tensor] = None,
714
+ ) -> torch.Tensor:
715
+ # Notice that normalization is always applied before the real computation in the following blocks.
716
+ # 0. Self-Attention
717
+ batch_size = hidden_states.shape[0]
718
+
719
+ batch_frames, seq_length, channels = hidden_states.shape
720
+ batch_size = batch_frames // num_frames
721
+
722
+ hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
723
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
724
+ hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
725
+
726
+ residual = hidden_states
727
+ hidden_states = self.norm_in(hidden_states)
728
+
729
+ if self._chunk_size is not None:
730
+ hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size)
731
+ else:
732
+ hidden_states = self.ff_in(hidden_states)
733
+
734
+ if self.is_res:
735
+ hidden_states = hidden_states + residual
736
+
737
+ norm_hidden_states = self.norm1(hidden_states)
738
+ attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
739
+ hidden_states = attn_output + hidden_states
740
+
741
+ # 3. Cross-Attention
742
+ if self.attn2 is not None:
743
+ norm_hidden_states = self.norm2(hidden_states)
744
+ attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
745
+ hidden_states = attn_output + hidden_states
746
+
747
+ # 4. Feed-forward
748
+ norm_hidden_states = self.norm3(hidden_states)
749
+
750
+ if self._chunk_size is not None:
751
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
752
+ else:
753
+ ff_output = self.ff(norm_hidden_states)
754
+
755
+ if self.is_res:
756
+ hidden_states = ff_output + hidden_states
757
+ else:
758
+ hidden_states = ff_output
759
+
760
+ hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
761
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
762
+ hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
763
+
764
+ return hidden_states
765
+
766
+
767
+ class SkipFFTransformerBlock(nn.Module):
768
+ def __init__(
769
+ self,
770
+ dim: int,
771
+ num_attention_heads: int,
772
+ attention_head_dim: int,
773
+ kv_input_dim: int,
774
+ kv_input_dim_proj_use_bias: bool,
775
+ dropout=0.0,
776
+ cross_attention_dim: Optional[int] = None,
777
+ attention_bias: bool = False,
778
+ attention_out_bias: bool = True,
779
+ ):
780
+ super().__init__()
781
+ if kv_input_dim != dim:
782
+ self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias)
783
+ else:
784
+ self.kv_mapper = None
785
+
786
+ self.norm1 = RMSNorm(dim, 1e-06)
787
+
788
+ self.attn1 = Attention(
789
+ query_dim=dim,
790
+ heads=num_attention_heads,
791
+ dim_head=attention_head_dim,
792
+ dropout=dropout,
793
+ bias=attention_bias,
794
+ cross_attention_dim=cross_attention_dim,
795
+ out_bias=attention_out_bias,
796
+ )
797
+
798
+ self.norm2 = RMSNorm(dim, 1e-06)
799
+
800
+ self.attn2 = Attention(
801
+ query_dim=dim,
802
+ cross_attention_dim=cross_attention_dim,
803
+ heads=num_attention_heads,
804
+ dim_head=attention_head_dim,
805
+ dropout=dropout,
806
+ bias=attention_bias,
807
+ out_bias=attention_out_bias,
808
+ )
809
+
810
+ def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs):
811
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
812
+
813
+ if self.kv_mapper is not None:
814
+ encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states))
815
+
816
+ norm_hidden_states = self.norm1(hidden_states)
817
+
818
+ attn_output = self.attn1(
819
+ norm_hidden_states,
820
+ encoder_hidden_states=encoder_hidden_states,
821
+ **cross_attention_kwargs,
822
+ )
823
+
824
+ hidden_states = attn_output + hidden_states
825
+
826
+ norm_hidden_states = self.norm2(hidden_states)
827
+
828
+ attn_output = self.attn2(
829
+ norm_hidden_states,
830
+ encoder_hidden_states=encoder_hidden_states,
831
+ **cross_attention_kwargs,
832
+ )
833
+
834
+ hidden_states = attn_output + hidden_states
835
+
836
+ return hidden_states
837
+
838
+
839
+ @maybe_allow_in_graph
840
+ class FreeNoiseTransformerBlock(nn.Module):
841
+ r"""
842
+ A FreeNoise Transformer block.
843
+
844
+ Parameters:
845
+ dim (`int`):
846
+ The number of channels in the input and output.
847
+ num_attention_heads (`int`):
848
+ The number of heads to use for multi-head attention.
849
+ attention_head_dim (`int`):
850
+ The number of channels in each head.
851
+ dropout (`float`, *optional*, defaults to 0.0):
852
+ The dropout probability to use.
853
+ cross_attention_dim (`int`, *optional*):
854
+ The size of the encoder_hidden_states vector for cross attention.
855
+ activation_fn (`str`, *optional*, defaults to `"geglu"`):
856
+ Activation function to be used in feed-forward.
857
+ num_embeds_ada_norm (`int`, *optional*):
858
+ The number of diffusion steps used during training. See `Transformer2DModel`.
859
+ attention_bias (`bool`, defaults to `False`):
860
+ Configure if the attentions should contain a bias parameter.
861
+ only_cross_attention (`bool`, defaults to `False`):
862
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
863
+ double_self_attention (`bool`, defaults to `False`):
864
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
865
+ upcast_attention (`bool`, defaults to `False`):
866
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
867
+ norm_elementwise_affine (`bool`, defaults to `True`):
868
+ Whether to use learnable elementwise affine parameters for normalization.
869
+ norm_type (`str`, defaults to `"layer_norm"`):
870
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
871
+ final_dropout (`bool` defaults to `False`):
872
+ Whether to apply a final dropout after the last feed-forward layer.
873
+ attention_type (`str`, defaults to `"default"`):
874
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
875
+ positional_embeddings (`str`, *optional*):
876
+ The type of positional embeddings to apply to.
877
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
878
+ The maximum number of positional embeddings to apply.
879
+ ff_inner_dim (`int`, *optional*):
880
+ Hidden dimension of feed-forward MLP.
881
+ ff_bias (`bool`, defaults to `True`):
882
+ Whether or not to use bias in feed-forward MLP.
883
+ attention_out_bias (`bool`, defaults to `True`):
884
+ Whether or not to use bias in attention output project layer.
885
+ context_length (`int`, defaults to `16`):
886
+ The maximum number of frames that the FreeNoise block processes at once.
887
+ context_stride (`int`, defaults to `4`):
888
+ The number of frames to be skipped before starting to process a new batch of `context_length` frames.
889
+ weighting_scheme (`str`, defaults to `"pyramid"`):
890
+ The weighting scheme to use for weighting averaging of processed latent frames. As described in the
891
+ Equation 9. of the [FreeNoise](https://arxiv.org/abs/2310.15169) paper, "pyramid" is the default setting
892
+ used.
893
+ """
894
+
895
+ def __init__(
896
+ self,
897
+ dim: int,
898
+ num_attention_heads: int,
899
+ attention_head_dim: int,
900
+ dropout: float = 0.0,
901
+ cross_attention_dim: Optional[int] = None,
902
+ activation_fn: str = "geglu",
903
+ num_embeds_ada_norm: Optional[int] = None,
904
+ attention_bias: bool = False,
905
+ only_cross_attention: bool = False,
906
+ double_self_attention: bool = False,
907
+ upcast_attention: bool = False,
908
+ norm_elementwise_affine: bool = True,
909
+ norm_type: str = "layer_norm",
910
+ norm_eps: float = 1e-5,
911
+ final_dropout: bool = False,
912
+ positional_embeddings: Optional[str] = None,
913
+ num_positional_embeddings: Optional[int] = None,
914
+ ff_inner_dim: Optional[int] = None,
915
+ ff_bias: bool = True,
916
+ attention_out_bias: bool = True,
917
+ context_length: int = 16,
918
+ context_stride: int = 4,
919
+ weighting_scheme: str = "pyramid",
920
+ ):
921
+ super().__init__()
922
+ self.dim = dim
923
+ self.num_attention_heads = num_attention_heads
924
+ self.attention_head_dim = attention_head_dim
925
+ self.dropout = dropout
926
+ self.cross_attention_dim = cross_attention_dim
927
+ self.activation_fn = activation_fn
928
+ self.attention_bias = attention_bias
929
+ self.double_self_attention = double_self_attention
930
+ self.norm_elementwise_affine = norm_elementwise_affine
931
+ self.positional_embeddings = positional_embeddings
932
+ self.num_positional_embeddings = num_positional_embeddings
933
+ self.only_cross_attention = only_cross_attention
934
+
935
+ self.set_free_noise_properties(context_length, context_stride, weighting_scheme)
936
+
937
+ # We keep these boolean flags for backward-compatibility.
938
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
939
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
940
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
941
+ self.use_layer_norm = norm_type == "layer_norm"
942
+ self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
943
+
944
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
945
+ raise ValueError(
946
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
947
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
948
+ )
949
+
950
+ self.norm_type = norm_type
951
+ self.num_embeds_ada_norm = num_embeds_ada_norm
952
+
953
+ if positional_embeddings and (num_positional_embeddings is None):
954
+ raise ValueError(
955
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
956
+ )
957
+
958
+ if positional_embeddings == "sinusoidal":
959
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
960
+ else:
961
+ self.pos_embed = None
962
+
963
+ # Define 3 blocks. Each block has its own normalization layer.
964
+ # 1. Self-Attn
965
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
966
+
967
+ self.attn1 = Attention(
968
+ query_dim=dim,
969
+ heads=num_attention_heads,
970
+ dim_head=attention_head_dim,
971
+ dropout=dropout,
972
+ bias=attention_bias,
973
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
974
+ upcast_attention=upcast_attention,
975
+ out_bias=attention_out_bias,
976
+ )
977
+
978
+ # 2. Cross-Attn
979
+ if cross_attention_dim is not None or double_self_attention:
980
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
981
+
982
+ self.attn2 = Attention(
983
+ query_dim=dim,
984
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
985
+ heads=num_attention_heads,
986
+ dim_head=attention_head_dim,
987
+ dropout=dropout,
988
+ bias=attention_bias,
989
+ upcast_attention=upcast_attention,
990
+ out_bias=attention_out_bias,
991
+ ) # is self-attn if encoder_hidden_states is none
992
+
993
+ # 3. Feed-forward
994
+ self.ff = FeedForward(
995
+ dim,
996
+ dropout=dropout,
997
+ activation_fn=activation_fn,
998
+ final_dropout=final_dropout,
999
+ inner_dim=ff_inner_dim,
1000
+ bias=ff_bias,
1001
+ )
1002
+
1003
+ self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
1004
+
1005
+ # let chunk size default to None
1006
+ self._chunk_size = None
1007
+ self._chunk_dim = 0
1008
+
1009
+ def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]:
1010
+ frame_indices = []
1011
+ for i in range(0, num_frames - self.context_length + 1, self.context_stride):
1012
+ window_start = i
1013
+ window_end = min(num_frames, i + self.context_length)
1014
+ frame_indices.append((window_start, window_end))
1015
+ return frame_indices
1016
+
1017
+ def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]:
1018
+ if weighting_scheme == "flat":
1019
+ weights = [1.0] * num_frames
1020
+
1021
+ elif weighting_scheme == "pyramid":
1022
+ if num_frames % 2 == 0:
1023
+ # num_frames = 4 => [1, 2, 2, 1]
1024
+ mid = num_frames // 2
1025
+ weights = list(range(1, mid + 1))
1026
+ weights = weights + weights[::-1]
1027
+ else:
1028
+ # num_frames = 5 => [1, 2, 3, 2, 1]
1029
+ mid = (num_frames + 1) // 2
1030
+ weights = list(range(1, mid))
1031
+ weights = weights + [mid] + weights[::-1]
1032
+
1033
+ elif weighting_scheme == "delayed_reverse_sawtooth":
1034
+ if num_frames % 2 == 0:
1035
+ # num_frames = 4 => [0.01, 2, 2, 1]
1036
+ mid = num_frames // 2
1037
+ weights = [0.01] * (mid - 1) + [mid]
1038
+ weights = weights + list(range(mid, 0, -1))
1039
+ else:
1040
+ # num_frames = 5 => [0.01, 0.01, 3, 2, 1]
1041
+ mid = (num_frames + 1) // 2
1042
+ weights = [0.01] * mid
1043
+ weights = weights + list(range(mid, 0, -1))
1044
+ else:
1045
+ raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}")
1046
+
1047
+ return weights
1048
+
1049
+ def set_free_noise_properties(
1050
+ self, context_length: int, context_stride: int, weighting_scheme: str = "pyramid"
1051
+ ) -> None:
1052
+ self.context_length = context_length
1053
+ self.context_stride = context_stride
1054
+ self.weighting_scheme = weighting_scheme
1055
+
1056
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0) -> None:
1057
+ # Sets chunk feed-forward
1058
+ self._chunk_size = chunk_size
1059
+ self._chunk_dim = dim
1060
+
1061
+ def forward(
1062
+ self,
1063
+ hidden_states: torch.Tensor,
1064
+ attention_mask: Optional[torch.Tensor] = None,
1065
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1066
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1067
+ cross_attention_kwargs: Dict[str, Any] = None,
1068
+ *args,
1069
+ **kwargs,
1070
+ ) -> torch.Tensor:
1071
+ if cross_attention_kwargs is not None:
1072
+ if cross_attention_kwargs.get("scale", None) is not None:
1073
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
1074
+
1075
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
1076
+
1077
+ # hidden_states: [B x H x W, F, C]
1078
+ device = hidden_states.device
1079
+ dtype = hidden_states.dtype
1080
+
1081
+ num_frames = hidden_states.size(1)
1082
+ frame_indices = self._get_frame_indices(num_frames)
1083
+ frame_weights = self._get_frame_weights(self.context_length, self.weighting_scheme)
1084
+ frame_weights = torch.tensor(frame_weights, device=device, dtype=dtype).unsqueeze(0).unsqueeze(-1)
1085
+ is_last_frame_batch_complete = frame_indices[-1][1] == num_frames
1086
+
1087
+ # Handle out-of-bounds case if num_frames isn't perfectly divisible by context_length
1088
+ # For example, num_frames=25, context_length=16, context_stride=4, then we expect the ranges:
1089
+ # [(0, 16), (4, 20), (8, 24), (10, 26)]
1090
+ if not is_last_frame_batch_complete:
1091
+ if num_frames < self.context_length:
1092
+ raise ValueError(f"Expected {num_frames=} to be greater or equal than {self.context_length=}")
1093
+ last_frame_batch_length = num_frames - frame_indices[-1][1]
1094
+ frame_indices.append((num_frames - self.context_length, num_frames))
1095
+
1096
+ num_times_accumulated = torch.zeros((1, num_frames, 1), device=device)
1097
+ accumulated_values = torch.zeros_like(hidden_states)
1098
+
1099
+ for i, (frame_start, frame_end) in enumerate(frame_indices):
1100
+ # The reason for slicing here is to ensure that if (frame_end - frame_start) is to handle
1101
+ # cases like frame_indices=[(0, 16), (16, 20)], if the user provided a video with 19 frames, or
1102
+ # essentially a non-multiple of `context_length`.
1103
+ weights = torch.ones_like(num_times_accumulated[:, frame_start:frame_end])
1104
+ weights *= frame_weights
1105
+
1106
+ hidden_states_chunk = hidden_states[:, frame_start:frame_end]
1107
+
1108
+ # Notice that normalization is always applied before the real computation in the following blocks.
1109
+ # 1. Self-Attention
1110
+ norm_hidden_states = self.norm1(hidden_states_chunk)
1111
+
1112
+ if self.pos_embed is not None:
1113
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
1114
+
1115
+ attn_output = self.attn1(
1116
+ norm_hidden_states,
1117
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
1118
+ attention_mask=attention_mask,
1119
+ **cross_attention_kwargs,
1120
+ )
1121
+
1122
+ hidden_states_chunk = attn_output + hidden_states_chunk
1123
+ if hidden_states_chunk.ndim == 4:
1124
+ hidden_states_chunk = hidden_states_chunk.squeeze(1)
1125
+
1126
+ # 2. Cross-Attention
1127
+ if self.attn2 is not None:
1128
+ norm_hidden_states = self.norm2(hidden_states_chunk)
1129
+
1130
+ if self.pos_embed is not None and self.norm_type != "ada_norm_single":
1131
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
1132
+
1133
+ attn_output = self.attn2(
1134
+ norm_hidden_states,
1135
+ encoder_hidden_states=encoder_hidden_states,
1136
+ attention_mask=encoder_attention_mask,
1137
+ **cross_attention_kwargs,
1138
+ )
1139
+ hidden_states_chunk = attn_output + hidden_states_chunk
1140
+
1141
+ if i == len(frame_indices) - 1 and not is_last_frame_batch_complete:
1142
+ accumulated_values[:, -last_frame_batch_length:] += (
1143
+ hidden_states_chunk[:, -last_frame_batch_length:] * weights[:, -last_frame_batch_length:]
1144
+ )
1145
+ num_times_accumulated[:, -last_frame_batch_length:] += weights[:, -last_frame_batch_length]
1146
+ else:
1147
+ accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights
1148
+ num_times_accumulated[:, frame_start:frame_end] += weights
1149
+
1150
+ # TODO(aryan): Maybe this could be done in a better way.
1151
+ #
1152
+ # Previously, this was:
1153
+ # hidden_states = torch.where(
1154
+ # num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
1155
+ # )
1156
+ #
1157
+ # The reasoning for the change here is `torch.where` became a bottleneck at some point when golfing memory
1158
+ # spikes. It is particularly noticeable when the number of frames is high. My understanding is that this comes
1159
+ # from tensors being copied - which is why we resort to spliting and concatenating here. I've not particularly
1160
+ # looked into this deeply because other memory optimizations led to more pronounced reductions.
1161
+ hidden_states = torch.cat(
1162
+ [
1163
+ torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split)
1164
+ for accumulated_split, num_times_split in zip(
1165
+ accumulated_values.split(self.context_length, dim=1),
1166
+ num_times_accumulated.split(self.context_length, dim=1),
1167
+ )
1168
+ ],
1169
+ dim=1,
1170
+ ).to(dtype)
1171
+
1172
+ # 3. Feed-forward
1173
+ norm_hidden_states = self.norm3(hidden_states)
1174
+
1175
+ if self._chunk_size is not None:
1176
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
1177
+ else:
1178
+ ff_output = self.ff(norm_hidden_states)
1179
+
1180
+ hidden_states = ff_output + hidden_states
1181
+ if hidden_states.ndim == 4:
1182
+ hidden_states = hidden_states.squeeze(1)
1183
+
1184
+ return hidden_states
1185
+
1186
+
1187
+ class FeedForward(nn.Module):
1188
+ r"""
1189
+ A feed-forward layer.
1190
+
1191
+ Parameters:
1192
+ dim (`int`): The number of channels in the input.
1193
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
1194
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
1195
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
1196
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
1197
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
1198
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
1199
+ """
1200
+
1201
+ def __init__(
1202
+ self,
1203
+ dim: int,
1204
+ dim_out: Optional[int] = None,
1205
+ mult: int = 4,
1206
+ dropout: float = 0.0,
1207
+ activation_fn: str = "geglu",
1208
+ final_dropout: bool = False,
1209
+ inner_dim=None,
1210
+ bias: bool = True,
1211
+ ):
1212
+ super().__init__()
1213
+ if inner_dim is None:
1214
+ inner_dim = int(dim * mult)
1215
+ dim_out = dim_out if dim_out is not None else dim
1216
+
1217
+ if activation_fn == "gelu":
1218
+ act_fn = GELU(dim, inner_dim, bias=bias)
1219
+ if activation_fn == "gelu-approximate":
1220
+ act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
1221
+ elif activation_fn == "geglu":
1222
+ act_fn = GEGLU(dim, inner_dim, bias=bias)
1223
+ elif activation_fn == "geglu-approximate":
1224
+ act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
1225
+ elif activation_fn == "swiglu":
1226
+ act_fn = SwiGLU(dim, inner_dim, bias=bias)
1227
+
1228
+ self.net = nn.ModuleList([])
1229
+ # project in
1230
+ self.net.append(act_fn)
1231
+ # project dropout
1232
+ self.net.append(nn.Dropout(dropout))
1233
+ # project out
1234
+ self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
1235
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
1236
+ if final_dropout:
1237
+ self.net.append(nn.Dropout(dropout))
1238
+
1239
+ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
1240
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
1241
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1242
+ deprecate("scale", "1.0.0", deprecation_message)
1243
+ for module in self.net:
1244
+ hidden_states = module(hidden_states)
1245
+ return hidden_states