root commited on
Commit
00e3192
1 Parent(s): 39ae121

add mvdiffusion models

Browse files
mvdiffusion/models/transformer_mv2d_image.py ADDED
@@ -0,0 +1,1029 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, Optional
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.models.embeddings import ImagePositionalEmbeddings
23
+ from diffusers.utils import BaseOutput, deprecate
24
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
25
+ from diffusers.models.attention import FeedForward, AdaLayerNorm, AdaLayerNormZero, Attention
26
+ from diffusers.models.embeddings import PatchEmbed
27
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
28
+ from diffusers.models.modeling_utils import ModelMixin
29
+ from diffusers.utils.import_utils import is_xformers_available
30
+
31
+ from einops import rearrange, repeat
32
+ import pdb
33
+ import random
34
+
35
+
36
+ if is_xformers_available():
37
+ import xformers
38
+ import xformers.ops
39
+ else:
40
+ xformers = None
41
+
42
+ def my_repeat(tensor, num_repeats):
43
+ """
44
+ Repeat a tensor along a given dimension
45
+ """
46
+ if len(tensor.shape) == 3:
47
+ return repeat(tensor, "b d c -> (b v) d c", v=num_repeats)
48
+ elif len(tensor.shape) == 4:
49
+ return repeat(tensor, "a b d c -> (a v) b d c", v=num_repeats)
50
+
51
+
52
+ @dataclass
53
+ class TransformerMV2DModelOutput(BaseOutput):
54
+ """
55
+ The output of [`Transformer2DModel`].
56
+
57
+ Args:
58
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
59
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
60
+ distributions for the unnoised latent pixels.
61
+ """
62
+
63
+ sample: torch.FloatTensor
64
+
65
+
66
+ class TransformerMV2DModel(ModelMixin, ConfigMixin):
67
+ """
68
+ A 2D Transformer model for image-like data.
69
+
70
+ Parameters:
71
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
72
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
73
+ in_channels (`int`, *optional*):
74
+ The number of channels in the input and output (specify if the input is **continuous**).
75
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
76
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
77
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
78
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
79
+ This is fixed during training since it is used to learn a number of position embeddings.
80
+ num_vector_embeds (`int`, *optional*):
81
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
82
+ Includes the class for the masked latent pixel.
83
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
84
+ num_embeds_ada_norm ( `int`, *optional*):
85
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
86
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
87
+ added to the hidden states.
88
+
89
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
90
+ attention_bias (`bool`, *optional*):
91
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
92
+ """
93
+
94
+ @register_to_config
95
+ def __init__(
96
+ self,
97
+ num_attention_heads: int = 16,
98
+ attention_head_dim: int = 88,
99
+ in_channels: Optional[int] = None,
100
+ out_channels: Optional[int] = None,
101
+ num_layers: int = 1,
102
+ dropout: float = 0.0,
103
+ norm_num_groups: int = 32,
104
+ cross_attention_dim: Optional[int] = None,
105
+ attention_bias: bool = False,
106
+ sample_size: Optional[int] = None,
107
+ num_vector_embeds: Optional[int] = None,
108
+ patch_size: Optional[int] = None,
109
+ activation_fn: str = "geglu",
110
+ num_embeds_ada_norm: Optional[int] = None,
111
+ use_linear_projection: bool = False,
112
+ only_cross_attention: bool = False,
113
+ upcast_attention: bool = False,
114
+ norm_type: str = "layer_norm",
115
+ norm_elementwise_affine: bool = True,
116
+ num_views: int = 1,
117
+ cd_attention_last: bool=False,
118
+ cd_attention_mid: bool=False,
119
+ multiview_attention: bool=True,
120
+ sparse_mv_attention: bool = False,
121
+ mvcd_attention: bool=False
122
+ ):
123
+ super().__init__()
124
+ self.use_linear_projection = use_linear_projection
125
+ self.num_attention_heads = num_attention_heads
126
+ self.attention_head_dim = attention_head_dim
127
+ inner_dim = num_attention_heads * attention_head_dim
128
+
129
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
130
+ # Define whether input is continuous or discrete depending on configuration
131
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
132
+ self.is_input_vectorized = num_vector_embeds is not None
133
+ self.is_input_patches = in_channels is not None and patch_size is not None
134
+
135
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
136
+ deprecation_message = (
137
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
138
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
139
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
140
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
141
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
142
+ )
143
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
144
+ norm_type = "ada_norm"
145
+
146
+ if self.is_input_continuous and self.is_input_vectorized:
147
+ raise ValueError(
148
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
149
+ " sure that either `in_channels` or `num_vector_embeds` is None."
150
+ )
151
+ elif self.is_input_vectorized and self.is_input_patches:
152
+ raise ValueError(
153
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
154
+ " sure that either `num_vector_embeds` or `num_patches` is None."
155
+ )
156
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
157
+ raise ValueError(
158
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
159
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
160
+ )
161
+
162
+ # 2. Define input layers
163
+ if self.is_input_continuous:
164
+ self.in_channels = in_channels
165
+
166
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
167
+ if use_linear_projection:
168
+ self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
169
+ else:
170
+ self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
171
+ elif self.is_input_vectorized:
172
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
173
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
174
+
175
+ self.height = sample_size
176
+ self.width = sample_size
177
+ self.num_vector_embeds = num_vector_embeds
178
+ self.num_latent_pixels = self.height * self.width
179
+
180
+ self.latent_image_embedding = ImagePositionalEmbeddings(
181
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
182
+ )
183
+ elif self.is_input_patches:
184
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
185
+
186
+ self.height = sample_size
187
+ self.width = sample_size
188
+
189
+ self.patch_size = patch_size
190
+ self.pos_embed = PatchEmbed(
191
+ height=sample_size,
192
+ width=sample_size,
193
+ patch_size=patch_size,
194
+ in_channels=in_channels,
195
+ embed_dim=inner_dim,
196
+ )
197
+
198
+ # 3. Define transformers blocks
199
+ self.transformer_blocks = nn.ModuleList(
200
+ [
201
+ BasicMVTransformerBlock(
202
+ inner_dim,
203
+ num_attention_heads,
204
+ attention_head_dim,
205
+ dropout=dropout,
206
+ cross_attention_dim=cross_attention_dim,
207
+ activation_fn=activation_fn,
208
+ num_embeds_ada_norm=num_embeds_ada_norm,
209
+ attention_bias=attention_bias,
210
+ only_cross_attention=only_cross_attention,
211
+ upcast_attention=upcast_attention,
212
+ norm_type=norm_type,
213
+ norm_elementwise_affine=norm_elementwise_affine,
214
+ num_views=num_views,
215
+ cd_attention_last=cd_attention_last,
216
+ cd_attention_mid=cd_attention_mid,
217
+ multiview_attention=multiview_attention,
218
+ sparse_mv_attention=sparse_mv_attention,
219
+ mvcd_attention=mvcd_attention
220
+ )
221
+ for d in range(num_layers)
222
+ ]
223
+ )
224
+
225
+ # 4. Define output layers
226
+ self.out_channels = in_channels if out_channels is None else out_channels
227
+ if self.is_input_continuous:
228
+ # TODO: should use out_channels for continuous projections
229
+ if use_linear_projection:
230
+ self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
231
+ else:
232
+ self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
233
+ elif self.is_input_vectorized:
234
+ self.norm_out = nn.LayerNorm(inner_dim)
235
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
236
+ elif self.is_input_patches:
237
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
238
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
239
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
240
+
241
+ def forward(
242
+ self,
243
+ hidden_states: torch.Tensor,
244
+ encoder_hidden_states: Optional[torch.Tensor] = None,
245
+ timestep: Optional[torch.LongTensor] = None,
246
+ class_labels: Optional[torch.LongTensor] = None,
247
+ cross_attention_kwargs: Dict[str, Any] = None,
248
+ attention_mask: Optional[torch.Tensor] = None,
249
+ encoder_attention_mask: Optional[torch.Tensor] = None,
250
+ return_dict: bool = True,
251
+ ):
252
+ """
253
+ The [`Transformer2DModel`] forward method.
254
+
255
+ Args:
256
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
257
+ Input `hidden_states`.
258
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
259
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
260
+ self-attention.
261
+ timestep ( `torch.LongTensor`, *optional*):
262
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
263
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
264
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
265
+ `AdaLayerZeroNorm`.
266
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
267
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
268
+
269
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
270
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
271
+
272
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
273
+ above. This bias will be added to the cross-attention scores.
274
+ return_dict (`bool`, *optional*, defaults to `True`):
275
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
276
+ tuple.
277
+
278
+ Returns:
279
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
280
+ `tuple` where the first element is the sample tensor.
281
+ """
282
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
283
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
284
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
285
+ # expects mask of shape:
286
+ # [batch, key_tokens]
287
+ # adds singleton query_tokens dimension:
288
+ # [batch, 1, key_tokens]
289
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
290
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
291
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
292
+ if attention_mask is not None and attention_mask.ndim == 2:
293
+ # assume that mask is expressed as:
294
+ # (1 = keep, 0 = discard)
295
+ # convert mask into a bias that can be added to attention scores:
296
+ # (keep = +0, discard = -10000.0)
297
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
298
+ attention_mask = attention_mask.unsqueeze(1)
299
+
300
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
301
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
302
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
303
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
304
+
305
+ # 1. Input
306
+ if self.is_input_continuous:
307
+ batch, _, height, width = hidden_states.shape
308
+ residual = hidden_states
309
+
310
+ hidden_states = self.norm(hidden_states)
311
+ if not self.use_linear_projection:
312
+ hidden_states = self.proj_in(hidden_states)
313
+ inner_dim = hidden_states.shape[1]
314
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
315
+ else:
316
+ inner_dim = hidden_states.shape[1]
317
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
318
+ hidden_states = self.proj_in(hidden_states)
319
+ elif self.is_input_vectorized:
320
+ hidden_states = self.latent_image_embedding(hidden_states)
321
+ elif self.is_input_patches:
322
+ hidden_states = self.pos_embed(hidden_states)
323
+
324
+ # 2. Blocks
325
+ for block in self.transformer_blocks:
326
+ hidden_states = block(
327
+ hidden_states,
328
+ attention_mask=attention_mask,
329
+ encoder_hidden_states=encoder_hidden_states,
330
+ encoder_attention_mask=encoder_attention_mask,
331
+ timestep=timestep,
332
+ cross_attention_kwargs=cross_attention_kwargs,
333
+ class_labels=class_labels,
334
+ )
335
+
336
+ # 3. Output
337
+ if self.is_input_continuous:
338
+ if not self.use_linear_projection:
339
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
340
+ hidden_states = self.proj_out(hidden_states)
341
+ else:
342
+ hidden_states = self.proj_out(hidden_states)
343
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
344
+
345
+ output = hidden_states + residual
346
+ elif self.is_input_vectorized:
347
+ hidden_states = self.norm_out(hidden_states)
348
+ logits = self.out(hidden_states)
349
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
350
+ logits = logits.permute(0, 2, 1)
351
+
352
+ # log(p(x_0))
353
+ output = F.log_softmax(logits.double(), dim=1).float()
354
+ elif self.is_input_patches:
355
+ # TODO: cleanup!
356
+ conditioning = self.transformer_blocks[0].norm1.emb(
357
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
358
+ )
359
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
360
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
361
+ hidden_states = self.proj_out_2(hidden_states)
362
+
363
+ # unpatchify
364
+ height = width = int(hidden_states.shape[1] ** 0.5)
365
+ hidden_states = hidden_states.reshape(
366
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
367
+ )
368
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
369
+ output = hidden_states.reshape(
370
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
371
+ )
372
+
373
+ if not return_dict:
374
+ return (output,)
375
+
376
+ return TransformerMV2DModelOutput(sample=output)
377
+
378
+
379
+ @maybe_allow_in_graph
380
+ class BasicMVTransformerBlock(nn.Module):
381
+ r"""
382
+ A basic Transformer block.
383
+
384
+ Parameters:
385
+ dim (`int`): The number of channels in the input and output.
386
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
387
+ attention_head_dim (`int`): The number of channels in each head.
388
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
389
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
390
+ only_cross_attention (`bool`, *optional*):
391
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
392
+ double_self_attention (`bool`, *optional*):
393
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
394
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
395
+ num_embeds_ada_norm (:
396
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
397
+ attention_bias (:
398
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
399
+ """
400
+
401
+ def __init__(
402
+ self,
403
+ dim: int,
404
+ num_attention_heads: int,
405
+ attention_head_dim: int,
406
+ dropout=0.0,
407
+ cross_attention_dim: Optional[int] = None,
408
+ activation_fn: str = "geglu",
409
+ num_embeds_ada_norm: Optional[int] = None,
410
+ attention_bias: bool = False,
411
+ only_cross_attention: bool = False,
412
+ double_self_attention: bool = False,
413
+ upcast_attention: bool = False,
414
+ norm_elementwise_affine: bool = True,
415
+ norm_type: str = "layer_norm",
416
+ final_dropout: bool = False,
417
+ num_views: int = 1,
418
+ cd_attention_last: bool = False,
419
+ cd_attention_mid: bool = False,
420
+ multiview_attention: bool = True,
421
+ sparse_mv_attention: bool = False,
422
+ mvcd_attention: bool = False
423
+ ):
424
+ super().__init__()
425
+ self.only_cross_attention = only_cross_attention
426
+
427
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
428
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
429
+
430
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
431
+ raise ValueError(
432
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
433
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
434
+ )
435
+
436
+ # Define 3 blocks. Each block has its own normalization layer.
437
+ # 1. Self-Attn
438
+ if self.use_ada_layer_norm:
439
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
440
+ elif self.use_ada_layer_norm_zero:
441
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
442
+ else:
443
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
444
+
445
+ self.multiview_attention = multiview_attention
446
+ self.sparse_mv_attention = sparse_mv_attention
447
+ self.mvcd_attention = mvcd_attention
448
+
449
+ self.attn1 = CustomAttention(
450
+ query_dim=dim,
451
+ heads=num_attention_heads,
452
+ dim_head=attention_head_dim,
453
+ dropout=dropout,
454
+ bias=attention_bias,
455
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
456
+ upcast_attention=upcast_attention,
457
+ processor=MVAttnProcessor()
458
+ )
459
+
460
+ # 2. Cross-Attn
461
+ if cross_attention_dim is not None or double_self_attention:
462
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
463
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
464
+ # the second cross attention block.
465
+ self.norm2 = (
466
+ AdaLayerNorm(dim, num_embeds_ada_norm)
467
+ if self.use_ada_layer_norm
468
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
469
+ )
470
+ self.attn2 = Attention(
471
+ query_dim=dim,
472
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
473
+ heads=num_attention_heads,
474
+ dim_head=attention_head_dim,
475
+ dropout=dropout,
476
+ bias=attention_bias,
477
+ upcast_attention=upcast_attention,
478
+ ) # is self-attn if encoder_hidden_states is none
479
+ else:
480
+ self.norm2 = None
481
+ self.attn2 = None
482
+
483
+ # 3. Feed-forward
484
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
485
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
486
+
487
+ # let chunk size default to None
488
+ self._chunk_size = None
489
+ self._chunk_dim = 0
490
+
491
+ self.num_views = num_views
492
+
493
+ self.cd_attention_last = cd_attention_last
494
+
495
+ if self.cd_attention_last:
496
+ # Joint task -Attn
497
+ self.attn_joint_last = CustomJointAttention(
498
+ query_dim=dim,
499
+ heads=num_attention_heads,
500
+ dim_head=attention_head_dim,
501
+ dropout=dropout,
502
+ bias=attention_bias,
503
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
504
+ upcast_attention=upcast_attention,
505
+ processor=JointAttnProcessor()
506
+ )
507
+ nn.init.zeros_(self.attn_joint_last.to_out[0].weight.data)
508
+ self.norm_joint_last = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
509
+
510
+
511
+ self.cd_attention_mid = cd_attention_mid
512
+
513
+ if self.cd_attention_mid:
514
+ print("cross-domain attn in the middle")
515
+ # Joint task -Attn
516
+ self.attn_joint_mid = CustomJointAttention(
517
+ query_dim=dim,
518
+ heads=num_attention_heads,
519
+ dim_head=attention_head_dim,
520
+ dropout=dropout,
521
+ bias=attention_bias,
522
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
523
+ upcast_attention=upcast_attention,
524
+ processor=JointAttnProcessor()
525
+ )
526
+ nn.init.zeros_(self.attn_joint_mid.to_out[0].weight.data)
527
+ self.norm_joint_mid = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
528
+
529
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
530
+ # Sets chunk feed-forward
531
+ self._chunk_size = chunk_size
532
+ self._chunk_dim = dim
533
+
534
+ def forward(
535
+ self,
536
+ hidden_states: torch.FloatTensor,
537
+ attention_mask: Optional[torch.FloatTensor] = None,
538
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
539
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
540
+ timestep: Optional[torch.LongTensor] = None,
541
+ cross_attention_kwargs: Dict[str, Any] = None,
542
+ class_labels: Optional[torch.LongTensor] = None,
543
+ ):
544
+ assert attention_mask is None # not supported yet
545
+ # Notice that normalization is always applied before the real computation in the following blocks.
546
+ # 1. Self-Attention
547
+ if self.use_ada_layer_norm:
548
+ norm_hidden_states = self.norm1(hidden_states, timestep)
549
+ elif self.use_ada_layer_norm_zero:
550
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
551
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
552
+ )
553
+ else:
554
+ norm_hidden_states = self.norm1(hidden_states)
555
+
556
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
557
+
558
+ attn_output = self.attn1(
559
+ norm_hidden_states,
560
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
561
+ attention_mask=attention_mask,
562
+ num_views=self.num_views,
563
+ multiview_attention=self.multiview_attention,
564
+ sparse_mv_attention=self.sparse_mv_attention,
565
+ mvcd_attention=self.mvcd_attention,
566
+ **cross_attention_kwargs,
567
+ )
568
+
569
+
570
+ if self.use_ada_layer_norm_zero:
571
+ attn_output = gate_msa.unsqueeze(1) * attn_output
572
+ hidden_states = attn_output + hidden_states
573
+
574
+ # joint attention twice
575
+ if self.cd_attention_mid:
576
+ norm_hidden_states = (
577
+ self.norm_joint_mid(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_mid(hidden_states)
578
+ )
579
+ hidden_states = self.attn_joint_mid(norm_hidden_states) + hidden_states
580
+
581
+ # 2. Cross-Attention
582
+ if self.attn2 is not None:
583
+ norm_hidden_states = (
584
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
585
+ )
586
+
587
+ attn_output = self.attn2(
588
+ norm_hidden_states,
589
+ encoder_hidden_states=encoder_hidden_states,
590
+ attention_mask=encoder_attention_mask,
591
+ **cross_attention_kwargs,
592
+ )
593
+ hidden_states = attn_output + hidden_states
594
+
595
+ # 3. Feed-forward
596
+ norm_hidden_states = self.norm3(hidden_states)
597
+
598
+ if self.use_ada_layer_norm_zero:
599
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
600
+
601
+ if self._chunk_size is not None:
602
+ # "feed_forward_chunk_size" can be used to save memory
603
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
604
+ raise ValueError(
605
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
606
+ )
607
+
608
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
609
+ ff_output = torch.cat(
610
+ [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
611
+ dim=self._chunk_dim,
612
+ )
613
+ else:
614
+ ff_output = self.ff(norm_hidden_states)
615
+
616
+ if self.use_ada_layer_norm_zero:
617
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
618
+
619
+ hidden_states = ff_output + hidden_states
620
+
621
+ if self.cd_attention_last:
622
+ norm_hidden_states = (
623
+ self.norm_joint_last(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_last(hidden_states)
624
+ )
625
+ hidden_states = self.attn_joint_last(norm_hidden_states) + hidden_states
626
+
627
+ return hidden_states
628
+
629
+
630
+ class CustomAttention(Attention):
631
+ def set_use_memory_efficient_attention_xformers(
632
+ self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
633
+ ):
634
+ processor = XFormersMVAttnProcessor()
635
+ self.set_processor(processor)
636
+ # print("using xformers attention processor")
637
+
638
+
639
+ class CustomJointAttention(Attention):
640
+ def set_use_memory_efficient_attention_xformers(
641
+ self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
642
+ ):
643
+ processor = XFormersJointAttnProcessor()
644
+ self.set_processor(processor)
645
+ # print("using xformers attention processor")
646
+
647
+ class MVAttnProcessor:
648
+ r"""
649
+ Default processor for performing attention-related computations.
650
+ """
651
+
652
+ def __call__(
653
+ self,
654
+ attn: Attention,
655
+ hidden_states,
656
+ encoder_hidden_states=None,
657
+ attention_mask=None,
658
+ temb=None,
659
+ num_views=1,
660
+ multiview_attention=True
661
+ ):
662
+ residual = hidden_states
663
+
664
+ if attn.spatial_norm is not None:
665
+ hidden_states = attn.spatial_norm(hidden_states, temb)
666
+
667
+ input_ndim = hidden_states.ndim
668
+
669
+ if input_ndim == 4:
670
+ batch_size, channel, height, width = hidden_states.shape
671
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
672
+
673
+ batch_size, sequence_length, _ = (
674
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
675
+ )
676
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
677
+
678
+ if attn.group_norm is not None:
679
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
680
+
681
+ query = attn.to_q(hidden_states)
682
+
683
+ if encoder_hidden_states is None:
684
+ encoder_hidden_states = hidden_states
685
+ elif attn.norm_cross:
686
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
687
+
688
+ key = attn.to_k(encoder_hidden_states)
689
+ value = attn.to_v(encoder_hidden_states)
690
+
691
+ # print('query', query.shape, 'key', key.shape, 'value', value.shape)
692
+ #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320])
693
+ # pdb.set_trace()
694
+ # multi-view self-attention
695
+ if multiview_attention:
696
+ if num_views <= 6:
697
+ # after use xformer; possible to train with 6 views
698
+ key = rearrange(key, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
699
+ value = rearrange(value, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
700
+ else:# apply sparse attention
701
+ pass
702
+ # print("use sparse attention")
703
+ # # seems that the sparse random sampling cause problems
704
+ # # don't use random sampling, just fix the indexes
705
+ # onekey = rearrange(key, "(b t) d c -> b t d c", t=num_views)
706
+ # onevalue = rearrange(value, "(b t) d c -> b t d c", t=num_views)
707
+ # allkeys = []
708
+ # allvalues = []
709
+ # all_indexes = {
710
+ # 0 : [0, 2, 3, 4],
711
+ # 1: [0, 1, 3, 5],
712
+ # 2: [0, 2, 3, 4],
713
+ # 3: [0, 2, 3, 4],
714
+ # 4: [0, 2, 3, 4],
715
+ # 5: [0, 1, 3, 5]
716
+ # }
717
+ # for jj in range(num_views):
718
+ # # valid_index = [x for x in range(0, num_views) if x!= jj]
719
+ # # indexes = random.sample(valid_index, 3) + [jj] + [0]
720
+ # indexes = all_indexes[jj]
721
+
722
+ # indexes = torch.tensor(indexes).long().to(key.device)
723
+ # allkeys.append(onekey[:, indexes])
724
+ # allvalues.append(onevalue[:, indexes])
725
+ # keys = torch.stack(allkeys, dim=1) # checked, should be dim=1
726
+ # values = torch.stack(allvalues, dim=1)
727
+ # key = rearrange(keys, 'b t f d c -> (b t) (f d) c')
728
+ # value = rearrange(values, 'b t f d c -> (b t) (f d) c')
729
+
730
+
731
+ query = attn.head_to_batch_dim(query).contiguous()
732
+ key = attn.head_to_batch_dim(key).contiguous()
733
+ value = attn.head_to_batch_dim(value).contiguous()
734
+
735
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
736
+ hidden_states = torch.bmm(attention_probs, value)
737
+ hidden_states = attn.batch_to_head_dim(hidden_states)
738
+
739
+ # linear proj
740
+ hidden_states = attn.to_out[0](hidden_states)
741
+ # dropout
742
+ hidden_states = attn.to_out[1](hidden_states)
743
+
744
+ if input_ndim == 4:
745
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
746
+
747
+ if attn.residual_connection:
748
+ hidden_states = hidden_states + residual
749
+
750
+ hidden_states = hidden_states / attn.rescale_output_factor
751
+
752
+ return hidden_states
753
+
754
+
755
+ class XFormersMVAttnProcessor:
756
+ r"""
757
+ Default processor for performing attention-related computations.
758
+ """
759
+
760
+ def __call__(
761
+ self,
762
+ attn: Attention,
763
+ hidden_states,
764
+ encoder_hidden_states=None,
765
+ attention_mask=None,
766
+ temb=None,
767
+ num_views=1.,
768
+ multiview_attention=True,
769
+ sparse_mv_attention=False,
770
+ mvcd_attention=False,
771
+ ):
772
+ residual = hidden_states
773
+
774
+ if attn.spatial_norm is not None:
775
+ hidden_states = attn.spatial_norm(hidden_states, temb)
776
+
777
+ input_ndim = hidden_states.ndim
778
+
779
+ if input_ndim == 4:
780
+ batch_size, channel, height, width = hidden_states.shape
781
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
782
+
783
+ batch_size, sequence_length, _ = (
784
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
785
+ )
786
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
787
+
788
+ # from yuancheng; here attention_mask is None
789
+ if attention_mask is not None:
790
+ # expand our mask's singleton query_tokens dimension:
791
+ # [batch*heads, 1, key_tokens] ->
792
+ # [batch*heads, query_tokens, key_tokens]
793
+ # so that it can be added as a bias onto the attention scores that xformers computes:
794
+ # [batch*heads, query_tokens, key_tokens]
795
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
796
+ _, query_tokens, _ = hidden_states.shape
797
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
798
+
799
+ if attn.group_norm is not None:
800
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
801
+
802
+ query = attn.to_q(hidden_states)
803
+
804
+ if encoder_hidden_states is None:
805
+ encoder_hidden_states = hidden_states
806
+ elif attn.norm_cross:
807
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
808
+
809
+ key_raw = attn.to_k(encoder_hidden_states)
810
+ value_raw = attn.to_v(encoder_hidden_states)
811
+
812
+ # print('query', query.shape, 'key', key.shape, 'value', value.shape)
813
+ #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320])
814
+ # pdb.set_trace()
815
+ # multi-view self-attention
816
+ if multiview_attention:
817
+ if not sparse_mv_attention:
818
+ key = my_repeat(rearrange(key_raw, "(b t) d c -> b (t d) c", t=num_views), num_views)
819
+ value = my_repeat(rearrange(value_raw, "(b t) d c -> b (t d) c", t=num_views), num_views)
820
+ else:
821
+ key_front = my_repeat(rearrange(key_raw, "(b t) d c -> b t d c", t=num_views)[:, 0, :, :], num_views) # [(b t), d, c]
822
+ value_front = my_repeat(rearrange(value_raw, "(b t) d c -> b t d c", t=num_views)[:, 0, :, :], num_views)
823
+ key = torch.cat([key_front, key_raw], dim=1) # shape (b t) (2 d) c
824
+ value = torch.cat([value_front, value_raw], dim=1)
825
+
826
+ if mvcd_attention:
827
+ # memory efficient, cross domain attention
828
+ key_0, key_1 = torch.chunk(key_raw, dim=0, chunks=2) # keys shape (b t) d c
829
+ value_0, value_1 = torch.chunk(value_raw, dim=0, chunks=2)
830
+ key_cross = torch.concat([key_1, key_0], dim=0)
831
+ value_cross = torch.concat([value_1, value_0], dim=0) # shape (b t) d c
832
+ key = torch.cat([key, key_cross], dim=1)
833
+ value = torch.cat([value, value_cross], dim=1) # shape (b t) (t+1 d) c
834
+ else:
835
+ # print("don't use multiview attention.")
836
+ key = key_raw
837
+ value = value_raw
838
+
839
+ query = attn.head_to_batch_dim(query)
840
+ key = attn.head_to_batch_dim(key)
841
+ value = attn.head_to_batch_dim(value)
842
+
843
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
844
+ hidden_states = attn.batch_to_head_dim(hidden_states)
845
+
846
+ # linear proj
847
+ hidden_states = attn.to_out[0](hidden_states)
848
+ # dropout
849
+ hidden_states = attn.to_out[1](hidden_states)
850
+
851
+ if input_ndim == 4:
852
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
853
+
854
+ if attn.residual_connection:
855
+ hidden_states = hidden_states + residual
856
+
857
+ hidden_states = hidden_states / attn.rescale_output_factor
858
+
859
+ return hidden_states
860
+
861
+
862
+
863
+ class XFormersJointAttnProcessor:
864
+ r"""
865
+ Default processor for performing attention-related computations.
866
+ """
867
+
868
+ def __call__(
869
+ self,
870
+ attn: Attention,
871
+ hidden_states,
872
+ encoder_hidden_states=None,
873
+ attention_mask=None,
874
+ temb=None,
875
+ num_tasks=2
876
+ ):
877
+
878
+ residual = hidden_states
879
+
880
+ if attn.spatial_norm is not None:
881
+ hidden_states = attn.spatial_norm(hidden_states, temb)
882
+
883
+ input_ndim = hidden_states.ndim
884
+
885
+ if input_ndim == 4:
886
+ batch_size, channel, height, width = hidden_states.shape
887
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
888
+
889
+ batch_size, sequence_length, _ = (
890
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
891
+ )
892
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
893
+
894
+ # from yuancheng; here attention_mask is None
895
+ if attention_mask is not None:
896
+ # expand our mask's singleton query_tokens dimension:
897
+ # [batch*heads, 1, key_tokens] ->
898
+ # [batch*heads, query_tokens, key_tokens]
899
+ # so that it can be added as a bias onto the attention scores that xformers computes:
900
+ # [batch*heads, query_tokens, key_tokens]
901
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
902
+ _, query_tokens, _ = hidden_states.shape
903
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
904
+
905
+ if attn.group_norm is not None:
906
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
907
+
908
+ query = attn.to_q(hidden_states)
909
+
910
+ if encoder_hidden_states is None:
911
+ encoder_hidden_states = hidden_states
912
+ elif attn.norm_cross:
913
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
914
+
915
+ key = attn.to_k(encoder_hidden_states)
916
+ value = attn.to_v(encoder_hidden_states)
917
+
918
+ assert num_tasks == 2 # only support two tasks now
919
+
920
+ key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c
921
+ value_0, value_1 = torch.chunk(value, dim=0, chunks=2)
922
+ key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c
923
+ value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c
924
+ key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c
925
+ value = torch.cat([value]*2, dim=0) # (2 b t) 2d c
926
+
927
+
928
+ query = attn.head_to_batch_dim(query).contiguous()
929
+ key = attn.head_to_batch_dim(key).contiguous()
930
+ value = attn.head_to_batch_dim(value).contiguous()
931
+
932
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
933
+ hidden_states = attn.batch_to_head_dim(hidden_states)
934
+
935
+ # linear proj
936
+ hidden_states = attn.to_out[0](hidden_states)
937
+ # dropout
938
+ hidden_states = attn.to_out[1](hidden_states)
939
+
940
+ if input_ndim == 4:
941
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
942
+
943
+ if attn.residual_connection:
944
+ hidden_states = hidden_states + residual
945
+
946
+ hidden_states = hidden_states / attn.rescale_output_factor
947
+
948
+ return hidden_states
949
+
950
+
951
+ class JointAttnProcessor:
952
+ r"""
953
+ Default processor for performing attention-related computations.
954
+ """
955
+
956
+ def __call__(
957
+ self,
958
+ attn: Attention,
959
+ hidden_states,
960
+ encoder_hidden_states=None,
961
+ attention_mask=None,
962
+ temb=None,
963
+ num_tasks=2
964
+ ):
965
+
966
+ residual = hidden_states
967
+
968
+ if attn.spatial_norm is not None:
969
+ hidden_states = attn.spatial_norm(hidden_states, temb)
970
+
971
+ input_ndim = hidden_states.ndim
972
+
973
+ if input_ndim == 4:
974
+ batch_size, channel, height, width = hidden_states.shape
975
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
976
+
977
+ batch_size, sequence_length, _ = (
978
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
979
+ )
980
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
981
+
982
+
983
+ if attn.group_norm is not None:
984
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
985
+
986
+ query = attn.to_q(hidden_states)
987
+
988
+ if encoder_hidden_states is None:
989
+ encoder_hidden_states = hidden_states
990
+ elif attn.norm_cross:
991
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
992
+
993
+ key = attn.to_k(encoder_hidden_states)
994
+ value = attn.to_v(encoder_hidden_states)
995
+
996
+ assert num_tasks == 2 # only support two tasks now
997
+
998
+ key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c
999
+ value_0, value_1 = torch.chunk(value, dim=0, chunks=2)
1000
+ key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c
1001
+ value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c
1002
+ key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c
1003
+ value = torch.cat([value]*2, dim=0) # (2 b t) 2d c
1004
+
1005
+
1006
+ query = attn.head_to_batch_dim(query).contiguous()
1007
+ key = attn.head_to_batch_dim(key).contiguous()
1008
+ value = attn.head_to_batch_dim(value).contiguous()
1009
+
1010
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
1011
+ hidden_states = torch.bmm(attention_probs, value)
1012
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1013
+
1014
+ # linear proj
1015
+ hidden_states = attn.to_out[0](hidden_states)
1016
+ # dropout
1017
+ hidden_states = attn.to_out[1](hidden_states)
1018
+
1019
+ if input_ndim == 4:
1020
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1021
+
1022
+ if attn.residual_connection:
1023
+ hidden_states = hidden_states + residual
1024
+
1025
+ hidden_states = hidden_states / attn.rescale_output_factor
1026
+
1027
+ return hidden_states
1028
+
1029
+
mvdiffusion/models/transformer_mv2d_rowwise.py ADDED
@@ -0,0 +1,978 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, Optional
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.models.embeddings import ImagePositionalEmbeddings
23
+ from diffusers.utils import BaseOutput, deprecate
24
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
25
+ from diffusers.models.attention import FeedForward, AdaLayerNorm, AdaLayerNormZero, Attention
26
+ from diffusers.models.embeddings import PatchEmbed
27
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
28
+ from diffusers.models.modeling_utils import ModelMixin
29
+ from diffusers.utils.import_utils import is_xformers_available
30
+
31
+ from einops import rearrange
32
+ import pdb
33
+ import random
34
+ import math
35
+
36
+
37
+ if is_xformers_available():
38
+ import xformers
39
+ import xformers.ops
40
+ else:
41
+ xformers = None
42
+
43
+
44
+ @dataclass
45
+ class TransformerMV2DModelOutput(BaseOutput):
46
+ """
47
+ The output of [`Transformer2DModel`].
48
+
49
+ Args:
50
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
51
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
52
+ distributions for the unnoised latent pixels.
53
+ """
54
+
55
+ sample: torch.FloatTensor
56
+
57
+
58
+ class TransformerMV2DModel(ModelMixin, ConfigMixin):
59
+ """
60
+ A 2D Transformer model for image-like data.
61
+
62
+ Parameters:
63
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
64
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
65
+ in_channels (`int`, *optional*):
66
+ The number of channels in the input and output (specify if the input is **continuous**).
67
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
68
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
69
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
70
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
71
+ This is fixed during training since it is used to learn a number of position embeddings.
72
+ num_vector_embeds (`int`, *optional*):
73
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
74
+ Includes the class for the masked latent pixel.
75
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
76
+ num_embeds_ada_norm ( `int`, *optional*):
77
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
78
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
79
+ added to the hidden states.
80
+
81
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
82
+ attention_bias (`bool`, *optional*):
83
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
84
+ """
85
+
86
+ @register_to_config
87
+ def __init__(
88
+ self,
89
+ num_attention_heads: int = 16,
90
+ attention_head_dim: int = 88,
91
+ in_channels: Optional[int] = None,
92
+ out_channels: Optional[int] = None,
93
+ num_layers: int = 1,
94
+ dropout: float = 0.0,
95
+ norm_num_groups: int = 32,
96
+ cross_attention_dim: Optional[int] = None,
97
+ attention_bias: bool = False,
98
+ sample_size: Optional[int] = None,
99
+ num_vector_embeds: Optional[int] = None,
100
+ patch_size: Optional[int] = None,
101
+ activation_fn: str = "geglu",
102
+ num_embeds_ada_norm: Optional[int] = None,
103
+ use_linear_projection: bool = False,
104
+ only_cross_attention: bool = False,
105
+ upcast_attention: bool = False,
106
+ norm_type: str = "layer_norm",
107
+ norm_elementwise_affine: bool = True,
108
+ num_views: int = 1,
109
+ cd_attention_last: bool=False,
110
+ cd_attention_mid: bool=False,
111
+ multiview_attention: bool=True,
112
+ sparse_mv_attention: bool = True, # not used
113
+ mvcd_attention: bool=False
114
+ ):
115
+ super().__init__()
116
+ self.use_linear_projection = use_linear_projection
117
+ self.num_attention_heads = num_attention_heads
118
+ self.attention_head_dim = attention_head_dim
119
+ inner_dim = num_attention_heads * attention_head_dim
120
+
121
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
122
+ # Define whether input is continuous or discrete depending on configuration
123
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
124
+ self.is_input_vectorized = num_vector_embeds is not None
125
+ self.is_input_patches = in_channels is not None and patch_size is not None
126
+
127
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
128
+ deprecation_message = (
129
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
130
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
131
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
132
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
133
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
134
+ )
135
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
136
+ norm_type = "ada_norm"
137
+
138
+ if self.is_input_continuous and self.is_input_vectorized:
139
+ raise ValueError(
140
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
141
+ " sure that either `in_channels` or `num_vector_embeds` is None."
142
+ )
143
+ elif self.is_input_vectorized and self.is_input_patches:
144
+ raise ValueError(
145
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
146
+ " sure that either `num_vector_embeds` or `num_patches` is None."
147
+ )
148
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
149
+ raise ValueError(
150
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
151
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
152
+ )
153
+
154
+ # 2. Define input layers
155
+ if self.is_input_continuous:
156
+ self.in_channels = in_channels
157
+
158
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
159
+ if use_linear_projection:
160
+ self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
161
+ else:
162
+ self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
163
+ elif self.is_input_vectorized:
164
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
165
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
166
+
167
+ self.height = sample_size
168
+ self.width = sample_size
169
+ self.num_vector_embeds = num_vector_embeds
170
+ self.num_latent_pixels = self.height * self.width
171
+
172
+ self.latent_image_embedding = ImagePositionalEmbeddings(
173
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
174
+ )
175
+ elif self.is_input_patches:
176
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
177
+
178
+ self.height = sample_size
179
+ self.width = sample_size
180
+
181
+ self.patch_size = patch_size
182
+ self.pos_embed = PatchEmbed(
183
+ height=sample_size,
184
+ width=sample_size,
185
+ patch_size=patch_size,
186
+ in_channels=in_channels,
187
+ embed_dim=inner_dim,
188
+ )
189
+
190
+ # 3. Define transformers blocks
191
+ self.transformer_blocks = nn.ModuleList(
192
+ [
193
+ BasicMVTransformerBlock(
194
+ inner_dim,
195
+ num_attention_heads,
196
+ attention_head_dim,
197
+ dropout=dropout,
198
+ cross_attention_dim=cross_attention_dim,
199
+ activation_fn=activation_fn,
200
+ num_embeds_ada_norm=num_embeds_ada_norm,
201
+ attention_bias=attention_bias,
202
+ only_cross_attention=only_cross_attention,
203
+ upcast_attention=upcast_attention,
204
+ norm_type=norm_type,
205
+ norm_elementwise_affine=norm_elementwise_affine,
206
+ num_views=num_views,
207
+ cd_attention_last=cd_attention_last,
208
+ cd_attention_mid=cd_attention_mid,
209
+ multiview_attention=multiview_attention,
210
+ mvcd_attention=mvcd_attention
211
+ )
212
+ for d in range(num_layers)
213
+ ]
214
+ )
215
+
216
+ # 4. Define output layers
217
+ self.out_channels = in_channels if out_channels is None else out_channels
218
+ if self.is_input_continuous:
219
+ # TODO: should use out_channels for continuous projections
220
+ if use_linear_projection:
221
+ self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
222
+ else:
223
+ self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
224
+ elif self.is_input_vectorized:
225
+ self.norm_out = nn.LayerNorm(inner_dim)
226
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
227
+ elif self.is_input_patches:
228
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
229
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
230
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
231
+
232
+ def forward(
233
+ self,
234
+ hidden_states: torch.Tensor,
235
+ encoder_hidden_states: Optional[torch.Tensor] = None,
236
+ timestep: Optional[torch.LongTensor] = None,
237
+ class_labels: Optional[torch.LongTensor] = None,
238
+ cross_attention_kwargs: Dict[str, Any] = None,
239
+ attention_mask: Optional[torch.Tensor] = None,
240
+ encoder_attention_mask: Optional[torch.Tensor] = None,
241
+ return_dict: bool = True,
242
+ ):
243
+ """
244
+ The [`Transformer2DModel`] forward method.
245
+
246
+ Args:
247
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
248
+ Input `hidden_states`.
249
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
250
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
251
+ self-attention.
252
+ timestep ( `torch.LongTensor`, *optional*):
253
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
254
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
255
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
256
+ `AdaLayerZeroNorm`.
257
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
258
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
259
+
260
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
261
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
262
+
263
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
264
+ above. This bias will be added to the cross-attention scores.
265
+ return_dict (`bool`, *optional*, defaults to `True`):
266
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
267
+ tuple.
268
+
269
+ Returns:
270
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
271
+ `tuple` where the first element is the sample tensor.
272
+ """
273
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
274
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
275
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
276
+ # expects mask of shape:
277
+ # [batch, key_tokens]
278
+ # adds singleton query_tokens dimension:
279
+ # [batch, 1, key_tokens]
280
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
281
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
282
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
283
+ if attention_mask is not None and attention_mask.ndim == 2:
284
+ # assume that mask is expressed as:
285
+ # (1 = keep, 0 = discard)
286
+ # convert mask into a bias that can be added to attention scores:
287
+ # (keep = +0, discard = -10000.0)
288
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
289
+ attention_mask = attention_mask.unsqueeze(1)
290
+
291
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
292
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
293
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
294
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
295
+
296
+ # 1. Input
297
+ if self.is_input_continuous:
298
+ batch, _, height, width = hidden_states.shape
299
+ residual = hidden_states
300
+
301
+ hidden_states = self.norm(hidden_states)
302
+ if not self.use_linear_projection:
303
+ hidden_states = self.proj_in(hidden_states)
304
+ inner_dim = hidden_states.shape[1]
305
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
306
+ else:
307
+ inner_dim = hidden_states.shape[1]
308
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
309
+ hidden_states = self.proj_in(hidden_states)
310
+ elif self.is_input_vectorized:
311
+ hidden_states = self.latent_image_embedding(hidden_states)
312
+ elif self.is_input_patches:
313
+ hidden_states = self.pos_embed(hidden_states)
314
+
315
+ # 2. Blocks
316
+ for block in self.transformer_blocks:
317
+ hidden_states = block(
318
+ hidden_states,
319
+ attention_mask=attention_mask,
320
+ encoder_hidden_states=encoder_hidden_states,
321
+ encoder_attention_mask=encoder_attention_mask,
322
+ timestep=timestep,
323
+ cross_attention_kwargs=cross_attention_kwargs,
324
+ class_labels=class_labels,
325
+ )
326
+
327
+ # 3. Output
328
+ if self.is_input_continuous:
329
+ if not self.use_linear_projection:
330
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
331
+ hidden_states = self.proj_out(hidden_states)
332
+ else:
333
+ hidden_states = self.proj_out(hidden_states)
334
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
335
+
336
+ output = hidden_states + residual
337
+ elif self.is_input_vectorized:
338
+ hidden_states = self.norm_out(hidden_states)
339
+ logits = self.out(hidden_states)
340
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
341
+ logits = logits.permute(0, 2, 1)
342
+
343
+ # log(p(x_0))
344
+ output = F.log_softmax(logits.double(), dim=1).float()
345
+ elif self.is_input_patches:
346
+ # TODO: cleanup!
347
+ conditioning = self.transformer_blocks[0].norm1.emb(
348
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
349
+ )
350
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
351
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
352
+ hidden_states = self.proj_out_2(hidden_states)
353
+
354
+ # unpatchify
355
+ height = width = int(hidden_states.shape[1] ** 0.5)
356
+ hidden_states = hidden_states.reshape(
357
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
358
+ )
359
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
360
+ output = hidden_states.reshape(
361
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
362
+ )
363
+
364
+ if not return_dict:
365
+ return (output,)
366
+
367
+ return TransformerMV2DModelOutput(sample=output)
368
+
369
+
370
+ @maybe_allow_in_graph
371
+ class BasicMVTransformerBlock(nn.Module):
372
+ r"""
373
+ A basic Transformer block.
374
+
375
+ Parameters:
376
+ dim (`int`): The number of channels in the input and output.
377
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
378
+ attention_head_dim (`int`): The number of channels in each head.
379
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
380
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
381
+ only_cross_attention (`bool`, *optional*):
382
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
383
+ double_self_attention (`bool`, *optional*):
384
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
385
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
386
+ num_embeds_ada_norm (:
387
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
388
+ attention_bias (:
389
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
390
+ """
391
+
392
+ def __init__(
393
+ self,
394
+ dim: int,
395
+ num_attention_heads: int,
396
+ attention_head_dim: int,
397
+ dropout=0.0,
398
+ cross_attention_dim: Optional[int] = None,
399
+ activation_fn: str = "geglu",
400
+ num_embeds_ada_norm: Optional[int] = None,
401
+ attention_bias: bool = False,
402
+ only_cross_attention: bool = False,
403
+ double_self_attention: bool = False,
404
+ upcast_attention: bool = False,
405
+ norm_elementwise_affine: bool = True,
406
+ norm_type: str = "layer_norm",
407
+ final_dropout: bool = False,
408
+ num_views: int = 1,
409
+ cd_attention_last: bool = False,
410
+ cd_attention_mid: bool = False,
411
+ multiview_attention: bool = True,
412
+ mvcd_attention: bool = False,
413
+ rowwise_attention: bool = True
414
+ ):
415
+ super().__init__()
416
+ self.only_cross_attention = only_cross_attention
417
+
418
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
419
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
420
+
421
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
422
+ raise ValueError(
423
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
424
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
425
+ )
426
+
427
+ # Define 3 blocks. Each block has its own normalization layer.
428
+ # 1. Self-Attn
429
+ if self.use_ada_layer_norm:
430
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
431
+ elif self.use_ada_layer_norm_zero:
432
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
433
+ else:
434
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
435
+
436
+ self.multiview_attention = multiview_attention
437
+ self.mvcd_attention = mvcd_attention
438
+ self.rowwise_attention = multiview_attention and rowwise_attention
439
+
440
+ # rowwise multiview attention
441
+
442
+ print('INFO: using row wise attention...')
443
+
444
+ self.attn1 = CustomAttention(
445
+ query_dim=dim,
446
+ heads=num_attention_heads,
447
+ dim_head=attention_head_dim,
448
+ dropout=dropout,
449
+ bias=attention_bias,
450
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
451
+ upcast_attention=upcast_attention,
452
+ processor=MVAttnProcessor()
453
+ )
454
+
455
+ # 2. Cross-Attn
456
+ if cross_attention_dim is not None or double_self_attention:
457
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
458
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
459
+ # the second cross attention block.
460
+ self.norm2 = (
461
+ AdaLayerNorm(dim, num_embeds_ada_norm)
462
+ if self.use_ada_layer_norm
463
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
464
+ )
465
+ self.attn2 = Attention(
466
+ query_dim=dim,
467
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
468
+ heads=num_attention_heads,
469
+ dim_head=attention_head_dim,
470
+ dropout=dropout,
471
+ bias=attention_bias,
472
+ upcast_attention=upcast_attention,
473
+ ) # is self-attn if encoder_hidden_states is none
474
+ else:
475
+ self.norm2 = None
476
+ self.attn2 = None
477
+
478
+ # 3. Feed-forward
479
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
480
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
481
+
482
+ # let chunk size default to None
483
+ self._chunk_size = None
484
+ self._chunk_dim = 0
485
+
486
+ self.num_views = num_views
487
+
488
+ self.cd_attention_last = cd_attention_last
489
+
490
+ if self.cd_attention_last:
491
+ # Joint task -Attn
492
+ self.attn_joint = CustomJointAttention(
493
+ query_dim=dim,
494
+ heads=num_attention_heads,
495
+ dim_head=attention_head_dim,
496
+ dropout=dropout,
497
+ bias=attention_bias,
498
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
499
+ upcast_attention=upcast_attention,
500
+ processor=JointAttnProcessor()
501
+ )
502
+ nn.init.zeros_(self.attn_joint.to_out[0].weight.data)
503
+ self.norm_joint = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
504
+
505
+
506
+ self.cd_attention_mid = cd_attention_mid
507
+
508
+ if self.cd_attention_mid:
509
+ print("joint twice")
510
+ # Joint task -Attn
511
+ self.attn_joint_twice = CustomJointAttention(
512
+ query_dim=dim,
513
+ heads=num_attention_heads,
514
+ dim_head=attention_head_dim,
515
+ dropout=dropout,
516
+ bias=attention_bias,
517
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
518
+ upcast_attention=upcast_attention,
519
+ processor=JointAttnProcessor()
520
+ )
521
+ nn.init.zeros_(self.attn_joint_twice.to_out[0].weight.data)
522
+ self.norm_joint_twice = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
523
+
524
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
525
+ # Sets chunk feed-forward
526
+ self._chunk_size = chunk_size
527
+ self._chunk_dim = dim
528
+
529
+ def forward(
530
+ self,
531
+ hidden_states: torch.FloatTensor,
532
+ attention_mask: Optional[torch.FloatTensor] = None,
533
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
534
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
535
+ timestep: Optional[torch.LongTensor] = None,
536
+ cross_attention_kwargs: Dict[str, Any] = None,
537
+ class_labels: Optional[torch.LongTensor] = None,
538
+ ):
539
+ assert attention_mask is None # not supported yet
540
+ # Notice that normalization is always applied before the real computation in the following blocks.
541
+ # 1. Self-Attention
542
+ if self.use_ada_layer_norm:
543
+ norm_hidden_states = self.norm1(hidden_states, timestep)
544
+ elif self.use_ada_layer_norm_zero:
545
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
546
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
547
+ )
548
+ else:
549
+ norm_hidden_states = self.norm1(hidden_states)
550
+
551
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
552
+
553
+ attn_output = self.attn1(
554
+ norm_hidden_states,
555
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
556
+ attention_mask=attention_mask,
557
+ multiview_attention=self.multiview_attention,
558
+ mvcd_attention=self.mvcd_attention,
559
+ num_views=self.num_views,
560
+ **cross_attention_kwargs,
561
+ )
562
+
563
+ if self.use_ada_layer_norm_zero:
564
+ attn_output = gate_msa.unsqueeze(1) * attn_output
565
+ hidden_states = attn_output + hidden_states
566
+
567
+ # joint attention twice
568
+ if self.cd_attention_mid:
569
+ norm_hidden_states = (
570
+ self.norm_joint_twice(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_twice(hidden_states)
571
+ )
572
+ hidden_states = self.attn_joint_twice(norm_hidden_states) + hidden_states
573
+
574
+ # 2. Cross-Attention
575
+ if self.attn2 is not None:
576
+ norm_hidden_states = (
577
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
578
+ )
579
+
580
+ attn_output = self.attn2(
581
+ norm_hidden_states,
582
+ encoder_hidden_states=encoder_hidden_states,
583
+ attention_mask=encoder_attention_mask,
584
+ **cross_attention_kwargs,
585
+ )
586
+ hidden_states = attn_output + hidden_states
587
+
588
+ # 3. Feed-forward
589
+ norm_hidden_states = self.norm3(hidden_states)
590
+
591
+ if self.use_ada_layer_norm_zero:
592
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
593
+
594
+ if self._chunk_size is not None:
595
+ # "feed_forward_chunk_size" can be used to save memory
596
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
597
+ raise ValueError(
598
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
599
+ )
600
+
601
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
602
+ ff_output = torch.cat(
603
+ [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
604
+ dim=self._chunk_dim,
605
+ )
606
+ else:
607
+ ff_output = self.ff(norm_hidden_states)
608
+
609
+ if self.use_ada_layer_norm_zero:
610
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
611
+
612
+ hidden_states = ff_output + hidden_states
613
+
614
+ if self.cd_attention_last:
615
+ norm_hidden_states = (
616
+ self.norm_joint(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint(hidden_states)
617
+ )
618
+ hidden_states = self.attn_joint(norm_hidden_states) + hidden_states
619
+
620
+ return hidden_states
621
+
622
+
623
+ class CustomAttention(Attention):
624
+ def set_use_memory_efficient_attention_xformers(
625
+ self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
626
+ ):
627
+ processor = XFormersMVAttnProcessor()
628
+ self.set_processor(processor)
629
+ # print("using xformers attention processor")
630
+
631
+
632
+ class CustomJointAttention(Attention):
633
+ def set_use_memory_efficient_attention_xformers(
634
+ self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
635
+ ):
636
+ processor = XFormersJointAttnProcessor()
637
+ self.set_processor(processor)
638
+ # print("using xformers attention processor")
639
+
640
+ class MVAttnProcessor:
641
+ r"""
642
+ Default processor for performing attention-related computations.
643
+ """
644
+
645
+ def __call__(
646
+ self,
647
+ attn: Attention,
648
+ hidden_states,
649
+ encoder_hidden_states=None,
650
+ attention_mask=None,
651
+ temb=None,
652
+ num_views=1,
653
+ multiview_attention=True
654
+ ):
655
+ residual = hidden_states
656
+
657
+ if attn.spatial_norm is not None:
658
+ hidden_states = attn.spatial_norm(hidden_states, temb)
659
+
660
+ input_ndim = hidden_states.ndim
661
+
662
+ if input_ndim == 4:
663
+ batch_size, channel, height, width = hidden_states.shape
664
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
665
+
666
+ batch_size, sequence_length, _ = (
667
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
668
+ )
669
+ height = int(math.sqrt(sequence_length))
670
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
671
+
672
+ if attn.group_norm is not None:
673
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
674
+
675
+ query = attn.to_q(hidden_states)
676
+
677
+ if encoder_hidden_states is None:
678
+ encoder_hidden_states = hidden_states
679
+ elif attn.norm_cross:
680
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
681
+
682
+ key = attn.to_k(encoder_hidden_states)
683
+ value = attn.to_v(encoder_hidden_states)
684
+
685
+ # print('query', query.shape, 'key', key.shape, 'value', value.shape)
686
+ #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320])
687
+ # pdb.set_trace()
688
+ # multi-view self-attention
689
+ key = rearrange(key, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
690
+ value = rearrange(value, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
691
+ query = rearrange(query, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) # torch.Size([192, 384, 320])
692
+
693
+ query = attn.head_to_batch_dim(query).contiguous()
694
+ key = attn.head_to_batch_dim(key).contiguous()
695
+ value = attn.head_to_batch_dim(value).contiguous()
696
+
697
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
698
+ hidden_states = torch.bmm(attention_probs, value)
699
+ hidden_states = attn.batch_to_head_dim(hidden_states)
700
+
701
+ # linear proj
702
+ hidden_states = attn.to_out[0](hidden_states)
703
+ # dropout
704
+ hidden_states = attn.to_out[1](hidden_states)
705
+ hidden_states = rearrange(hidden_states, "(b h) (v w) c -> (b v) (h w) c", v=num_views, h=height)
706
+ if input_ndim == 4:
707
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
708
+
709
+ if attn.residual_connection:
710
+ hidden_states = hidden_states + residual
711
+
712
+ hidden_states = hidden_states / attn.rescale_output_factor
713
+
714
+ return hidden_states
715
+
716
+
717
+ class XFormersMVAttnProcessor:
718
+ r"""
719
+ Default processor for performing attention-related computations.
720
+ """
721
+
722
+ def __call__(
723
+ self,
724
+ attn: Attention,
725
+ hidden_states,
726
+ encoder_hidden_states=None,
727
+ attention_mask=None,
728
+ temb=None,
729
+ num_views=1,
730
+ multiview_attention=True,
731
+ mvcd_attention=False,
732
+ ):
733
+ residual = hidden_states
734
+
735
+ if attn.spatial_norm is not None:
736
+ hidden_states = attn.spatial_norm(hidden_states, temb)
737
+
738
+ input_ndim = hidden_states.ndim
739
+
740
+ if input_ndim == 4:
741
+ batch_size, channel, height, width = hidden_states.shape
742
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
743
+
744
+ batch_size, sequence_length, _ = (
745
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
746
+ )
747
+ height = int(math.sqrt(sequence_length))
748
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
749
+ # from yuancheng; here attention_mask is None
750
+ if attention_mask is not None:
751
+ # expand our mask's singleton query_tokens dimension:
752
+ # [batch*heads, 1, key_tokens] ->
753
+ # [batch*heads, query_tokens, key_tokens]
754
+ # so that it can be added as a bias onto the attention scores that xformers computes:
755
+ # [batch*heads, query_tokens, key_tokens]
756
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
757
+ _, query_tokens, _ = hidden_states.shape
758
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
759
+
760
+ if attn.group_norm is not None:
761
+ print('Warning: using group norm, pay attention to use it in row-wise attention')
762
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
763
+
764
+ query = attn.to_q(hidden_states)
765
+
766
+ if encoder_hidden_states is None:
767
+ encoder_hidden_states = hidden_states
768
+ elif attn.norm_cross:
769
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
770
+
771
+ key_raw = attn.to_k(encoder_hidden_states)
772
+ value_raw = attn.to_v(encoder_hidden_states)
773
+
774
+ # print('query', query.shape, 'key', key.shape, 'value', value.shape)
775
+ # pdb.set_trace()
776
+
777
+ key = rearrange(key_raw, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
778
+ value = rearrange(value_raw, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
779
+ query = rearrange(query, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) # torch.Size([192, 384, 320])
780
+ if mvcd_attention:
781
+ # memory efficient, cross domain attention
782
+ key_0, key_1 = torch.chunk(key_raw, dim=0, chunks=2) # keys shape (b t) d c
783
+ value_0, value_1 = torch.chunk(value_raw, dim=0, chunks=2)
784
+ key_cross = torch.concat([key_1, key_0], dim=0)
785
+ value_cross = torch.concat([value_1, value_0], dim=0) # shape (b t) d c
786
+ key = torch.cat([key, key_cross], dim=1)
787
+ value = torch.cat([value, value_cross], dim=1) # shape (b t) (t+1 d) c
788
+
789
+
790
+ query = attn.head_to_batch_dim(query) # torch.Size([960, 384, 64])
791
+ key = attn.head_to_batch_dim(key)
792
+ value = attn.head_to_batch_dim(value)
793
+
794
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
795
+ hidden_states = attn.batch_to_head_dim(hidden_states)
796
+
797
+ # linear proj
798
+ hidden_states = attn.to_out[0](hidden_states)
799
+ # dropout
800
+ hidden_states = attn.to_out[1](hidden_states)
801
+ # print(hidden_states.shape)
802
+ hidden_states = rearrange(hidden_states, "(b h) (v w) c -> (b v) (h w) c", v=num_views, h=height)
803
+ if input_ndim == 4:
804
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
805
+
806
+ if attn.residual_connection:
807
+ hidden_states = hidden_states + residual
808
+
809
+ hidden_states = hidden_states / attn.rescale_output_factor
810
+
811
+ return hidden_states
812
+
813
+
814
+ class XFormersJointAttnProcessor:
815
+ r"""
816
+ Default processor for performing attention-related computations.
817
+ """
818
+
819
+ def __call__(
820
+ self,
821
+ attn: Attention,
822
+ hidden_states,
823
+ encoder_hidden_states=None,
824
+ attention_mask=None,
825
+ temb=None,
826
+ num_tasks=2
827
+ ):
828
+
829
+ residual = hidden_states
830
+
831
+ if attn.spatial_norm is not None:
832
+ hidden_states = attn.spatial_norm(hidden_states, temb)
833
+
834
+ input_ndim = hidden_states.ndim
835
+
836
+ if input_ndim == 4:
837
+ batch_size, channel, height, width = hidden_states.shape
838
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
839
+
840
+ batch_size, sequence_length, _ = (
841
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
842
+ )
843
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
844
+
845
+ # from yuancheng; here attention_mask is None
846
+ if attention_mask is not None:
847
+ # expand our mask's singleton query_tokens dimension:
848
+ # [batch*heads, 1, key_tokens] ->
849
+ # [batch*heads, query_tokens, key_tokens]
850
+ # so that it can be added as a bias onto the attention scores that xformers computes:
851
+ # [batch*heads, query_tokens, key_tokens]
852
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
853
+ _, query_tokens, _ = hidden_states.shape
854
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
855
+
856
+ if attn.group_norm is not None:
857
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
858
+
859
+ query = attn.to_q(hidden_states)
860
+
861
+ if encoder_hidden_states is None:
862
+ encoder_hidden_states = hidden_states
863
+ elif attn.norm_cross:
864
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
865
+
866
+ key = attn.to_k(encoder_hidden_states)
867
+ value = attn.to_v(encoder_hidden_states)
868
+
869
+ assert num_tasks == 2 # only support two tasks now
870
+
871
+ key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c
872
+ value_0, value_1 = torch.chunk(value, dim=0, chunks=2)
873
+ key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c
874
+ value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c
875
+ key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c
876
+ value = torch.cat([value]*2, dim=0) # (2 b t) 2d c
877
+
878
+
879
+ query = attn.head_to_batch_dim(query).contiguous()
880
+ key = attn.head_to_batch_dim(key).contiguous()
881
+ value = attn.head_to_batch_dim(value).contiguous()
882
+
883
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
884
+ hidden_states = attn.batch_to_head_dim(hidden_states)
885
+
886
+ # linear proj
887
+ hidden_states = attn.to_out[0](hidden_states)
888
+ # dropout
889
+ hidden_states = attn.to_out[1](hidden_states)
890
+
891
+ if input_ndim == 4:
892
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
893
+
894
+ if attn.residual_connection:
895
+ hidden_states = hidden_states + residual
896
+
897
+ hidden_states = hidden_states / attn.rescale_output_factor
898
+
899
+ return hidden_states
900
+
901
+
902
+ class JointAttnProcessor:
903
+ r"""
904
+ Default processor for performing attention-related computations.
905
+ """
906
+
907
+ def __call__(
908
+ self,
909
+ attn: Attention,
910
+ hidden_states,
911
+ encoder_hidden_states=None,
912
+ attention_mask=None,
913
+ temb=None,
914
+ num_tasks=2
915
+ ):
916
+
917
+ residual = hidden_states
918
+
919
+ if attn.spatial_norm is not None:
920
+ hidden_states = attn.spatial_norm(hidden_states, temb)
921
+
922
+ input_ndim = hidden_states.ndim
923
+
924
+ if input_ndim == 4:
925
+ batch_size, channel, height, width = hidden_states.shape
926
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
927
+
928
+ batch_size, sequence_length, _ = (
929
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
930
+ )
931
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
932
+
933
+
934
+ if attn.group_norm is not None:
935
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
936
+
937
+ query = attn.to_q(hidden_states)
938
+
939
+ if encoder_hidden_states is None:
940
+ encoder_hidden_states = hidden_states
941
+ elif attn.norm_cross:
942
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
943
+
944
+ key = attn.to_k(encoder_hidden_states)
945
+ value = attn.to_v(encoder_hidden_states)
946
+
947
+ assert num_tasks == 2 # only support two tasks now
948
+
949
+ key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c
950
+ value_0, value_1 = torch.chunk(value, dim=0, chunks=2)
951
+ key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c
952
+ value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c
953
+ key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c
954
+ value = torch.cat([value]*2, dim=0) # (2 b t) 2d c
955
+
956
+
957
+ query = attn.head_to_batch_dim(query).contiguous()
958
+ key = attn.head_to_batch_dim(key).contiguous()
959
+ value = attn.head_to_batch_dim(value).contiguous()
960
+
961
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
962
+ hidden_states = torch.bmm(attention_probs, value)
963
+ hidden_states = attn.batch_to_head_dim(hidden_states)
964
+
965
+ # linear proj
966
+ hidden_states = attn.to_out[0](hidden_states)
967
+ # dropout
968
+ hidden_states = attn.to_out[1](hidden_states)
969
+
970
+ if input_ndim == 4:
971
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
972
+
973
+ if attn.residual_connection:
974
+ hidden_states = hidden_states + residual
975
+
976
+ hidden_states = hidden_states / attn.rescale_output_factor
977
+
978
+ return hidden_states
mvdiffusion/models/transformer_mv2d_self_rowwise.py ADDED
@@ -0,0 +1,1038 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, Optional
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.models.embeddings import ImagePositionalEmbeddings
23
+ from diffusers.utils import BaseOutput, deprecate
24
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
25
+ from diffusers.models.attention import FeedForward, AdaLayerNorm, AdaLayerNormZero, Attention
26
+ from diffusers.models.embeddings import PatchEmbed
27
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
28
+ from diffusers.models.modeling_utils import ModelMixin
29
+ from diffusers.utils.import_utils import is_xformers_available
30
+
31
+ from einops import rearrange
32
+ import pdb
33
+ import random
34
+ import math
35
+
36
+
37
+ if is_xformers_available():
38
+ import xformers
39
+ import xformers.ops
40
+ else:
41
+ xformers = None
42
+
43
+
44
+ @dataclass
45
+ class TransformerMV2DModelOutput(BaseOutput):
46
+ """
47
+ The output of [`Transformer2DModel`].
48
+
49
+ Args:
50
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
51
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
52
+ distributions for the unnoised latent pixels.
53
+ """
54
+
55
+ sample: torch.FloatTensor
56
+
57
+
58
+ class TransformerMV2DModel(ModelMixin, ConfigMixin):
59
+ """
60
+ A 2D Transformer model for image-like data.
61
+
62
+ Parameters:
63
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
64
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
65
+ in_channels (`int`, *optional*):
66
+ The number of channels in the input and output (specify if the input is **continuous**).
67
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
68
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
69
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
70
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
71
+ This is fixed during training since it is used to learn a number of position embeddings.
72
+ num_vector_embeds (`int`, *optional*):
73
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
74
+ Includes the class for the masked latent pixel.
75
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
76
+ num_embeds_ada_norm ( `int`, *optional*):
77
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
78
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
79
+ added to the hidden states.
80
+
81
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
82
+ attention_bias (`bool`, *optional*):
83
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
84
+ """
85
+
86
+ @register_to_config
87
+ def __init__(
88
+ self,
89
+ num_attention_heads: int = 16,
90
+ attention_head_dim: int = 88,
91
+ in_channels: Optional[int] = None,
92
+ out_channels: Optional[int] = None,
93
+ num_layers: int = 1,
94
+ dropout: float = 0.0,
95
+ norm_num_groups: int = 32,
96
+ cross_attention_dim: Optional[int] = None,
97
+ attention_bias: bool = False,
98
+ sample_size: Optional[int] = None,
99
+ num_vector_embeds: Optional[int] = None,
100
+ patch_size: Optional[int] = None,
101
+ activation_fn: str = "geglu",
102
+ num_embeds_ada_norm: Optional[int] = None,
103
+ use_linear_projection: bool = False,
104
+ only_cross_attention: bool = False,
105
+ upcast_attention: bool = False,
106
+ norm_type: str = "layer_norm",
107
+ norm_elementwise_affine: bool = True,
108
+ num_views: int = 1,
109
+ cd_attention_mid: bool=False,
110
+ cd_attention_last: bool=False,
111
+ multiview_attention: bool=True,
112
+ sparse_mv_attention: bool = True, # not used
113
+ mvcd_attention: bool=False,
114
+ use_dino: bool=False
115
+ ):
116
+ super().__init__()
117
+ self.use_linear_projection = use_linear_projection
118
+ self.num_attention_heads = num_attention_heads
119
+ self.attention_head_dim = attention_head_dim
120
+ inner_dim = num_attention_heads * attention_head_dim
121
+
122
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
123
+ # Define whether input is continuous or discrete depending on configuration
124
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
125
+ self.is_input_vectorized = num_vector_embeds is not None
126
+ self.is_input_patches = in_channels is not None and patch_size is not None
127
+
128
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
129
+ deprecation_message = (
130
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
131
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
132
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
133
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
134
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
135
+ )
136
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
137
+ norm_type = "ada_norm"
138
+
139
+ if self.is_input_continuous and self.is_input_vectorized:
140
+ raise ValueError(
141
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
142
+ " sure that either `in_channels` or `num_vector_embeds` is None."
143
+ )
144
+ elif self.is_input_vectorized and self.is_input_patches:
145
+ raise ValueError(
146
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
147
+ " sure that either `num_vector_embeds` or `num_patches` is None."
148
+ )
149
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
150
+ raise ValueError(
151
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
152
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
153
+ )
154
+
155
+ # 2. Define input layers
156
+ if self.is_input_continuous:
157
+ self.in_channels = in_channels
158
+
159
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
160
+ if use_linear_projection:
161
+ self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
162
+ else:
163
+ self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
164
+ elif self.is_input_vectorized:
165
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
166
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
167
+
168
+ self.height = sample_size
169
+ self.width = sample_size
170
+ self.num_vector_embeds = num_vector_embeds
171
+ self.num_latent_pixels = self.height * self.width
172
+
173
+ self.latent_image_embedding = ImagePositionalEmbeddings(
174
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
175
+ )
176
+ elif self.is_input_patches:
177
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
178
+
179
+ self.height = sample_size
180
+ self.width = sample_size
181
+
182
+ self.patch_size = patch_size
183
+ self.pos_embed = PatchEmbed(
184
+ height=sample_size,
185
+ width=sample_size,
186
+ patch_size=patch_size,
187
+ in_channels=in_channels,
188
+ embed_dim=inner_dim,
189
+ )
190
+
191
+ # 3. Define transformers blocks
192
+ self.transformer_blocks = nn.ModuleList(
193
+ [
194
+ BasicMVTransformerBlock(
195
+ inner_dim,
196
+ num_attention_heads,
197
+ attention_head_dim,
198
+ dropout=dropout,
199
+ cross_attention_dim=cross_attention_dim,
200
+ activation_fn=activation_fn,
201
+ num_embeds_ada_norm=num_embeds_ada_norm,
202
+ attention_bias=attention_bias,
203
+ only_cross_attention=only_cross_attention,
204
+ upcast_attention=upcast_attention,
205
+ norm_type=norm_type,
206
+ norm_elementwise_affine=norm_elementwise_affine,
207
+ num_views=num_views,
208
+ cd_attention_last=cd_attention_last,
209
+ cd_attention_mid=cd_attention_mid,
210
+ multiview_attention=multiview_attention,
211
+ mvcd_attention=mvcd_attention,
212
+ use_dino=use_dino
213
+ )
214
+ for d in range(num_layers)
215
+ ]
216
+ )
217
+
218
+ # 4. Define output layers
219
+ self.out_channels = in_channels if out_channels is None else out_channels
220
+ if self.is_input_continuous:
221
+ # TODO: should use out_channels for continuous projections
222
+ if use_linear_projection:
223
+ self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
224
+ else:
225
+ self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
226
+ elif self.is_input_vectorized:
227
+ self.norm_out = nn.LayerNorm(inner_dim)
228
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
229
+ elif self.is_input_patches:
230
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
231
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
232
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
233
+
234
+ def forward(
235
+ self,
236
+ hidden_states: torch.Tensor,
237
+ encoder_hidden_states: Optional[torch.Tensor] = None,
238
+ dino_feature: Optional[torch.Tensor] = None,
239
+ timestep: Optional[torch.LongTensor] = None,
240
+ class_labels: Optional[torch.LongTensor] = None,
241
+ cross_attention_kwargs: Dict[str, Any] = None,
242
+ attention_mask: Optional[torch.Tensor] = None,
243
+ encoder_attention_mask: Optional[torch.Tensor] = None,
244
+ return_dict: bool = True,
245
+ ):
246
+ """
247
+ The [`Transformer2DModel`] forward method.
248
+
249
+ Args:
250
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
251
+ Input `hidden_states`.
252
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
253
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
254
+ self-attention.
255
+ timestep ( `torch.LongTensor`, *optional*):
256
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
257
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
258
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
259
+ `AdaLayerZeroNorm`.
260
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
261
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
262
+
263
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
264
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
265
+
266
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
267
+ above. This bias will be added to the cross-attention scores.
268
+ return_dict (`bool`, *optional*, defaults to `True`):
269
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
270
+ tuple.
271
+
272
+ Returns:
273
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
274
+ `tuple` where the first element is the sample tensor.
275
+ """
276
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
277
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
278
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
279
+ # expects mask of shape:
280
+ # [batch, key_tokens]
281
+ # adds singleton query_tokens dimension:
282
+ # [batch, 1, key_tokens]
283
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
284
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
285
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
286
+ if attention_mask is not None and attention_mask.ndim == 2:
287
+ # assume that mask is expressed as:
288
+ # (1 = keep, 0 = discard)
289
+ # convert mask into a bias that can be added to attention scores:
290
+ # (keep = +0, discard = -10000.0)
291
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
292
+ attention_mask = attention_mask.unsqueeze(1)
293
+
294
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
295
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
296
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
297
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
298
+
299
+ # 1. Input
300
+ if self.is_input_continuous:
301
+ batch, _, height, width = hidden_states.shape
302
+ residual = hidden_states
303
+
304
+ hidden_states = self.norm(hidden_states)
305
+ if not self.use_linear_projection:
306
+ hidden_states = self.proj_in(hidden_states)
307
+ inner_dim = hidden_states.shape[1]
308
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
309
+ else:
310
+ inner_dim = hidden_states.shape[1]
311
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
312
+ hidden_states = self.proj_in(hidden_states)
313
+ elif self.is_input_vectorized:
314
+ hidden_states = self.latent_image_embedding(hidden_states)
315
+ elif self.is_input_patches:
316
+ hidden_states = self.pos_embed(hidden_states)
317
+
318
+ # 2. Blocks
319
+ for block in self.transformer_blocks:
320
+ hidden_states = block(
321
+ hidden_states,
322
+ attention_mask=attention_mask,
323
+ encoder_hidden_states=encoder_hidden_states,
324
+ dino_feature=dino_feature,
325
+ encoder_attention_mask=encoder_attention_mask,
326
+ timestep=timestep,
327
+ cross_attention_kwargs=cross_attention_kwargs,
328
+ class_labels=class_labels,
329
+ )
330
+
331
+ # 3. Output
332
+ if self.is_input_continuous:
333
+ if not self.use_linear_projection:
334
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
335
+ hidden_states = self.proj_out(hidden_states)
336
+ else:
337
+ hidden_states = self.proj_out(hidden_states)
338
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
339
+
340
+ output = hidden_states + residual
341
+ elif self.is_input_vectorized:
342
+ hidden_states = self.norm_out(hidden_states)
343
+ logits = self.out(hidden_states)
344
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
345
+ logits = logits.permute(0, 2, 1)
346
+
347
+ # log(p(x_0))
348
+ output = F.log_softmax(logits.double(), dim=1).float()
349
+ elif self.is_input_patches:
350
+ # TODO: cleanup!
351
+ conditioning = self.transformer_blocks[0].norm1.emb(
352
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
353
+ )
354
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
355
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
356
+ hidden_states = self.proj_out_2(hidden_states)
357
+
358
+ # unpatchify
359
+ height = width = int(hidden_states.shape[1] ** 0.5)
360
+ hidden_states = hidden_states.reshape(
361
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
362
+ )
363
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
364
+ output = hidden_states.reshape(
365
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
366
+ )
367
+
368
+ if not return_dict:
369
+ return (output,)
370
+
371
+ return TransformerMV2DModelOutput(sample=output)
372
+
373
+
374
+ @maybe_allow_in_graph
375
+ class BasicMVTransformerBlock(nn.Module):
376
+ r"""
377
+ A basic Transformer block.
378
+
379
+ Parameters:
380
+ dim (`int`): The number of channels in the input and output.
381
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
382
+ attention_head_dim (`int`): The number of channels in each head.
383
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
384
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
385
+ only_cross_attention (`bool`, *optional*):
386
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
387
+ double_self_attention (`bool`, *optional*):
388
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
389
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
390
+ num_embeds_ada_norm (:
391
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
392
+ attention_bias (:
393
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
394
+ """
395
+
396
+ def __init__(
397
+ self,
398
+ dim: int,
399
+ num_attention_heads: int,
400
+ attention_head_dim: int,
401
+ dropout=0.0,
402
+ cross_attention_dim: Optional[int] = None,
403
+ activation_fn: str = "geglu",
404
+ num_embeds_ada_norm: Optional[int] = None,
405
+ attention_bias: bool = False,
406
+ only_cross_attention: bool = False,
407
+ double_self_attention: bool = False,
408
+ upcast_attention: bool = False,
409
+ norm_elementwise_affine: bool = True,
410
+ norm_type: str = "layer_norm",
411
+ final_dropout: bool = False,
412
+ num_views: int = 1,
413
+ cd_attention_last: bool = False,
414
+ cd_attention_mid: bool = False,
415
+ multiview_attention: bool = True,
416
+ mvcd_attention: bool = False,
417
+ rowwise_attention: bool = True,
418
+ use_dino: bool = False
419
+ ):
420
+ super().__init__()
421
+ self.only_cross_attention = only_cross_attention
422
+
423
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
424
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
425
+
426
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
427
+ raise ValueError(
428
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
429
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
430
+ )
431
+
432
+ # Define 3 blocks. Each block has its own normalization layer.
433
+ # 1. Self-Attn
434
+ if self.use_ada_layer_norm:
435
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
436
+ elif self.use_ada_layer_norm_zero:
437
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
438
+ else:
439
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
440
+
441
+ self.multiview_attention = multiview_attention
442
+ self.mvcd_attention = mvcd_attention
443
+ self.cd_attention_mid = cd_attention_mid
444
+ self.rowwise_attention = multiview_attention and rowwise_attention
445
+
446
+ if mvcd_attention and (not cd_attention_mid):
447
+ # add cross domain attn to self attn
448
+ self.attn1 = CustomJointAttention(
449
+ query_dim=dim,
450
+ heads=num_attention_heads,
451
+ dim_head=attention_head_dim,
452
+ dropout=dropout,
453
+ bias=attention_bias,
454
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
455
+ upcast_attention=upcast_attention,
456
+ processor=JointAttnProcessor()
457
+ )
458
+ else:
459
+ self.attn1 = Attention(
460
+ query_dim=dim,
461
+ heads=num_attention_heads,
462
+ dim_head=attention_head_dim,
463
+ dropout=dropout,
464
+ bias=attention_bias,
465
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
466
+ upcast_attention=upcast_attention
467
+ )
468
+ # 1.1 rowwise multiview attention
469
+ if self.rowwise_attention:
470
+ # print('INFO: using self+row_wise mv attention...')
471
+ self.norm_mv = (
472
+ AdaLayerNorm(dim, num_embeds_ada_norm)
473
+ if self.use_ada_layer_norm
474
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
475
+ )
476
+ self.attn_mv = CustomAttention(
477
+ query_dim=dim,
478
+ heads=num_attention_heads,
479
+ dim_head=attention_head_dim,
480
+ dropout=dropout,
481
+ bias=attention_bias,
482
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
483
+ upcast_attention=upcast_attention,
484
+ processor=MVAttnProcessor()
485
+ )
486
+ nn.init.zeros_(self.attn_mv.to_out[0].weight.data)
487
+ else:
488
+ self.norm_mv = None
489
+ self.attn_mv = None
490
+
491
+ # # 1.2 rowwise cross-domain attn
492
+ # if mvcd_attention:
493
+ # self.attn_joint = CustomJointAttention(
494
+ # query_dim=dim,
495
+ # heads=num_attention_heads,
496
+ # dim_head=attention_head_dim,
497
+ # dropout=dropout,
498
+ # bias=attention_bias,
499
+ # cross_attention_dim=cross_attention_dim if only_cross_attention else None,
500
+ # upcast_attention=upcast_attention,
501
+ # processor=JointAttnProcessor()
502
+ # )
503
+ # nn.init.zeros_(self.attn_joint.to_out[0].weight.data)
504
+ # self.norm_joint = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
505
+ # else:
506
+ # self.attn_joint = None
507
+ # self.norm_joint = None
508
+
509
+ # 2. Cross-Attn
510
+ if cross_attention_dim is not None or double_self_attention:
511
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
512
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
513
+ # the second cross attention block.
514
+ self.norm2 = (
515
+ AdaLayerNorm(dim, num_embeds_ada_norm)
516
+ if self.use_ada_layer_norm
517
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
518
+ )
519
+ self.attn2 = Attention(
520
+ query_dim=dim,
521
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
522
+ heads=num_attention_heads,
523
+ dim_head=attention_head_dim,
524
+ dropout=dropout,
525
+ bias=attention_bias,
526
+ upcast_attention=upcast_attention,
527
+ ) # is self-attn if encoder_hidden_states is none
528
+ else:
529
+ self.norm2 = None
530
+ self.attn2 = None
531
+
532
+ # 3. Feed-forward
533
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
534
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
535
+
536
+ # let chunk size default to None
537
+ self._chunk_size = None
538
+ self._chunk_dim = 0
539
+
540
+ self.num_views = num_views
541
+
542
+
543
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
544
+ # Sets chunk feed-forward
545
+ self._chunk_size = chunk_size
546
+ self._chunk_dim = dim
547
+
548
+ def forward(
549
+ self,
550
+ hidden_states: torch.FloatTensor,
551
+ attention_mask: Optional[torch.FloatTensor] = None,
552
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
553
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
554
+ timestep: Optional[torch.LongTensor] = None,
555
+ cross_attention_kwargs: Dict[str, Any] = None,
556
+ class_labels: Optional[torch.LongTensor] = None,
557
+ dino_feature: Optional[torch.FloatTensor] = None
558
+ ):
559
+ assert attention_mask is None # not supported yet
560
+ # Notice that normalization is always applied before the real computation in the following blocks.
561
+ # 1. Self-Attention
562
+ if self.use_ada_layer_norm:
563
+ norm_hidden_states = self.norm1(hidden_states, timestep)
564
+ elif self.use_ada_layer_norm_zero:
565
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
566
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
567
+ )
568
+ else:
569
+ norm_hidden_states = self.norm1(hidden_states)
570
+
571
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
572
+
573
+ attn_output = self.attn1(
574
+ norm_hidden_states,
575
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
576
+ attention_mask=attention_mask,
577
+ # multiview_attention=self.multiview_attention,
578
+ # mvcd_attention=self.mvcd_attention,
579
+ **cross_attention_kwargs,
580
+ )
581
+
582
+
583
+ if self.use_ada_layer_norm_zero:
584
+ attn_output = gate_msa.unsqueeze(1) * attn_output
585
+ hidden_states = attn_output + hidden_states
586
+
587
+ # import pdb;pdb.set_trace()
588
+ # 1.1 row wise multiview attention
589
+ if self.rowwise_attention:
590
+ norm_hidden_states = (
591
+ self.norm_mv(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_mv(hidden_states)
592
+ )
593
+ attn_output = self.attn_mv(
594
+ norm_hidden_states,
595
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
596
+ attention_mask=attention_mask,
597
+ num_views=self.num_views,
598
+ multiview_attention=self.multiview_attention,
599
+ cd_attention_mid=self.cd_attention_mid,
600
+ **cross_attention_kwargs,
601
+ )
602
+ hidden_states = attn_output + hidden_states
603
+
604
+
605
+ # 2. Cross-Attention
606
+ if self.attn2 is not None:
607
+ norm_hidden_states = (
608
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
609
+ )
610
+
611
+ attn_output = self.attn2(
612
+ norm_hidden_states,
613
+ encoder_hidden_states=encoder_hidden_states,
614
+ attention_mask=encoder_attention_mask,
615
+ **cross_attention_kwargs,
616
+ )
617
+ hidden_states = attn_output + hidden_states
618
+
619
+ # 3. Feed-forward
620
+ norm_hidden_states = self.norm3(hidden_states)
621
+
622
+ if self.use_ada_layer_norm_zero:
623
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
624
+
625
+ if self._chunk_size is not None:
626
+ # "feed_forward_chunk_size" can be used to save memory
627
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
628
+ raise ValueError(
629
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
630
+ )
631
+
632
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
633
+ ff_output = torch.cat(
634
+ [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
635
+ dim=self._chunk_dim,
636
+ )
637
+ else:
638
+ ff_output = self.ff(norm_hidden_states)
639
+
640
+ if self.use_ada_layer_norm_zero:
641
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
642
+
643
+ hidden_states = ff_output + hidden_states
644
+
645
+ return hidden_states
646
+
647
+
648
+ class CustomAttention(Attention):
649
+ def set_use_memory_efficient_attention_xformers(
650
+ self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
651
+ ):
652
+ processor = XFormersMVAttnProcessor()
653
+ self.set_processor(processor)
654
+ # print("using xformers attention processor")
655
+
656
+
657
+ class CustomJointAttention(Attention):
658
+ def set_use_memory_efficient_attention_xformers(
659
+ self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
660
+ ):
661
+ processor = XFormersJointAttnProcessor()
662
+ self.set_processor(processor)
663
+ # print("using xformers attention processor")
664
+
665
+ class MVAttnProcessor:
666
+ r"""
667
+ Default processor for performing attention-related computations.
668
+ """
669
+
670
+ def __call__(
671
+ self,
672
+ attn: Attention,
673
+ hidden_states,
674
+ encoder_hidden_states=None,
675
+ attention_mask=None,
676
+ temb=None,
677
+ num_views=1,
678
+ cd_attention_mid=False
679
+ ):
680
+ residual = hidden_states
681
+
682
+ if attn.spatial_norm is not None:
683
+ hidden_states = attn.spatial_norm(hidden_states, temb)
684
+
685
+ input_ndim = hidden_states.ndim
686
+
687
+ if input_ndim == 4:
688
+ batch_size, channel, height, width = hidden_states.shape
689
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
690
+
691
+ batch_size, sequence_length, _ = (
692
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
693
+ )
694
+ height = int(math.sqrt(sequence_length))
695
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
696
+
697
+ if attn.group_norm is not None:
698
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
699
+
700
+ query = attn.to_q(hidden_states)
701
+
702
+ if encoder_hidden_states is None:
703
+ encoder_hidden_states = hidden_states
704
+ elif attn.norm_cross:
705
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
706
+
707
+ key = attn.to_k(encoder_hidden_states)
708
+ value = attn.to_v(encoder_hidden_states)
709
+
710
+ # print('query', query.shape, 'key', key.shape, 'value', value.shape)
711
+ #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320])
712
+ # pdb.set_trace()
713
+ # multi-view self-attention
714
+ def transpose(tensor):
715
+ tensor = rearrange(tensor, "(b v) (h w) c -> b v h w c", v=num_views, h=height)
716
+ tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # b v h w c
717
+ tensor = torch.cat([tensor_0, tensor_1], dim=3) # b v h 2w c
718
+ tensor = rearrange(tensor, "b v h w c -> (b h) (v w) c", v=num_views, h=height)
719
+ return tensor
720
+
721
+ if cd_attention_mid:
722
+ key = transpose(key)
723
+ value = transpose(value)
724
+ query = transpose(query)
725
+ else:
726
+ key = rearrange(key, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
727
+ value = rearrange(value, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
728
+ query = rearrange(query, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) # torch.Size([192, 384, 320])
729
+
730
+ query = attn.head_to_batch_dim(query).contiguous()
731
+ key = attn.head_to_batch_dim(key).contiguous()
732
+ value = attn.head_to_batch_dim(value).contiguous()
733
+
734
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
735
+ hidden_states = torch.bmm(attention_probs, value)
736
+ hidden_states = attn.batch_to_head_dim(hidden_states)
737
+
738
+ # linear proj
739
+ hidden_states = attn.to_out[0](hidden_states)
740
+ # dropout
741
+ hidden_states = attn.to_out[1](hidden_states)
742
+ if cd_attention_mid:
743
+ hidden_states = rearrange(hidden_states, "(b h) (v w) c -> b v h w c", v=num_views, h=height)
744
+ hidden_states_0, hidden_states_1 = torch.chunk(hidden_states, dim=3, chunks=2) # b v h w c
745
+ hidden_states = torch.cat([hidden_states_0, hidden_states_1], dim=0) # 2b v h w c
746
+ hidden_states = rearrange(hidden_states, "b v h w c -> (b v) (h w) c", v=num_views, h=height)
747
+ else:
748
+ hidden_states = rearrange(hidden_states, "(b h) (v w) c -> (b v) (h w) c", v=num_views, h=height)
749
+ if input_ndim == 4:
750
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
751
+
752
+ if attn.residual_connection:
753
+ hidden_states = hidden_states + residual
754
+
755
+ hidden_states = hidden_states / attn.rescale_output_factor
756
+
757
+ return hidden_states
758
+
759
+
760
+ class XFormersMVAttnProcessor:
761
+ r"""
762
+ Default processor for performing attention-related computations.
763
+ """
764
+
765
+ def __call__(
766
+ self,
767
+ attn: Attention,
768
+ hidden_states,
769
+ encoder_hidden_states=None,
770
+ attention_mask=None,
771
+ temb=None,
772
+ num_views=1,
773
+ multiview_attention=True,
774
+ cd_attention_mid=False
775
+ ):
776
+ # print(num_views)
777
+ residual = hidden_states
778
+
779
+ if attn.spatial_norm is not None:
780
+ hidden_states = attn.spatial_norm(hidden_states, temb)
781
+
782
+ input_ndim = hidden_states.ndim
783
+
784
+ if input_ndim == 4:
785
+ batch_size, channel, height, width = hidden_states.shape
786
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
787
+
788
+ batch_size, sequence_length, _ = (
789
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
790
+ )
791
+ height = int(math.sqrt(sequence_length))
792
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
793
+ # from yuancheng; here attention_mask is None
794
+ if attention_mask is not None:
795
+ # expand our mask's singleton query_tokens dimension:
796
+ # [batch*heads, 1, key_tokens] ->
797
+ # [batch*heads, query_tokens, key_tokens]
798
+ # so that it can be added as a bias onto the attention scores that xformers computes:
799
+ # [batch*heads, query_tokens, key_tokens]
800
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
801
+ _, query_tokens, _ = hidden_states.shape
802
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
803
+
804
+ if attn.group_norm is not None:
805
+ print('Warning: using group norm, pay attention to use it in row-wise attention')
806
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
807
+
808
+ query = attn.to_q(hidden_states)
809
+
810
+ if encoder_hidden_states is None:
811
+ encoder_hidden_states = hidden_states
812
+ elif attn.norm_cross:
813
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
814
+
815
+ key_raw = attn.to_k(encoder_hidden_states)
816
+ value_raw = attn.to_v(encoder_hidden_states)
817
+
818
+ # print('query', query.shape, 'key', key.shape, 'value', value.shape)
819
+ # pdb.set_trace()
820
+ def transpose(tensor):
821
+ tensor = rearrange(tensor, "(b v) (h w) c -> b v h w c", v=num_views, h=height)
822
+ tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # b v h w c
823
+ tensor = torch.cat([tensor_0, tensor_1], dim=3) # b v h 2w c
824
+ tensor = rearrange(tensor, "b v h w c -> (b h) (v w) c", v=num_views, h=height)
825
+ return tensor
826
+ # print(mvcd_attention)
827
+ # import pdb;pdb.set_trace()
828
+ if cd_attention_mid:
829
+ key = transpose(key_raw)
830
+ value = transpose(value_raw)
831
+ query = transpose(query)
832
+ else:
833
+ key = rearrange(key_raw, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
834
+ value = rearrange(value_raw, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
835
+ query = rearrange(query, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) # torch.Size([192, 384, 320])
836
+
837
+
838
+ query = attn.head_to_batch_dim(query) # torch.Size([960, 384, 64])
839
+ key = attn.head_to_batch_dim(key)
840
+ value = attn.head_to_batch_dim(value)
841
+
842
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
843
+ hidden_states = attn.batch_to_head_dim(hidden_states)
844
+
845
+ # linear proj
846
+ hidden_states = attn.to_out[0](hidden_states)
847
+ # dropout
848
+ hidden_states = attn.to_out[1](hidden_states)
849
+
850
+ if cd_attention_mid:
851
+ hidden_states = rearrange(hidden_states, "(b h) (v w) c -> b v h w c", v=num_views, h=height)
852
+ hidden_states_0, hidden_states_1 = torch.chunk(hidden_states, dim=3, chunks=2) # b v h w c
853
+ hidden_states = torch.cat([hidden_states_0, hidden_states_1], dim=0) # 2b v h w c
854
+ hidden_states = rearrange(hidden_states, "b v h w c -> (b v) (h w) c", v=num_views, h=height)
855
+ else:
856
+ hidden_states = rearrange(hidden_states, "(b h) (v w) c -> (b v) (h w) c", v=num_views, h=height)
857
+ if input_ndim == 4:
858
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
859
+
860
+ if attn.residual_connection:
861
+ hidden_states = hidden_states + residual
862
+
863
+ hidden_states = hidden_states / attn.rescale_output_factor
864
+
865
+ return hidden_states
866
+
867
+
868
+ class XFormersJointAttnProcessor:
869
+ r"""
870
+ Default processor for performing attention-related computations.
871
+ """
872
+
873
+ def __call__(
874
+ self,
875
+ attn: Attention,
876
+ hidden_states,
877
+ encoder_hidden_states=None,
878
+ attention_mask=None,
879
+ temb=None,
880
+ num_tasks=2
881
+ ):
882
+ residual = hidden_states
883
+
884
+ if attn.spatial_norm is not None:
885
+ hidden_states = attn.spatial_norm(hidden_states, temb)
886
+
887
+ input_ndim = hidden_states.ndim
888
+
889
+ if input_ndim == 4:
890
+ batch_size, channel, height, width = hidden_states.shape
891
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
892
+
893
+ batch_size, sequence_length, _ = (
894
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
895
+ )
896
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
897
+
898
+ # from yuancheng; here attention_mask is None
899
+ if attention_mask is not None:
900
+ # expand our mask's singleton query_tokens dimension:
901
+ # [batch*heads, 1, key_tokens] ->
902
+ # [batch*heads, query_tokens, key_tokens]
903
+ # so that it can be added as a bias onto the attention scores that xformers computes:
904
+ # [batch*heads, query_tokens, key_tokens]
905
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
906
+ _, query_tokens, _ = hidden_states.shape
907
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
908
+
909
+ if attn.group_norm is not None:
910
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
911
+
912
+ query = attn.to_q(hidden_states)
913
+
914
+ if encoder_hidden_states is None:
915
+ encoder_hidden_states = hidden_states
916
+ elif attn.norm_cross:
917
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
918
+
919
+ key = attn.to_k(encoder_hidden_states)
920
+ value = attn.to_v(encoder_hidden_states)
921
+
922
+ assert num_tasks == 2 # only support two tasks now
923
+
924
+ def transpose(tensor):
925
+ tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # bv hw c
926
+ tensor = torch.cat([tensor_0, tensor_1], dim=1) # bv 2hw c
927
+ return tensor
928
+ key = transpose(key)
929
+ value = transpose(value)
930
+ query = transpose(query)
931
+ # from icecream import ic
932
+ # ic(key.shape, value.shape, query.shape)
933
+ # import pdb;pdb.set_trace()
934
+ query = attn.head_to_batch_dim(query).contiguous()
935
+ key = attn.head_to_batch_dim(key).contiguous()
936
+ value = attn.head_to_batch_dim(value).contiguous()
937
+
938
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
939
+ hidden_states = attn.batch_to_head_dim(hidden_states)
940
+
941
+ # linear proj
942
+ hidden_states = attn.to_out[0](hidden_states)
943
+ # dropout
944
+ hidden_states = attn.to_out[1](hidden_states)
945
+ hidden_states_normal, hidden_states_color = torch.chunk(hidden_states, dim=1, chunks=2)
946
+ hidden_states = torch.cat([hidden_states_normal, hidden_states_color], dim=0) # 2bv hw c
947
+
948
+ if input_ndim == 4:
949
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
950
+
951
+ if attn.residual_connection:
952
+ hidden_states = hidden_states + residual
953
+
954
+ hidden_states = hidden_states / attn.rescale_output_factor
955
+
956
+ return hidden_states
957
+
958
+
959
+ class JointAttnProcessor:
960
+ r"""
961
+ Default processor for performing attention-related computations.
962
+ """
963
+
964
+ def __call__(
965
+ self,
966
+ attn: Attention,
967
+ hidden_states,
968
+ encoder_hidden_states=None,
969
+ attention_mask=None,
970
+ temb=None,
971
+ num_tasks=2
972
+ ):
973
+
974
+ residual = hidden_states
975
+
976
+ if attn.spatial_norm is not None:
977
+ hidden_states = attn.spatial_norm(hidden_states, temb)
978
+
979
+ input_ndim = hidden_states.ndim
980
+
981
+ if input_ndim == 4:
982
+ batch_size, channel, height, width = hidden_states.shape
983
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
984
+
985
+ batch_size, sequence_length, _ = (
986
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
987
+ )
988
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
989
+
990
+
991
+ if attn.group_norm is not None:
992
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
993
+
994
+ query = attn.to_q(hidden_states)
995
+
996
+ if encoder_hidden_states is None:
997
+ encoder_hidden_states = hidden_states
998
+ elif attn.norm_cross:
999
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1000
+
1001
+ key = attn.to_k(encoder_hidden_states)
1002
+ value = attn.to_v(encoder_hidden_states)
1003
+
1004
+ assert num_tasks == 2 # only support two tasks now
1005
+
1006
+ def transpose(tensor):
1007
+ tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # bv hw c
1008
+ tensor = torch.cat([tensor_0, tensor_1], dim=1) # bv 2hw c
1009
+ return tensor
1010
+ key = transpose(key)
1011
+ value = transpose(value)
1012
+ query = transpose(query)
1013
+
1014
+
1015
+ query = attn.head_to_batch_dim(query).contiguous()
1016
+ key = attn.head_to_batch_dim(key).contiguous()
1017
+ value = attn.head_to_batch_dim(value).contiguous()
1018
+
1019
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
1020
+ hidden_states = torch.bmm(attention_probs, value)
1021
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1022
+
1023
+
1024
+ # linear proj
1025
+ hidden_states = attn.to_out[0](hidden_states)
1026
+ # dropout
1027
+ hidden_states = attn.to_out[1](hidden_states)
1028
+
1029
+ hidden_states = torch.cat([hidden_states[:, 0], hidden_states[:, 1]], dim=0) # 2bv hw c
1030
+ if input_ndim == 4:
1031
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1032
+
1033
+ if attn.residual_connection:
1034
+ hidden_states = hidden_states + residual
1035
+
1036
+ hidden_states = hidden_states / attn.rescale_output_factor
1037
+
1038
+ return hidden_states
mvdiffusion/models/unet_mv2d_blocks.py ADDED
@@ -0,0 +1,971 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, Optional, Tuple
15
+
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from diffusers.utils import is_torch_version, logging
22
+ from diffusers.models.normalization import AdaGroupNorm
23
+ from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
24
+ from diffusers.models.dual_transformer_2d import DualTransformer2DModel
25
+ from diffusers.models.resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D
26
+
27
+ from diffusers.models.unets.unet_2d_blocks import DownBlock2D, ResnetDownsampleBlock2D, AttnDownBlock2D, CrossAttnDownBlock2D, SimpleCrossAttnDownBlock2D, SkipDownBlock2D, AttnSkipDownBlock2D, DownEncoderBlock2D, AttnDownEncoderBlock2D, KDownBlock2D, KCrossAttnDownBlock2D
28
+ from diffusers.models.unets.unet_2d_blocks import UpBlock2D, ResnetUpsampleBlock2D, CrossAttnUpBlock2D, SimpleCrossAttnUpBlock2D, AttnUpBlock2D, SkipUpBlock2D, AttnSkipUpBlock2D, UpDecoderBlock2D, AttnUpDecoderBlock2D, KUpBlock2D, KCrossAttnUpBlock2D
29
+
30
+
31
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
+
33
+
34
+ def get_down_block(
35
+ down_block_type,
36
+ num_layers,
37
+ in_channels,
38
+ out_channels,
39
+ temb_channels,
40
+ add_downsample,
41
+ resnet_eps,
42
+ resnet_act_fn,
43
+ transformer_layers_per_block=1,
44
+ num_attention_heads=None,
45
+ resnet_groups=None,
46
+ cross_attention_dim=None,
47
+ downsample_padding=None,
48
+ dual_cross_attention=False,
49
+ use_linear_projection=False,
50
+ only_cross_attention=False,
51
+ upcast_attention=False,
52
+ resnet_time_scale_shift="default",
53
+ resnet_skip_time_act=False,
54
+ resnet_out_scale_factor=1.0,
55
+ cross_attention_norm=None,
56
+ attention_head_dim=None,
57
+ downsample_type=None,
58
+ num_views=1,
59
+ cd_attention_last: bool = False,
60
+ cd_attention_mid: bool = False,
61
+ multiview_attention: bool = True,
62
+ sparse_mv_attention: bool = False,
63
+ selfattn_block: str = "custom",
64
+ mvcd_attention: bool=False,
65
+ use_dino: bool = False
66
+ ):
67
+ # If attn head dim is not defined, we default it to the number of heads
68
+ if attention_head_dim is None:
69
+ logger.warn(
70
+ f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
71
+ )
72
+ attention_head_dim = num_attention_heads
73
+
74
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
75
+ if down_block_type == "DownBlock2D":
76
+ return DownBlock2D(
77
+ num_layers=num_layers,
78
+ in_channels=in_channels,
79
+ out_channels=out_channels,
80
+ temb_channels=temb_channels,
81
+ add_downsample=add_downsample,
82
+ resnet_eps=resnet_eps,
83
+ resnet_act_fn=resnet_act_fn,
84
+ resnet_groups=resnet_groups,
85
+ downsample_padding=downsample_padding,
86
+ resnet_time_scale_shift=resnet_time_scale_shift,
87
+ )
88
+ elif down_block_type == "ResnetDownsampleBlock2D":
89
+ return ResnetDownsampleBlock2D(
90
+ num_layers=num_layers,
91
+ in_channels=in_channels,
92
+ out_channels=out_channels,
93
+ temb_channels=temb_channels,
94
+ add_downsample=add_downsample,
95
+ resnet_eps=resnet_eps,
96
+ resnet_act_fn=resnet_act_fn,
97
+ resnet_groups=resnet_groups,
98
+ resnet_time_scale_shift=resnet_time_scale_shift,
99
+ skip_time_act=resnet_skip_time_act,
100
+ output_scale_factor=resnet_out_scale_factor,
101
+ )
102
+ elif down_block_type == "AttnDownBlock2D":
103
+ if add_downsample is False:
104
+ downsample_type = None
105
+ else:
106
+ downsample_type = downsample_type or "conv" # default to 'conv'
107
+ return AttnDownBlock2D(
108
+ num_layers=num_layers,
109
+ in_channels=in_channels,
110
+ out_channels=out_channels,
111
+ temb_channels=temb_channels,
112
+ resnet_eps=resnet_eps,
113
+ resnet_act_fn=resnet_act_fn,
114
+ resnet_groups=resnet_groups,
115
+ downsample_padding=downsample_padding,
116
+ attention_head_dim=attention_head_dim,
117
+ resnet_time_scale_shift=resnet_time_scale_shift,
118
+ downsample_type=downsample_type,
119
+ )
120
+ elif down_block_type == "CrossAttnDownBlock2D":
121
+ if cross_attention_dim is None:
122
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
123
+ return CrossAttnDownBlock2D(
124
+ num_layers=num_layers,
125
+ transformer_layers_per_block=transformer_layers_per_block,
126
+ in_channels=in_channels,
127
+ out_channels=out_channels,
128
+ temb_channels=temb_channels,
129
+ add_downsample=add_downsample,
130
+ resnet_eps=resnet_eps,
131
+ resnet_act_fn=resnet_act_fn,
132
+ resnet_groups=resnet_groups,
133
+ downsample_padding=downsample_padding,
134
+ cross_attention_dim=cross_attention_dim,
135
+ num_attention_heads=num_attention_heads,
136
+ dual_cross_attention=dual_cross_attention,
137
+ use_linear_projection=use_linear_projection,
138
+ only_cross_attention=only_cross_attention,
139
+ upcast_attention=upcast_attention,
140
+ resnet_time_scale_shift=resnet_time_scale_shift,
141
+ )
142
+ # custom MV2D attention block
143
+ elif down_block_type == "CrossAttnDownBlockMV2D":
144
+ if cross_attention_dim is None:
145
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMV2D")
146
+ return CrossAttnDownBlockMV2D(
147
+ num_layers=num_layers,
148
+ transformer_layers_per_block=transformer_layers_per_block,
149
+ in_channels=in_channels,
150
+ out_channels=out_channels,
151
+ temb_channels=temb_channels,
152
+ add_downsample=add_downsample,
153
+ resnet_eps=resnet_eps,
154
+ resnet_act_fn=resnet_act_fn,
155
+ resnet_groups=resnet_groups,
156
+ downsample_padding=downsample_padding,
157
+ cross_attention_dim=cross_attention_dim,
158
+ num_attention_heads=num_attention_heads,
159
+ dual_cross_attention=dual_cross_attention,
160
+ use_linear_projection=use_linear_projection,
161
+ only_cross_attention=only_cross_attention,
162
+ upcast_attention=upcast_attention,
163
+ resnet_time_scale_shift=resnet_time_scale_shift,
164
+ num_views=num_views,
165
+ cd_attention_last=cd_attention_last,
166
+ cd_attention_mid=cd_attention_mid,
167
+ multiview_attention=multiview_attention,
168
+ sparse_mv_attention=sparse_mv_attention,
169
+ selfattn_block=selfattn_block,
170
+ mvcd_attention=mvcd_attention,
171
+ use_dino=use_dino
172
+ )
173
+ elif down_block_type == "SimpleCrossAttnDownBlock2D":
174
+ if cross_attention_dim is None:
175
+ raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D")
176
+ return SimpleCrossAttnDownBlock2D(
177
+ num_layers=num_layers,
178
+ in_channels=in_channels,
179
+ out_channels=out_channels,
180
+ temb_channels=temb_channels,
181
+ add_downsample=add_downsample,
182
+ resnet_eps=resnet_eps,
183
+ resnet_act_fn=resnet_act_fn,
184
+ resnet_groups=resnet_groups,
185
+ cross_attention_dim=cross_attention_dim,
186
+ attention_head_dim=attention_head_dim,
187
+ resnet_time_scale_shift=resnet_time_scale_shift,
188
+ skip_time_act=resnet_skip_time_act,
189
+ output_scale_factor=resnet_out_scale_factor,
190
+ only_cross_attention=only_cross_attention,
191
+ cross_attention_norm=cross_attention_norm,
192
+ )
193
+ elif down_block_type == "SkipDownBlock2D":
194
+ return SkipDownBlock2D(
195
+ num_layers=num_layers,
196
+ in_channels=in_channels,
197
+ out_channels=out_channels,
198
+ temb_channels=temb_channels,
199
+ add_downsample=add_downsample,
200
+ resnet_eps=resnet_eps,
201
+ resnet_act_fn=resnet_act_fn,
202
+ downsample_padding=downsample_padding,
203
+ resnet_time_scale_shift=resnet_time_scale_shift,
204
+ )
205
+ elif down_block_type == "AttnSkipDownBlock2D":
206
+ return AttnSkipDownBlock2D(
207
+ num_layers=num_layers,
208
+ in_channels=in_channels,
209
+ out_channels=out_channels,
210
+ temb_channels=temb_channels,
211
+ add_downsample=add_downsample,
212
+ resnet_eps=resnet_eps,
213
+ resnet_act_fn=resnet_act_fn,
214
+ attention_head_dim=attention_head_dim,
215
+ resnet_time_scale_shift=resnet_time_scale_shift,
216
+ )
217
+ elif down_block_type == "DownEncoderBlock2D":
218
+ return DownEncoderBlock2D(
219
+ num_layers=num_layers,
220
+ in_channels=in_channels,
221
+ out_channels=out_channels,
222
+ add_downsample=add_downsample,
223
+ resnet_eps=resnet_eps,
224
+ resnet_act_fn=resnet_act_fn,
225
+ resnet_groups=resnet_groups,
226
+ downsample_padding=downsample_padding,
227
+ resnet_time_scale_shift=resnet_time_scale_shift,
228
+ )
229
+ elif down_block_type == "AttnDownEncoderBlock2D":
230
+ return AttnDownEncoderBlock2D(
231
+ num_layers=num_layers,
232
+ in_channels=in_channels,
233
+ out_channels=out_channels,
234
+ add_downsample=add_downsample,
235
+ resnet_eps=resnet_eps,
236
+ resnet_act_fn=resnet_act_fn,
237
+ resnet_groups=resnet_groups,
238
+ downsample_padding=downsample_padding,
239
+ attention_head_dim=attention_head_dim,
240
+ resnet_time_scale_shift=resnet_time_scale_shift,
241
+ )
242
+ elif down_block_type == "KDownBlock2D":
243
+ return KDownBlock2D(
244
+ num_layers=num_layers,
245
+ in_channels=in_channels,
246
+ out_channels=out_channels,
247
+ temb_channels=temb_channels,
248
+ add_downsample=add_downsample,
249
+ resnet_eps=resnet_eps,
250
+ resnet_act_fn=resnet_act_fn,
251
+ )
252
+ elif down_block_type == "KCrossAttnDownBlock2D":
253
+ return KCrossAttnDownBlock2D(
254
+ num_layers=num_layers,
255
+ in_channels=in_channels,
256
+ out_channels=out_channels,
257
+ temb_channels=temb_channels,
258
+ add_downsample=add_downsample,
259
+ resnet_eps=resnet_eps,
260
+ resnet_act_fn=resnet_act_fn,
261
+ cross_attention_dim=cross_attention_dim,
262
+ attention_head_dim=attention_head_dim,
263
+ add_self_attention=True if not add_downsample else False,
264
+ )
265
+ raise ValueError(f"{down_block_type} does not exist.")
266
+
267
+
268
+ def get_up_block(
269
+ up_block_type,
270
+ num_layers,
271
+ in_channels,
272
+ out_channels,
273
+ prev_output_channel,
274
+ temb_channels,
275
+ add_upsample,
276
+ resnet_eps,
277
+ resnet_act_fn,
278
+ transformer_layers_per_block=1,
279
+ num_attention_heads=None,
280
+ resnet_groups=None,
281
+ cross_attention_dim=None,
282
+ dual_cross_attention=False,
283
+ use_linear_projection=False,
284
+ only_cross_attention=False,
285
+ upcast_attention=False,
286
+ resnet_time_scale_shift="default",
287
+ resnet_skip_time_act=False,
288
+ resnet_out_scale_factor=1.0,
289
+ cross_attention_norm=None,
290
+ attention_head_dim=None,
291
+ upsample_type=None,
292
+ num_views=1,
293
+ cd_attention_last: bool = False,
294
+ cd_attention_mid: bool = False,
295
+ multiview_attention: bool = True,
296
+ sparse_mv_attention: bool = False,
297
+ selfattn_block: str = "custom",
298
+ mvcd_attention: bool=False,
299
+ use_dino: bool = False
300
+ ):
301
+ # If attn head dim is not defined, we default it to the number of heads
302
+ if attention_head_dim is None:
303
+ logger.warn(
304
+ f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
305
+ )
306
+ attention_head_dim = num_attention_heads
307
+
308
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
309
+ if up_block_type == "UpBlock2D":
310
+ return UpBlock2D(
311
+ num_layers=num_layers,
312
+ in_channels=in_channels,
313
+ out_channels=out_channels,
314
+ prev_output_channel=prev_output_channel,
315
+ temb_channels=temb_channels,
316
+ add_upsample=add_upsample,
317
+ resnet_eps=resnet_eps,
318
+ resnet_act_fn=resnet_act_fn,
319
+ resnet_groups=resnet_groups,
320
+ resnet_time_scale_shift=resnet_time_scale_shift,
321
+ )
322
+ elif up_block_type == "ResnetUpsampleBlock2D":
323
+ return ResnetUpsampleBlock2D(
324
+ num_layers=num_layers,
325
+ in_channels=in_channels,
326
+ out_channels=out_channels,
327
+ prev_output_channel=prev_output_channel,
328
+ temb_channels=temb_channels,
329
+ add_upsample=add_upsample,
330
+ resnet_eps=resnet_eps,
331
+ resnet_act_fn=resnet_act_fn,
332
+ resnet_groups=resnet_groups,
333
+ resnet_time_scale_shift=resnet_time_scale_shift,
334
+ skip_time_act=resnet_skip_time_act,
335
+ output_scale_factor=resnet_out_scale_factor,
336
+ )
337
+ elif up_block_type == "CrossAttnUpBlock2D":
338
+ if cross_attention_dim is None:
339
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
340
+ return CrossAttnUpBlock2D(
341
+ num_layers=num_layers,
342
+ transformer_layers_per_block=transformer_layers_per_block,
343
+ in_channels=in_channels,
344
+ out_channels=out_channels,
345
+ prev_output_channel=prev_output_channel,
346
+ temb_channels=temb_channels,
347
+ add_upsample=add_upsample,
348
+ resnet_eps=resnet_eps,
349
+ resnet_act_fn=resnet_act_fn,
350
+ resnet_groups=resnet_groups,
351
+ cross_attention_dim=cross_attention_dim,
352
+ num_attention_heads=num_attention_heads,
353
+ dual_cross_attention=dual_cross_attention,
354
+ use_linear_projection=use_linear_projection,
355
+ only_cross_attention=only_cross_attention,
356
+ upcast_attention=upcast_attention,
357
+ resnet_time_scale_shift=resnet_time_scale_shift,
358
+ )
359
+ # custom MV2D attention block
360
+ elif up_block_type == "CrossAttnUpBlockMV2D":
361
+ if cross_attention_dim is None:
362
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMV2D")
363
+ return CrossAttnUpBlockMV2D(
364
+ num_layers=num_layers,
365
+ transformer_layers_per_block=transformer_layers_per_block,
366
+ in_channels=in_channels,
367
+ out_channels=out_channels,
368
+ prev_output_channel=prev_output_channel,
369
+ temb_channels=temb_channels,
370
+ add_upsample=add_upsample,
371
+ resnet_eps=resnet_eps,
372
+ resnet_act_fn=resnet_act_fn,
373
+ resnet_groups=resnet_groups,
374
+ cross_attention_dim=cross_attention_dim,
375
+ num_attention_heads=num_attention_heads,
376
+ dual_cross_attention=dual_cross_attention,
377
+ use_linear_projection=use_linear_projection,
378
+ only_cross_attention=only_cross_attention,
379
+ upcast_attention=upcast_attention,
380
+ resnet_time_scale_shift=resnet_time_scale_shift,
381
+ num_views=num_views,
382
+ cd_attention_last=cd_attention_last,
383
+ cd_attention_mid=cd_attention_mid,
384
+ multiview_attention=multiview_attention,
385
+ sparse_mv_attention=sparse_mv_attention,
386
+ selfattn_block=selfattn_block,
387
+ mvcd_attention=mvcd_attention,
388
+ use_dino=use_dino
389
+ )
390
+ elif up_block_type == "SimpleCrossAttnUpBlock2D":
391
+ if cross_attention_dim is None:
392
+ raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D")
393
+ return SimpleCrossAttnUpBlock2D(
394
+ num_layers=num_layers,
395
+ in_channels=in_channels,
396
+ out_channels=out_channels,
397
+ prev_output_channel=prev_output_channel,
398
+ temb_channels=temb_channels,
399
+ add_upsample=add_upsample,
400
+ resnet_eps=resnet_eps,
401
+ resnet_act_fn=resnet_act_fn,
402
+ resnet_groups=resnet_groups,
403
+ cross_attention_dim=cross_attention_dim,
404
+ attention_head_dim=attention_head_dim,
405
+ resnet_time_scale_shift=resnet_time_scale_shift,
406
+ skip_time_act=resnet_skip_time_act,
407
+ output_scale_factor=resnet_out_scale_factor,
408
+ only_cross_attention=only_cross_attention,
409
+ cross_attention_norm=cross_attention_norm,
410
+ )
411
+ elif up_block_type == "AttnUpBlock2D":
412
+ if add_upsample is False:
413
+ upsample_type = None
414
+ else:
415
+ upsample_type = upsample_type or "conv" # default to 'conv'
416
+
417
+ return AttnUpBlock2D(
418
+ num_layers=num_layers,
419
+ in_channels=in_channels,
420
+ out_channels=out_channels,
421
+ prev_output_channel=prev_output_channel,
422
+ temb_channels=temb_channels,
423
+ resnet_eps=resnet_eps,
424
+ resnet_act_fn=resnet_act_fn,
425
+ resnet_groups=resnet_groups,
426
+ attention_head_dim=attention_head_dim,
427
+ resnet_time_scale_shift=resnet_time_scale_shift,
428
+ upsample_type=upsample_type,
429
+ )
430
+ elif up_block_type == "SkipUpBlock2D":
431
+ return SkipUpBlock2D(
432
+ num_layers=num_layers,
433
+ in_channels=in_channels,
434
+ out_channels=out_channels,
435
+ prev_output_channel=prev_output_channel,
436
+ temb_channels=temb_channels,
437
+ add_upsample=add_upsample,
438
+ resnet_eps=resnet_eps,
439
+ resnet_act_fn=resnet_act_fn,
440
+ resnet_time_scale_shift=resnet_time_scale_shift,
441
+ )
442
+ elif up_block_type == "AttnSkipUpBlock2D":
443
+ return AttnSkipUpBlock2D(
444
+ num_layers=num_layers,
445
+ in_channels=in_channels,
446
+ out_channels=out_channels,
447
+ prev_output_channel=prev_output_channel,
448
+ temb_channels=temb_channels,
449
+ add_upsample=add_upsample,
450
+ resnet_eps=resnet_eps,
451
+ resnet_act_fn=resnet_act_fn,
452
+ attention_head_dim=attention_head_dim,
453
+ resnet_time_scale_shift=resnet_time_scale_shift,
454
+ )
455
+ elif up_block_type == "UpDecoderBlock2D":
456
+ return UpDecoderBlock2D(
457
+ num_layers=num_layers,
458
+ in_channels=in_channels,
459
+ out_channels=out_channels,
460
+ add_upsample=add_upsample,
461
+ resnet_eps=resnet_eps,
462
+ resnet_act_fn=resnet_act_fn,
463
+ resnet_groups=resnet_groups,
464
+ resnet_time_scale_shift=resnet_time_scale_shift,
465
+ temb_channels=temb_channels,
466
+ )
467
+ elif up_block_type == "AttnUpDecoderBlock2D":
468
+ return AttnUpDecoderBlock2D(
469
+ num_layers=num_layers,
470
+ in_channels=in_channels,
471
+ out_channels=out_channels,
472
+ add_upsample=add_upsample,
473
+ resnet_eps=resnet_eps,
474
+ resnet_act_fn=resnet_act_fn,
475
+ resnet_groups=resnet_groups,
476
+ attention_head_dim=attention_head_dim,
477
+ resnet_time_scale_shift=resnet_time_scale_shift,
478
+ temb_channels=temb_channels,
479
+ )
480
+ elif up_block_type == "KUpBlock2D":
481
+ return KUpBlock2D(
482
+ num_layers=num_layers,
483
+ in_channels=in_channels,
484
+ out_channels=out_channels,
485
+ temb_channels=temb_channels,
486
+ add_upsample=add_upsample,
487
+ resnet_eps=resnet_eps,
488
+ resnet_act_fn=resnet_act_fn,
489
+ )
490
+ elif up_block_type == "KCrossAttnUpBlock2D":
491
+ return KCrossAttnUpBlock2D(
492
+ num_layers=num_layers,
493
+ in_channels=in_channels,
494
+ out_channels=out_channels,
495
+ temb_channels=temb_channels,
496
+ add_upsample=add_upsample,
497
+ resnet_eps=resnet_eps,
498
+ resnet_act_fn=resnet_act_fn,
499
+ cross_attention_dim=cross_attention_dim,
500
+ attention_head_dim=attention_head_dim,
501
+ )
502
+
503
+ raise ValueError(f"{up_block_type} does not exist.")
504
+
505
+
506
+ class UNetMidBlockMV2DCrossAttn(nn.Module):
507
+ def __init__(
508
+ self,
509
+ in_channels: int,
510
+ temb_channels: int,
511
+ dropout: float = 0.0,
512
+ num_layers: int = 1,
513
+ transformer_layers_per_block: int = 1,
514
+ resnet_eps: float = 1e-6,
515
+ resnet_time_scale_shift: str = "default",
516
+ resnet_act_fn: str = "swish",
517
+ resnet_groups: int = 32,
518
+ resnet_pre_norm: bool = True,
519
+ num_attention_heads=1,
520
+ output_scale_factor=1.0,
521
+ cross_attention_dim=1280,
522
+ dual_cross_attention=False,
523
+ use_linear_projection=False,
524
+ upcast_attention=False,
525
+ num_views: int = 1,
526
+ cd_attention_last: bool = False,
527
+ cd_attention_mid: bool = False,
528
+ multiview_attention: bool = True,
529
+ sparse_mv_attention: bool = False,
530
+ selfattn_block: str = "custom",
531
+ mvcd_attention: bool=False,
532
+ use_dino: bool = False
533
+ ):
534
+ super().__init__()
535
+
536
+ self.has_cross_attention = True
537
+ self.num_attention_heads = num_attention_heads
538
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
539
+ if selfattn_block == "custom":
540
+ from .transformer_mv2d import TransformerMV2DModel
541
+ elif selfattn_block == "rowwise":
542
+ from .transformer_mv2d_rowwise import TransformerMV2DModel
543
+ elif selfattn_block == "self_rowwise":
544
+ from .transformer_mv2d_self_rowwise import TransformerMV2DModel
545
+ else:
546
+ raise NotImplementedError
547
+
548
+ # there is always at least one resnet
549
+ resnets = [
550
+ ResnetBlock2D(
551
+ in_channels=in_channels,
552
+ out_channels=in_channels,
553
+ temb_channels=temb_channels,
554
+ eps=resnet_eps,
555
+ groups=resnet_groups,
556
+ dropout=dropout,
557
+ time_embedding_norm=resnet_time_scale_shift,
558
+ non_linearity=resnet_act_fn,
559
+ output_scale_factor=output_scale_factor,
560
+ pre_norm=resnet_pre_norm,
561
+ )
562
+ ]
563
+ attentions = []
564
+
565
+ for _ in range(num_layers):
566
+ if not dual_cross_attention:
567
+ attentions.append(
568
+ TransformerMV2DModel(
569
+ num_attention_heads,
570
+ in_channels // num_attention_heads,
571
+ in_channels=in_channels,
572
+ num_layers=transformer_layers_per_block,
573
+ cross_attention_dim=cross_attention_dim,
574
+ norm_num_groups=resnet_groups,
575
+ use_linear_projection=use_linear_projection,
576
+ upcast_attention=upcast_attention,
577
+ num_views=num_views,
578
+ cd_attention_last=cd_attention_last,
579
+ cd_attention_mid=cd_attention_mid,
580
+ multiview_attention=multiview_attention,
581
+ sparse_mv_attention=sparse_mv_attention,
582
+ mvcd_attention=mvcd_attention,
583
+ use_dino=use_dino
584
+ )
585
+ )
586
+ else:
587
+ raise NotImplementedError
588
+ resnets.append(
589
+ ResnetBlock2D(
590
+ in_channels=in_channels,
591
+ out_channels=in_channels,
592
+ temb_channels=temb_channels,
593
+ eps=resnet_eps,
594
+ groups=resnet_groups,
595
+ dropout=dropout,
596
+ time_embedding_norm=resnet_time_scale_shift,
597
+ non_linearity=resnet_act_fn,
598
+ output_scale_factor=output_scale_factor,
599
+ pre_norm=resnet_pre_norm,
600
+ )
601
+ )
602
+
603
+ self.attentions = nn.ModuleList(attentions)
604
+ self.resnets = nn.ModuleList(resnets)
605
+
606
+ def forward(
607
+ self,
608
+ hidden_states: torch.FloatTensor,
609
+ temb: Optional[torch.FloatTensor] = None,
610
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
611
+ attention_mask: Optional[torch.FloatTensor] = None,
612
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
613
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
614
+ dino_feature: Optional[torch.FloatTensor] = None
615
+ ) -> torch.FloatTensor:
616
+ hidden_states = self.resnets[0](hidden_states, temb)
617
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
618
+ hidden_states = attn(
619
+ hidden_states,
620
+ encoder_hidden_states=encoder_hidden_states,
621
+ cross_attention_kwargs=cross_attention_kwargs,
622
+ attention_mask=attention_mask,
623
+ encoder_attention_mask=encoder_attention_mask,
624
+ dino_feature=dino_feature,
625
+ return_dict=False,
626
+ )[0]
627
+ hidden_states = resnet(hidden_states, temb)
628
+
629
+ return hidden_states
630
+
631
+
632
+ class CrossAttnUpBlockMV2D(nn.Module):
633
+ def __init__(
634
+ self,
635
+ in_channels: int,
636
+ out_channels: int,
637
+ prev_output_channel: int,
638
+ temb_channels: int,
639
+ dropout: float = 0.0,
640
+ num_layers: int = 1,
641
+ transformer_layers_per_block: int = 1,
642
+ resnet_eps: float = 1e-6,
643
+ resnet_time_scale_shift: str = "default",
644
+ resnet_act_fn: str = "swish",
645
+ resnet_groups: int = 32,
646
+ resnet_pre_norm: bool = True,
647
+ num_attention_heads=1,
648
+ cross_attention_dim=1280,
649
+ output_scale_factor=1.0,
650
+ add_upsample=True,
651
+ dual_cross_attention=False,
652
+ use_linear_projection=False,
653
+ only_cross_attention=False,
654
+ upcast_attention=False,
655
+ num_views: int = 1,
656
+ cd_attention_last: bool = False,
657
+ cd_attention_mid: bool = False,
658
+ multiview_attention: bool = True,
659
+ sparse_mv_attention: bool = False,
660
+ selfattn_block: str = "custom",
661
+ mvcd_attention: bool=False,
662
+ use_dino: bool = False
663
+ ):
664
+ super().__init__()
665
+ resnets = []
666
+ attentions = []
667
+
668
+ self.has_cross_attention = True
669
+ self.num_attention_heads = num_attention_heads
670
+
671
+ if selfattn_block == "custom":
672
+ from .transformer_mv2d import TransformerMV2DModel
673
+ elif selfattn_block == "rowwise":
674
+ from .transformer_mv2d_rowwise import TransformerMV2DModel
675
+ elif selfattn_block == "self_rowwise":
676
+ from .transformer_mv2d_self_rowwise import TransformerMV2DModel
677
+ else:
678
+ raise NotImplementedError
679
+
680
+ for i in range(num_layers):
681
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
682
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
683
+
684
+ resnets.append(
685
+ ResnetBlock2D(
686
+ in_channels=resnet_in_channels + res_skip_channels,
687
+ out_channels=out_channels,
688
+ temb_channels=temb_channels,
689
+ eps=resnet_eps,
690
+ groups=resnet_groups,
691
+ dropout=dropout,
692
+ time_embedding_norm=resnet_time_scale_shift,
693
+ non_linearity=resnet_act_fn,
694
+ output_scale_factor=output_scale_factor,
695
+ pre_norm=resnet_pre_norm,
696
+ )
697
+ )
698
+ if not dual_cross_attention:
699
+ attentions.append(
700
+ TransformerMV2DModel(
701
+ num_attention_heads,
702
+ out_channels // num_attention_heads,
703
+ in_channels=out_channels,
704
+ num_layers=transformer_layers_per_block,
705
+ cross_attention_dim=cross_attention_dim,
706
+ norm_num_groups=resnet_groups,
707
+ use_linear_projection=use_linear_projection,
708
+ only_cross_attention=only_cross_attention,
709
+ upcast_attention=upcast_attention,
710
+ num_views=num_views,
711
+ cd_attention_last=cd_attention_last,
712
+ cd_attention_mid=cd_attention_mid,
713
+ multiview_attention=multiview_attention,
714
+ sparse_mv_attention=sparse_mv_attention,
715
+ mvcd_attention=mvcd_attention,
716
+ use_dino=use_dino
717
+ )
718
+ )
719
+ else:
720
+ raise NotImplementedError
721
+ self.attentions = nn.ModuleList(attentions)
722
+ self.resnets = nn.ModuleList(resnets)
723
+
724
+ if add_upsample:
725
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
726
+ else:
727
+ self.upsamplers = None
728
+
729
+ self.gradient_checkpointing = False
730
+
731
+ def forward(
732
+ self,
733
+ hidden_states: torch.FloatTensor,
734
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
735
+ temb: Optional[torch.FloatTensor] = None,
736
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
737
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
738
+ upsample_size: Optional[int] = None,
739
+ attention_mask: Optional[torch.FloatTensor] = None,
740
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
741
+ dino_feature: Optional[torch.FloatTensor] = None
742
+ ):
743
+ for resnet, attn in zip(self.resnets, self.attentions):
744
+ # pop res hidden states
745
+ res_hidden_states = res_hidden_states_tuple[-1]
746
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
747
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
748
+
749
+ if self.training and self.gradient_checkpointing:
750
+
751
+ def create_custom_forward(module, return_dict=None):
752
+ def custom_forward(*inputs):
753
+ if return_dict is not None:
754
+ return module(*inputs, return_dict=return_dict)
755
+ else:
756
+ return module(*inputs)
757
+
758
+ return custom_forward
759
+
760
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
761
+ hidden_states = torch.utils.checkpoint.checkpoint(
762
+ create_custom_forward(resnet),
763
+ hidden_states,
764
+ temb,
765
+ **ckpt_kwargs,
766
+ )
767
+ hidden_states = torch.utils.checkpoint.checkpoint(
768
+ create_custom_forward(attn, return_dict=False),
769
+ hidden_states,
770
+ encoder_hidden_states,
771
+ dino_feature,
772
+ None, # timestep
773
+ None, # class_labels
774
+ cross_attention_kwargs,
775
+ attention_mask,
776
+ encoder_attention_mask,
777
+ **ckpt_kwargs,
778
+ )[0]
779
+ else:
780
+ hidden_states = resnet(hidden_states, temb)
781
+ hidden_states = attn(
782
+ hidden_states,
783
+ encoder_hidden_states=encoder_hidden_states,
784
+ cross_attention_kwargs=cross_attention_kwargs,
785
+ attention_mask=attention_mask,
786
+ encoder_attention_mask=encoder_attention_mask,
787
+ dino_feature=dino_feature,
788
+ return_dict=False,
789
+ )[0]
790
+
791
+ if self.upsamplers is not None:
792
+ for upsampler in self.upsamplers:
793
+ hidden_states = upsampler(hidden_states, upsample_size)
794
+
795
+ return hidden_states
796
+
797
+
798
+ class CrossAttnDownBlockMV2D(nn.Module):
799
+ def __init__(
800
+ self,
801
+ in_channels: int,
802
+ out_channels: int,
803
+ temb_channels: int,
804
+ dropout: float = 0.0,
805
+ num_layers: int = 1,
806
+ transformer_layers_per_block: int = 1,
807
+ resnet_eps: float = 1e-6,
808
+ resnet_time_scale_shift: str = "default",
809
+ resnet_act_fn: str = "swish",
810
+ resnet_groups: int = 32,
811
+ resnet_pre_norm: bool = True,
812
+ num_attention_heads=1,
813
+ cross_attention_dim=1280,
814
+ output_scale_factor=1.0,
815
+ downsample_padding=1,
816
+ add_downsample=True,
817
+ dual_cross_attention=False,
818
+ use_linear_projection=False,
819
+ only_cross_attention=False,
820
+ upcast_attention=False,
821
+ num_views: int = 1,
822
+ cd_attention_last: bool = False,
823
+ cd_attention_mid: bool = False,
824
+ multiview_attention: bool = True,
825
+ sparse_mv_attention: bool = False,
826
+ selfattn_block: str = "custom",
827
+ mvcd_attention: bool=False,
828
+ use_dino: bool = False
829
+ ):
830
+ super().__init__()
831
+ resnets = []
832
+ attentions = []
833
+
834
+ self.has_cross_attention = True
835
+ self.num_attention_heads = num_attention_heads
836
+ if selfattn_block == "custom":
837
+ from .transformer_mv2d import TransformerMV2DModel
838
+ elif selfattn_block == "rowwise":
839
+ from .transformer_mv2d_rowwise import TransformerMV2DModel
840
+ elif selfattn_block == "self_rowwise":
841
+ from .transformer_mv2d_self_rowwise import TransformerMV2DModel
842
+ else:
843
+ raise NotImplementedError
844
+
845
+ for i in range(num_layers):
846
+ in_channels = in_channels if i == 0 else out_channels
847
+ resnets.append(
848
+ ResnetBlock2D(
849
+ in_channels=in_channels,
850
+ out_channels=out_channels,
851
+ temb_channels=temb_channels,
852
+ eps=resnet_eps,
853
+ groups=resnet_groups,
854
+ dropout=dropout,
855
+ time_embedding_norm=resnet_time_scale_shift,
856
+ non_linearity=resnet_act_fn,
857
+ output_scale_factor=output_scale_factor,
858
+ pre_norm=resnet_pre_norm,
859
+ )
860
+ )
861
+ if not dual_cross_attention:
862
+ attentions.append(
863
+ TransformerMV2DModel(
864
+ num_attention_heads,
865
+ out_channels // num_attention_heads,
866
+ in_channels=out_channels,
867
+ num_layers=transformer_layers_per_block,
868
+ cross_attention_dim=cross_attention_dim,
869
+ norm_num_groups=resnet_groups,
870
+ use_linear_projection=use_linear_projection,
871
+ only_cross_attention=only_cross_attention,
872
+ upcast_attention=upcast_attention,
873
+ num_views=num_views,
874
+ cd_attention_last=cd_attention_last,
875
+ cd_attention_mid=cd_attention_mid,
876
+ multiview_attention=multiview_attention,
877
+ sparse_mv_attention=sparse_mv_attention,
878
+ mvcd_attention=mvcd_attention,
879
+ use_dino=use_dino
880
+ )
881
+ )
882
+ else:
883
+ raise NotImplementedError
884
+ self.attentions = nn.ModuleList(attentions)
885
+ self.resnets = nn.ModuleList(resnets)
886
+
887
+ if add_downsample:
888
+ self.downsamplers = nn.ModuleList(
889
+ [
890
+ Downsample2D(
891
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
892
+ )
893
+ ]
894
+ )
895
+ else:
896
+ self.downsamplers = None
897
+
898
+ self.gradient_checkpointing = False
899
+
900
+ def forward(
901
+ self,
902
+ hidden_states: torch.FloatTensor,
903
+ temb: Optional[torch.FloatTensor] = None,
904
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
905
+ dino_feature: Optional[torch.FloatTensor] = None,
906
+ attention_mask: Optional[torch.FloatTensor] = None,
907
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
908
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
909
+ additional_residuals=None,
910
+ ):
911
+ output_states = ()
912
+
913
+ blocks = list(zip(self.resnets, self.attentions))
914
+
915
+ for i, (resnet, attn) in enumerate(blocks):
916
+ if self.training and self.gradient_checkpointing:
917
+
918
+ def create_custom_forward(module, return_dict=None):
919
+ def custom_forward(*inputs):
920
+ if return_dict is not None:
921
+ return module(*inputs, return_dict=return_dict)
922
+ else:
923
+ return module(*inputs)
924
+
925
+ return custom_forward
926
+
927
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
928
+ hidden_states = torch.utils.checkpoint.checkpoint(
929
+ create_custom_forward(resnet),
930
+ hidden_states,
931
+ temb,
932
+ **ckpt_kwargs,
933
+ )
934
+ hidden_states = torch.utils.checkpoint.checkpoint(
935
+ create_custom_forward(attn, return_dict=False),
936
+ hidden_states,
937
+ encoder_hidden_states,
938
+ dino_feature,
939
+ None, # timestep
940
+ None, # class_labels
941
+ cross_attention_kwargs,
942
+ attention_mask,
943
+ encoder_attention_mask,
944
+ **ckpt_kwargs,
945
+ )[0]
946
+ else:
947
+ hidden_states = resnet(hidden_states, temb)
948
+ hidden_states = attn(
949
+ hidden_states,
950
+ encoder_hidden_states=encoder_hidden_states,
951
+ dino_feature=dino_feature,
952
+ cross_attention_kwargs=cross_attention_kwargs,
953
+ attention_mask=attention_mask,
954
+ encoder_attention_mask=encoder_attention_mask,
955
+ return_dict=False,
956
+ )[0]
957
+
958
+ # apply additional residuals to the output of the last pair of resnet and attention blocks
959
+ if i == len(blocks) - 1 and additional_residuals is not None:
960
+ hidden_states = hidden_states + additional_residuals
961
+
962
+ output_states = output_states + (hidden_states,)
963
+
964
+ if self.downsamplers is not None:
965
+ for downsampler in self.downsamplers:
966
+ hidden_states = downsampler(hidden_states)
967
+
968
+ output_states = output_states + (hidden_states,)
969
+
970
+ return hidden_states, output_states
971
+
mvdiffusion/models/unet_mv2d_condition.py ADDED
@@ -0,0 +1,1686 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+ import os
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.utils.checkpoint
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.loaders import UNet2DConditionLoadersMixin
24
+ from diffusers.utils import BaseOutput, logging
25
+ from diffusers.models.activations import get_activation
26
+ from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
27
+ from diffusers.models.embeddings import (
28
+ GaussianFourierProjection,
29
+ ImageHintTimeEmbedding,
30
+ ImageProjection,
31
+ ImageTimeEmbedding,
32
+ TextImageProjection,
33
+ TextImageTimeEmbedding,
34
+ TextTimeEmbedding,
35
+ TimestepEmbedding,
36
+ Timesteps,
37
+ )
38
+ from diffusers.models.modeling_utils import ModelMixin, load_state_dict, _load_state_dict_into_model
39
+ from diffusers.models.unet_2d_blocks import (
40
+ CrossAttnDownBlock2D,
41
+ CrossAttnUpBlock2D,
42
+ DownBlock2D,
43
+ UNetMidBlock2DCrossAttn,
44
+ UNetMidBlock2DSimpleCrossAttn,
45
+ UpBlock2D,
46
+ )
47
+ from diffusers.utils import (
48
+ CONFIG_NAME,
49
+ FLAX_WEIGHTS_NAME,
50
+ SAFETENSORS_WEIGHTS_NAME,
51
+ WEIGHTS_NAME,
52
+ _add_variant,
53
+ _get_model_file,
54
+ deprecate,
55
+ is_torch_version,
56
+ logging,
57
+ )
58
+ from diffusers.utils.import_utils import is_accelerate_available
59
+ from diffusers.utils.hub_utils import HF_HUB_OFFLINE
60
+ from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
61
+ DIFFUSERS_CACHE = HUGGINGFACE_HUB_CACHE
62
+
63
+ from diffusers import __version__
64
+ from .unet_mv2d_blocks import (
65
+ CrossAttnDownBlockMV2D,
66
+ CrossAttnUpBlockMV2D,
67
+ UNetMidBlockMV2DCrossAttn,
68
+ get_down_block,
69
+ get_up_block,
70
+ )
71
+ from einops import rearrange, repeat
72
+
73
+ from diffusers import __version__
74
+ from mvdiffusion.models.unet_mv2d_blocks import (
75
+ CrossAttnDownBlockMV2D,
76
+ CrossAttnUpBlockMV2D,
77
+ UNetMidBlockMV2DCrossAttn,
78
+ get_down_block,
79
+ get_up_block,
80
+ )
81
+
82
+
83
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
84
+
85
+
86
+ @dataclass
87
+ class UNetMV2DConditionOutput(BaseOutput):
88
+ """
89
+ The output of [`UNet2DConditionModel`].
90
+
91
+ Args:
92
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
93
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
94
+ """
95
+
96
+ sample: torch.FloatTensor = None
97
+
98
+
99
+ class ResidualBlock(nn.Module):
100
+ def __init__(self, dim):
101
+ super(ResidualBlock, self).__init__()
102
+ self.linear1 = nn.Linear(dim, dim)
103
+ self.activation = nn.SiLU()
104
+ self.linear2 = nn.Linear(dim, dim)
105
+
106
+ def forward(self, x):
107
+ identity = x
108
+ out = self.linear1(x)
109
+ out = self.activation(out)
110
+ out = self.linear2(out)
111
+ out += identity
112
+ out = self.activation(out)
113
+ return out
114
+
115
+ class ResidualLiner(nn.Module):
116
+ def __init__(self, in_features, out_features, dim, act=None, num_block=1):
117
+ super(ResidualLiner, self).__init__()
118
+ self.linear_in = nn.Sequential(nn.Linear(in_features, dim), nn.SiLU())
119
+
120
+ blocks = nn.ModuleList()
121
+ for _ in range(num_block):
122
+ blocks.append(ResidualBlock(dim))
123
+ self.blocks = blocks
124
+
125
+ self.linear_out = nn.Linear(dim, out_features)
126
+ self.act = act
127
+
128
+ def forward(self, x):
129
+ out = self.linear_in(x)
130
+ for block in self.blocks:
131
+ out = block(out)
132
+ out = self.linear_out(out)
133
+ if self.act is not None:
134
+ out = self.act(out)
135
+ return out
136
+
137
+ class BasicConvBlock(nn.Module):
138
+ def __init__(self, in_channels, out_channels, stride=1):
139
+ super(BasicConvBlock, self).__init__()
140
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
141
+ self.norm1 = nn.GroupNorm(num_groups=8, num_channels=in_channels, affine=True)
142
+ self.act = nn.SiLU()
143
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
144
+ self.norm2 = nn.GroupNorm(num_groups=8, num_channels=in_channels, affine=True)
145
+ self.downsample = nn.Sequential()
146
+ if stride != 1 or in_channels != out_channels:
147
+ self.downsample = nn.Sequential(
148
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
149
+ nn.GroupNorm(num_groups=8, num_channels=in_channels, affine=True)
150
+ )
151
+
152
+ def forward(self, x):
153
+ identity = x
154
+ out = self.conv1(x)
155
+ out = self.norm1(out)
156
+ out = self.act(out)
157
+ out = self.conv2(out)
158
+ out = self.norm2(out)
159
+ out += self.downsample(identity)
160
+ out = self.act(out)
161
+ return out
162
+
163
+ class UNetMV2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
164
+ r"""
165
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
166
+ shaped output.
167
+
168
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
169
+ for all models (such as downloading or saving).
170
+
171
+ Parameters:
172
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
173
+ Height and width of input/output sample.
174
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
175
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
176
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
177
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
178
+ Whether to flip the sin to cos in the time embedding.
179
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
180
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
181
+ The tuple of downsample blocks to use.
182
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
183
+ Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
184
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
185
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
186
+ The tuple of upsample blocks to use.
187
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
188
+ Whether to include self-attention in the basic transformer blocks, see
189
+ [`~models.attention.BasicTransformerBlock`].
190
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
191
+ The tuple of output channels for each block.
192
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
193
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
194
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
195
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
196
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
197
+ If `None`, normalization and activation layers is skipped in post-processing.
198
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
199
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
200
+ The dimension of the cross attention features.
201
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
202
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
203
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
204
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
205
+ encoder_hid_dim (`int`, *optional*, defaults to None):
206
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
207
+ dimension to `cross_attention_dim`.
208
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
209
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
210
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
211
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
212
+ num_attention_heads (`int`, *optional*):
213
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
214
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
215
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
216
+ class_embed_type (`str`, *optional*, defaults to `None`):
217
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
218
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
219
+ addition_embed_type (`str`, *optional*, defaults to `None`):
220
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
221
+ "text". "text" will use the `TextTimeEmbedding` layer.
222
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
223
+ Dimension for the timestep embeddings.
224
+ num_class_embeds (`int`, *optional*, defaults to `None`):
225
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
226
+ class conditioning with `class_embed_type` equal to `None`.
227
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
228
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
229
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
230
+ An optional override for the dimension of the projected time embedding.
231
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
232
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
233
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
234
+ timestep_post_act (`str`, *optional*, defaults to `None`):
235
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
236
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
237
+ The dimension of `cond_proj` layer in the timestep embedding.
238
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
239
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
240
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
241
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
242
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
243
+ embeddings with the class embeddings.
244
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
245
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
246
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
247
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
248
+ otherwise.
249
+ """
250
+
251
+ _supports_gradient_checkpointing = True
252
+
253
+ @register_to_config
254
+ def __init__(
255
+ self,
256
+ sample_size: Optional[int] = None,
257
+ in_channels: int = 4,
258
+ out_channels: int = 4,
259
+ center_input_sample: bool = False,
260
+ flip_sin_to_cos: bool = True,
261
+ freq_shift: int = 0,
262
+ down_block_types: Tuple[str] = (
263
+ "CrossAttnDownBlockMV2D",
264
+ "CrossAttnDownBlockMV2D",
265
+ "CrossAttnDownBlockMV2D",
266
+ "DownBlock2D",
267
+ ),
268
+ mid_block_type: Optional[str] = "UNetMidBlockMV2DCrossAttn",
269
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D"),
270
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
271
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
272
+ layers_per_block: Union[int, Tuple[int]] = 2,
273
+ downsample_padding: int = 1,
274
+ mid_block_scale_factor: float = 1,
275
+ act_fn: str = "silu",
276
+ norm_num_groups: Optional[int] = 32,
277
+ norm_eps: float = 1e-5,
278
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
279
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
280
+ encoder_hid_dim: Optional[int] = None,
281
+ encoder_hid_dim_type: Optional[str] = None,
282
+ attention_head_dim: Union[int, Tuple[int]] = 8,
283
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
284
+ dual_cross_attention: bool = False,
285
+ use_linear_projection: bool = False,
286
+ class_embed_type: Optional[str] = None,
287
+ addition_embed_type: Optional[str] = None,
288
+ addition_time_embed_dim: Optional[int] = None,
289
+ num_class_embeds: Optional[int] = None,
290
+ upcast_attention: bool = False,
291
+ resnet_time_scale_shift: str = "default",
292
+ resnet_skip_time_act: bool = False,
293
+ resnet_out_scale_factor: int = 1.0,
294
+ time_embedding_type: str = "positional",
295
+ time_embedding_dim: Optional[int] = None,
296
+ time_embedding_act_fn: Optional[str] = None,
297
+ timestep_post_act: Optional[str] = None,
298
+ time_cond_proj_dim: Optional[int] = None,
299
+ conv_in_kernel: int = 3,
300
+ conv_out_kernel: int = 3,
301
+ projection_class_embeddings_input_dim: Optional[int] = None,
302
+ projection_camera_embeddings_input_dim: Optional[int] = None,
303
+ class_embeddings_concat: bool = False,
304
+ mid_block_only_cross_attention: Optional[bool] = None,
305
+ cross_attention_norm: Optional[str] = None,
306
+ addition_embed_type_num_heads=64,
307
+ num_views: int = 1,
308
+ cd_attention_last: bool = False,
309
+ cd_attention_mid: bool = False,
310
+ multiview_attention: bool = True,
311
+ sparse_mv_attention: bool = False,
312
+ selfattn_block: str = "custom",
313
+ mvcd_attention: bool = False,
314
+ regress_elevation: bool = False,
315
+ regress_focal_length: bool = False,
316
+ num_regress_blocks: int = 4,
317
+ use_dino: bool = False,
318
+ addition_downsample: bool = False,
319
+ addition_channels: Optional[Tuple[int]] = (1280, 1280, 1280),
320
+ ):
321
+ super().__init__()
322
+
323
+ self.sample_size = sample_size
324
+ self.num_views = num_views
325
+ self.mvcd_attention = mvcd_attention
326
+ if num_attention_heads is not None:
327
+ raise ValueError(
328
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
329
+ )
330
+
331
+ # If `num_attention_heads` is not defined (which is the case for most models)
332
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
333
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
334
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
335
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
336
+ # which is why we correct for the naming here.
337
+ num_attention_heads = num_attention_heads or attention_head_dim
338
+
339
+ # Check inputs
340
+ if len(down_block_types) != len(up_block_types):
341
+ raise ValueError(
342
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
343
+ )
344
+
345
+ if len(block_out_channels) != len(down_block_types):
346
+ raise ValueError(
347
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
348
+ )
349
+
350
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
351
+ raise ValueError(
352
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
353
+ )
354
+
355
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
356
+ raise ValueError(
357
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
358
+ )
359
+
360
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
361
+ raise ValueError(
362
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
363
+ )
364
+
365
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
366
+ raise ValueError(
367
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
368
+ )
369
+
370
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
371
+ raise ValueError(
372
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
373
+ )
374
+
375
+ # input
376
+ conv_in_padding = (conv_in_kernel - 1) // 2
377
+ self.conv_in = nn.Conv2d(
378
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
379
+ )
380
+
381
+ # time
382
+ if time_embedding_type == "fourier":
383
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
384
+ if time_embed_dim % 2 != 0:
385
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
386
+ self.time_proj = GaussianFourierProjection(
387
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
388
+ )
389
+ timestep_input_dim = time_embed_dim
390
+ elif time_embedding_type == "positional":
391
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
392
+
393
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
394
+ timestep_input_dim = block_out_channels[0]
395
+ else:
396
+ raise ValueError(
397
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
398
+ )
399
+
400
+ self.time_embedding = TimestepEmbedding(
401
+ timestep_input_dim,
402
+ time_embed_dim,
403
+ act_fn=act_fn,
404
+ post_act_fn=timestep_post_act,
405
+ cond_proj_dim=time_cond_proj_dim,
406
+ )
407
+
408
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
409
+ encoder_hid_dim_type = "text_proj"
410
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
411
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
412
+
413
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
414
+ raise ValueError(
415
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
416
+ )
417
+
418
+ if encoder_hid_dim_type == "text_proj":
419
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
420
+ elif encoder_hid_dim_type == "text_image_proj":
421
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
422
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
423
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
424
+ self.encoder_hid_proj = TextImageProjection(
425
+ text_embed_dim=encoder_hid_dim,
426
+ image_embed_dim=cross_attention_dim,
427
+ cross_attention_dim=cross_attention_dim,
428
+ )
429
+ elif encoder_hid_dim_type == "image_proj":
430
+ # Kandinsky 2.2
431
+ self.encoder_hid_proj = ImageProjection(
432
+ image_embed_dim=encoder_hid_dim,
433
+ cross_attention_dim=cross_attention_dim,
434
+ )
435
+ elif encoder_hid_dim_type is not None:
436
+ raise ValueError(
437
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
438
+ )
439
+ else:
440
+ self.encoder_hid_proj = None
441
+
442
+ # class embedding
443
+ if class_embed_type is None and num_class_embeds is not None:
444
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
445
+ elif class_embed_type == "timestep":
446
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
447
+ elif class_embed_type == "identity":
448
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
449
+ elif class_embed_type == "projection":
450
+ if projection_class_embeddings_input_dim is None:
451
+ raise ValueError(
452
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
453
+ )
454
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
455
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
456
+ # 2. it projects from an arbitrary input dimension.
457
+ #
458
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
459
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
460
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
461
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
462
+ elif class_embed_type == "simple_projection":
463
+ if projection_class_embeddings_input_dim is None:
464
+ raise ValueError(
465
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
466
+ )
467
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
468
+ else:
469
+ self.class_embedding = None
470
+
471
+ if addition_embed_type == "text":
472
+ if encoder_hid_dim is not None:
473
+ text_time_embedding_from_dim = encoder_hid_dim
474
+ else:
475
+ text_time_embedding_from_dim = cross_attention_dim
476
+
477
+ self.add_embedding = TextTimeEmbedding(
478
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
479
+ )
480
+ elif addition_embed_type == "text_image":
481
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
482
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
483
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
484
+ self.add_embedding = TextImageTimeEmbedding(
485
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
486
+ )
487
+ elif addition_embed_type == "text_time":
488
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
489
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
490
+ elif addition_embed_type == "image":
491
+ # Kandinsky 2.2
492
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
493
+ elif addition_embed_type == "image_hint":
494
+ # Kandinsky 2.2 ControlNet
495
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
496
+ elif addition_embed_type is not None:
497
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
498
+
499
+ if time_embedding_act_fn is None:
500
+ self.time_embed_act = None
501
+ else:
502
+ self.time_embed_act = get_activation(time_embedding_act_fn)
503
+
504
+ self.down_blocks = nn.ModuleList([])
505
+ self.up_blocks = nn.ModuleList([])
506
+
507
+ if isinstance(only_cross_attention, bool):
508
+ if mid_block_only_cross_attention is None:
509
+ mid_block_only_cross_attention = only_cross_attention
510
+
511
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
512
+
513
+ if mid_block_only_cross_attention is None:
514
+ mid_block_only_cross_attention = False
515
+
516
+ if isinstance(num_attention_heads, int):
517
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
518
+
519
+ if isinstance(attention_head_dim, int):
520
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
521
+
522
+ if isinstance(cross_attention_dim, int):
523
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
524
+
525
+ if isinstance(layers_per_block, int):
526
+ layers_per_block = [layers_per_block] * len(down_block_types)
527
+
528
+ if isinstance(transformer_layers_per_block, int):
529
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
530
+
531
+ if class_embeddings_concat:
532
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
533
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
534
+ # regular time embeddings
535
+ blocks_time_embed_dim = time_embed_dim * 2
536
+ else:
537
+ blocks_time_embed_dim = time_embed_dim
538
+
539
+ # down
540
+ output_channel = block_out_channels[0]
541
+ for i, down_block_type in enumerate(down_block_types):
542
+ input_channel = output_channel
543
+ output_channel = block_out_channels[i]
544
+ is_final_block = i == len(block_out_channels) - 1
545
+
546
+ down_block = get_down_block(
547
+ down_block_type,
548
+ num_layers=layers_per_block[i],
549
+ transformer_layers_per_block=transformer_layers_per_block[i],
550
+ in_channels=input_channel,
551
+ out_channels=output_channel,
552
+ temb_channels=blocks_time_embed_dim,
553
+ add_downsample=not is_final_block,
554
+ resnet_eps=norm_eps,
555
+ resnet_act_fn=act_fn,
556
+ resnet_groups=norm_num_groups,
557
+ cross_attention_dim=cross_attention_dim[i],
558
+ num_attention_heads=num_attention_heads[i],
559
+ downsample_padding=downsample_padding,
560
+ dual_cross_attention=dual_cross_attention,
561
+ use_linear_projection=use_linear_projection,
562
+ only_cross_attention=only_cross_attention[i],
563
+ upcast_attention=upcast_attention,
564
+ resnet_time_scale_shift=resnet_time_scale_shift,
565
+ resnet_skip_time_act=resnet_skip_time_act,
566
+ resnet_out_scale_factor=resnet_out_scale_factor,
567
+ cross_attention_norm=cross_attention_norm,
568
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
569
+ num_views=num_views,
570
+ cd_attention_last=cd_attention_last,
571
+ cd_attention_mid=cd_attention_mid,
572
+ multiview_attention=multiview_attention,
573
+ sparse_mv_attention=sparse_mv_attention,
574
+ selfattn_block=selfattn_block,
575
+ mvcd_attention=mvcd_attention,
576
+ use_dino=use_dino
577
+ )
578
+ self.down_blocks.append(down_block)
579
+
580
+ # mid
581
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
582
+ self.mid_block = UNetMidBlock2DCrossAttn(
583
+ transformer_layers_per_block=transformer_layers_per_block[-1],
584
+ in_channels=block_out_channels[-1],
585
+ temb_channels=blocks_time_embed_dim,
586
+ resnet_eps=norm_eps,
587
+ resnet_act_fn=act_fn,
588
+ output_scale_factor=mid_block_scale_factor,
589
+ resnet_time_scale_shift=resnet_time_scale_shift,
590
+ cross_attention_dim=cross_attention_dim[-1],
591
+ num_attention_heads=num_attention_heads[-1],
592
+ resnet_groups=norm_num_groups,
593
+ dual_cross_attention=dual_cross_attention,
594
+ use_linear_projection=use_linear_projection,
595
+ upcast_attention=upcast_attention,
596
+ )
597
+ # custom MV2D attention block
598
+ elif mid_block_type == "UNetMidBlockMV2DCrossAttn":
599
+ self.mid_block = UNetMidBlockMV2DCrossAttn(
600
+ transformer_layers_per_block=transformer_layers_per_block[-1],
601
+ in_channels=block_out_channels[-1],
602
+ temb_channels=blocks_time_embed_dim,
603
+ resnet_eps=norm_eps,
604
+ resnet_act_fn=act_fn,
605
+ output_scale_factor=mid_block_scale_factor,
606
+ resnet_time_scale_shift=resnet_time_scale_shift,
607
+ cross_attention_dim=cross_attention_dim[-1],
608
+ num_attention_heads=num_attention_heads[-1],
609
+ resnet_groups=norm_num_groups,
610
+ dual_cross_attention=dual_cross_attention,
611
+ use_linear_projection=use_linear_projection,
612
+ upcast_attention=upcast_attention,
613
+ num_views=num_views,
614
+ cd_attention_last=cd_attention_last,
615
+ cd_attention_mid=cd_attention_mid,
616
+ multiview_attention=multiview_attention,
617
+ sparse_mv_attention=sparse_mv_attention,
618
+ selfattn_block=selfattn_block,
619
+ mvcd_attention=mvcd_attention,
620
+ use_dino=use_dino
621
+ )
622
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
623
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
624
+ in_channels=block_out_channels[-1],
625
+ temb_channels=blocks_time_embed_dim,
626
+ resnet_eps=norm_eps,
627
+ resnet_act_fn=act_fn,
628
+ output_scale_factor=mid_block_scale_factor,
629
+ cross_attention_dim=cross_attention_dim[-1],
630
+ attention_head_dim=attention_head_dim[-1],
631
+ resnet_groups=norm_num_groups,
632
+ resnet_time_scale_shift=resnet_time_scale_shift,
633
+ skip_time_act=resnet_skip_time_act,
634
+ only_cross_attention=mid_block_only_cross_attention,
635
+ cross_attention_norm=cross_attention_norm,
636
+ )
637
+ elif mid_block_type is None:
638
+ self.mid_block = None
639
+ else:
640
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
641
+
642
+ self.addition_downsample = addition_downsample
643
+ if self.addition_downsample:
644
+ inc = block_out_channels[-1]
645
+ self.downsample = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
646
+ self.conv_block = nn.ModuleList()
647
+ self.conv_block.append(BasicConvBlock(inc, addition_channels[0], stride=1))
648
+ for dim_ in addition_channels[1:-1]:
649
+ self.conv_block.append(BasicConvBlock(dim_, dim_, stride=1))
650
+ self.conv_block.append(BasicConvBlock(dim_, inc))
651
+ self.addition_conv_out = nn.Conv2d(inc, inc, kernel_size=1, bias=False)
652
+ nn.init.zeros_(self.addition_conv_out.weight.data)
653
+ self.addition_act_out = nn.SiLU()
654
+ self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
655
+
656
+ self.regress_elevation = regress_elevation
657
+ self.regress_focal_length = regress_focal_length
658
+ if regress_elevation or regress_focal_length:
659
+ self.pool = nn.AdaptiveAvgPool2d((1, 1))
660
+ self.camera_embedding = TimestepEmbedding(projection_camera_embeddings_input_dim, time_embed_dim=time_embed_dim)
661
+
662
+ regress_in_dim = block_out_channels[-1]*2 if mvcd_attention else block_out_channels
663
+
664
+ if regress_elevation:
665
+ self.elevation_regressor = ResidualLiner(regress_in_dim, 1, 1280, act=None, num_block=num_regress_blocks)
666
+ if regress_focal_length:
667
+ self.focal_regressor = ResidualLiner(regress_in_dim, 1, 1280, act=None, num_block=num_regress_blocks)
668
+ '''
669
+ self.regress_elevation = regress_elevation
670
+ self.regress_focal_length = regress_focal_length
671
+ if regress_elevation and (not regress_focal_length):
672
+ print("Regressing elevation")
673
+ cam_dim = 1
674
+ elif regress_focal_length and (not regress_elevation):
675
+ print("Regressing focal length")
676
+ cam_dim = 6
677
+ elif regress_elevation and regress_focal_length:
678
+ print("Regressing both elevation and focal length")
679
+ cam_dim = 7
680
+ else:
681
+ cam_dim = 0
682
+ assert projection_camera_embeddings_input_dim == 2*cam_dim, "projection_camera_embeddings_input_dim should be 2*cam_dim"
683
+ if regress_elevation or regress_focal_length:
684
+ self.elevation_regressor = nn.ModuleList([
685
+ nn.Linear(block_out_channels[-1], 1280),
686
+ nn.SiLU(),
687
+ nn.Linear(1280, 1280),
688
+ nn.SiLU(),
689
+ nn.Linear(1280, cam_dim)
690
+ ])
691
+ self.pool = nn.AdaptiveAvgPool2d((1, 1))
692
+ self.focal_act = nn.Softmax(dim=-1)
693
+ self.camera_embedding = TimestepEmbedding(projection_camera_embeddings_input_dim, time_embed_dim=time_embed_dim)
694
+ '''
695
+
696
+ # count how many layers upsample the images
697
+ self.num_upsamplers = 0
698
+
699
+ # up
700
+ reversed_block_out_channels = list(reversed(block_out_channels))
701
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
702
+ reversed_layers_per_block = list(reversed(layers_per_block))
703
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
704
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
705
+ only_cross_attention = list(reversed(only_cross_attention))
706
+
707
+ output_channel = reversed_block_out_channels[0]
708
+ for i, up_block_type in enumerate(up_block_types):
709
+ is_final_block = i == len(block_out_channels) - 1
710
+
711
+ prev_output_channel = output_channel
712
+ output_channel = reversed_block_out_channels[i]
713
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
714
+
715
+ # add upsample block for all BUT final layer
716
+ if not is_final_block:
717
+ add_upsample = True
718
+ self.num_upsamplers += 1
719
+ else:
720
+ add_upsample = False
721
+
722
+ up_block = get_up_block(
723
+ up_block_type,
724
+ num_layers=reversed_layers_per_block[i] + 1,
725
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
726
+ in_channels=input_channel,
727
+ out_channels=output_channel,
728
+ prev_output_channel=prev_output_channel,
729
+ temb_channels=blocks_time_embed_dim,
730
+ add_upsample=add_upsample,
731
+ resnet_eps=norm_eps,
732
+ resnet_act_fn=act_fn,
733
+ resnet_groups=norm_num_groups,
734
+ cross_attention_dim=reversed_cross_attention_dim[i],
735
+ num_attention_heads=reversed_num_attention_heads[i],
736
+ dual_cross_attention=dual_cross_attention,
737
+ use_linear_projection=use_linear_projection,
738
+ only_cross_attention=only_cross_attention[i],
739
+ upcast_attention=upcast_attention,
740
+ resnet_time_scale_shift=resnet_time_scale_shift,
741
+ resnet_skip_time_act=resnet_skip_time_act,
742
+ resnet_out_scale_factor=resnet_out_scale_factor,
743
+ cross_attention_norm=cross_attention_norm,
744
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
745
+ num_views=num_views,
746
+ cd_attention_last=cd_attention_last,
747
+ cd_attention_mid=cd_attention_mid,
748
+ multiview_attention=multiview_attention,
749
+ sparse_mv_attention=sparse_mv_attention,
750
+ selfattn_block=selfattn_block,
751
+ mvcd_attention=mvcd_attention,
752
+ use_dino=use_dino
753
+ )
754
+ self.up_blocks.append(up_block)
755
+ prev_output_channel = output_channel
756
+
757
+ # out
758
+ if norm_num_groups is not None:
759
+ self.conv_norm_out = nn.GroupNorm(
760
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
761
+ )
762
+
763
+ self.conv_act = get_activation(act_fn)
764
+
765
+ else:
766
+ self.conv_norm_out = None
767
+ self.conv_act = None
768
+
769
+ conv_out_padding = (conv_out_kernel - 1) // 2
770
+ self.conv_out = nn.Conv2d(
771
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
772
+ )
773
+
774
+ @property
775
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
776
+ r"""
777
+ Returns:
778
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
779
+ indexed by its weight name.
780
+ """
781
+ # set recursively
782
+ processors = {}
783
+
784
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
785
+ if hasattr(module, "set_processor"):
786
+ processors[f"{name}.processor"] = module.processor
787
+
788
+ for sub_name, child in module.named_children():
789
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
790
+
791
+ return processors
792
+
793
+ for name, module in self.named_children():
794
+ fn_recursive_add_processors(name, module, processors)
795
+
796
+ return processors
797
+
798
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
799
+ r"""
800
+ Sets the attention processor to use to compute attention.
801
+
802
+ Parameters:
803
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
804
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
805
+ for **all** `Attention` layers.
806
+
807
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
808
+ processor. This is strongly recommended when setting trainable attention processors.
809
+
810
+ """
811
+ count = len(self.attn_processors.keys())
812
+
813
+ if isinstance(processor, dict) and len(processor) != count:
814
+ raise ValueError(
815
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
816
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
817
+ )
818
+
819
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
820
+ if hasattr(module, "set_processor"):
821
+ if not isinstance(processor, dict):
822
+ module.set_processor(processor)
823
+ else:
824
+ module.set_processor(processor.pop(f"{name}.processor"))
825
+
826
+ for sub_name, child in module.named_children():
827
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
828
+
829
+ for name, module in self.named_children():
830
+ fn_recursive_attn_processor(name, module, processor)
831
+
832
+ def set_default_attn_processor(self):
833
+ """
834
+ Disables custom attention processors and sets the default attention implementation.
835
+ """
836
+ self.set_attn_processor(AttnProcessor())
837
+
838
+ def set_attention_slice(self, slice_size):
839
+ r"""
840
+ Enable sliced attention computation.
841
+
842
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
843
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
844
+
845
+ Args:
846
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
847
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
848
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
849
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
850
+ must be a multiple of `slice_size`.
851
+ """
852
+ sliceable_head_dims = []
853
+
854
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
855
+ if hasattr(module, "set_attention_slice"):
856
+ sliceable_head_dims.append(module.sliceable_head_dim)
857
+
858
+ for child in module.children():
859
+ fn_recursive_retrieve_sliceable_dims(child)
860
+
861
+ # retrieve number of attention layers
862
+ for module in self.children():
863
+ fn_recursive_retrieve_sliceable_dims(module)
864
+
865
+ num_sliceable_layers = len(sliceable_head_dims)
866
+
867
+ if slice_size == "auto":
868
+ # half the attention head size is usually a good trade-off between
869
+ # speed and memory
870
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
871
+ elif slice_size == "max":
872
+ # make smallest slice possible
873
+ slice_size = num_sliceable_layers * [1]
874
+
875
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
876
+
877
+ if len(slice_size) != len(sliceable_head_dims):
878
+ raise ValueError(
879
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
880
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
881
+ )
882
+
883
+ for i in range(len(slice_size)):
884
+ size = slice_size[i]
885
+ dim = sliceable_head_dims[i]
886
+ if size is not None and size > dim:
887
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
888
+
889
+ # Recursively walk through all the children.
890
+ # Any children which exposes the set_attention_slice method
891
+ # gets the message
892
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
893
+ if hasattr(module, "set_attention_slice"):
894
+ module.set_attention_slice(slice_size.pop())
895
+
896
+ for child in module.children():
897
+ fn_recursive_set_attention_slice(child, slice_size)
898
+
899
+ reversed_slice_size = list(reversed(slice_size))
900
+ for module in self.children():
901
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
902
+
903
+ def _set_gradient_checkpointing(self, module, value=False):
904
+ if isinstance(module, (CrossAttnDownBlock2D, CrossAttnDownBlockMV2D, DownBlock2D, CrossAttnUpBlock2D, CrossAttnUpBlockMV2D, UpBlock2D)):
905
+ module.gradient_checkpointing = value
906
+
907
+ def forward(
908
+ self,
909
+ sample: torch.FloatTensor,
910
+ timestep: Union[torch.Tensor, float, int],
911
+ encoder_hidden_states: torch.Tensor,
912
+ class_labels: Optional[torch.Tensor] = None,
913
+ timestep_cond: Optional[torch.Tensor] = None,
914
+ attention_mask: Optional[torch.Tensor] = None,
915
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
916
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
917
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
918
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
919
+ encoder_attention_mask: Optional[torch.Tensor] = None,
920
+ dino_feature: Optional[torch.Tensor] = None,
921
+ return_dict: bool = True,
922
+ vis_max_min: bool = False,
923
+ ) -> Union[UNetMV2DConditionOutput, Tuple]:
924
+ r"""
925
+ The [`UNet2DConditionModel`] forward method.
926
+
927
+ Args:
928
+ sample (`torch.FloatTensor`):
929
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
930
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
931
+ encoder_hidden_states (`torch.FloatTensor`):
932
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
933
+ encoder_attention_mask (`torch.Tensor`):
934
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
935
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
936
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
937
+ return_dict (`bool`, *optional*, defaults to `True`):
938
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
939
+ tuple.
940
+ cross_attention_kwargs (`dict`, *optional*):
941
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
942
+ added_cond_kwargs: (`dict`, *optional*):
943
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
944
+ are passed along to the UNet blocks.
945
+
946
+ Returns:
947
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
948
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
949
+ a `tuple` is returned where the first element is the sample tensor.
950
+ """
951
+ record_max_min = {}
952
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
953
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
954
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
955
+ # on the fly if necessary.
956
+ default_overall_up_factor = 2**self.num_upsamplers
957
+
958
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
959
+ forward_upsample_size = False
960
+ upsample_size = None
961
+
962
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
963
+ logger.info("Forward upsample size to force interpolation output size.")
964
+ forward_upsample_size = True
965
+
966
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
967
+ # expects mask of shape:
968
+ # [batch, key_tokens]
969
+ # adds singleton query_tokens dimension:
970
+ # [batch, 1, key_tokens]
971
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
972
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
973
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
974
+ if attention_mask is not None:
975
+ # assume that mask is expressed as:
976
+ # (1 = keep, 0 = discard)
977
+ # convert mask into a bias that can be added to attention scores:
978
+ # (keep = +0, discard = -10000.0)
979
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
980
+ attention_mask = attention_mask.unsqueeze(1)
981
+
982
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
983
+ if encoder_attention_mask is not None:
984
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
985
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
986
+
987
+ # 0. center input if necessary
988
+ if self.config.center_input_sample:
989
+ sample = 2 * sample - 1.0
990
+ # 1. time
991
+ timesteps = timestep
992
+ if not torch.is_tensor(timesteps):
993
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
994
+ # This would be a good case for the `match` statement (Python 3.10+)
995
+ is_mps = sample.device.type == "mps"
996
+ if isinstance(timestep, float):
997
+ dtype = torch.float32 if is_mps else torch.float64
998
+ else:
999
+ dtype = torch.int32 if is_mps else torch.int64
1000
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
1001
+ elif len(timesteps.shape) == 0:
1002
+ timesteps = timesteps[None].to(sample.device)
1003
+
1004
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1005
+ timesteps = timesteps.expand(sample.shape[0])
1006
+
1007
+ t_emb = self.time_proj(timesteps)
1008
+
1009
+ # `Timesteps` does not contain any weights and will always return f32 tensors
1010
+ # but time_embedding might actually be running in fp16. so we need to cast here.
1011
+ # there might be better ways to encapsulate this.
1012
+ t_emb = t_emb.to(dtype=sample.dtype)
1013
+
1014
+ emb = self.time_embedding(t_emb, timestep_cond)
1015
+ aug_emb = None
1016
+ if self.class_embedding is not None:
1017
+ if class_labels is None:
1018
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
1019
+
1020
+ if self.config.class_embed_type == "timestep":
1021
+ class_labels = self.time_proj(class_labels)
1022
+
1023
+ # `Timesteps` does not contain any weights and will always return f32 tensors
1024
+ # there might be better ways to encapsulate this.
1025
+ class_labels = class_labels.to(dtype=sample.dtype)
1026
+
1027
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
1028
+ if self.config.class_embeddings_concat:
1029
+ emb = torch.cat([emb, class_emb], dim=-1)
1030
+ else:
1031
+ emb = emb + class_emb
1032
+
1033
+ if self.config.addition_embed_type == "text":
1034
+ aug_emb = self.add_embedding(encoder_hidden_states)
1035
+ elif self.config.addition_embed_type == "text_image":
1036
+ # Kandinsky 2.1 - style
1037
+ if "image_embeds" not in added_cond_kwargs:
1038
+ raise ValueError(
1039
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1040
+ )
1041
+
1042
+ image_embs = added_cond_kwargs.get("image_embeds")
1043
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
1044
+ aug_emb = self.add_embedding(text_embs, image_embs)
1045
+ elif self.config.addition_embed_type == "text_time":
1046
+ # SDXL - style
1047
+ if "text_embeds" not in added_cond_kwargs:
1048
+ raise ValueError(
1049
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
1050
+ )
1051
+ text_embeds = added_cond_kwargs.get("text_embeds")
1052
+ if "time_ids" not in added_cond_kwargs:
1053
+ raise ValueError(
1054
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
1055
+ )
1056
+ time_ids = added_cond_kwargs.get("time_ids")
1057
+ time_embeds = self.add_time_proj(time_ids.flatten())
1058
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
1059
+
1060
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
1061
+ add_embeds = add_embeds.to(emb.dtype)
1062
+ aug_emb = self.add_embedding(add_embeds)
1063
+ elif self.config.addition_embed_type == "image":
1064
+ # Kandinsky 2.2 - style
1065
+ if "image_embeds" not in added_cond_kwargs:
1066
+ raise ValueError(
1067
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1068
+ )
1069
+ image_embs = added_cond_kwargs.get("image_embeds")
1070
+ aug_emb = self.add_embedding(image_embs)
1071
+ elif self.config.addition_embed_type == "image_hint":
1072
+ # Kandinsky 2.2 - style
1073
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
1074
+ raise ValueError(
1075
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
1076
+ )
1077
+ image_embs = added_cond_kwargs.get("image_embeds")
1078
+ hint = added_cond_kwargs.get("hint")
1079
+ aug_emb, hint = self.add_embedding(image_embs, hint)
1080
+ sample = torch.cat([sample, hint], dim=1)
1081
+
1082
+ emb = emb + aug_emb if aug_emb is not None else emb
1083
+ emb_pre_act = emb
1084
+ if self.time_embed_act is not None:
1085
+ emb = self.time_embed_act(emb)
1086
+
1087
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
1088
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1089
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
1090
+ # Kadinsky 2.1 - style
1091
+ if "image_embeds" not in added_cond_kwargs:
1092
+ raise ValueError(
1093
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1094
+ )
1095
+
1096
+ image_embeds = added_cond_kwargs.get("image_embeds")
1097
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
1098
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
1099
+ # Kandinsky 2.2 - style
1100
+ if "image_embeds" not in added_cond_kwargs:
1101
+ raise ValueError(
1102
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1103
+ )
1104
+ image_embeds = added_cond_kwargs.get("image_embeds")
1105
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1106
+ # 2. pre-process
1107
+ sample = self.conv_in(sample)
1108
+ # 3. down
1109
+
1110
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
1111
+ is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
1112
+
1113
+ down_block_res_samples = (sample,)
1114
+ for i, downsample_block in enumerate(self.down_blocks):
1115
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
1116
+ # For t2i-adapter CrossAttnDownBlock2D
1117
+ additional_residuals = {}
1118
+ if is_adapter and len(down_block_additional_residuals) > 0:
1119
+ additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
1120
+
1121
+ sample, res_samples = downsample_block(
1122
+ hidden_states=sample,
1123
+ temb=emb,
1124
+ encoder_hidden_states=encoder_hidden_states,
1125
+ dino_feature=dino_feature,
1126
+ attention_mask=attention_mask,
1127
+ cross_attention_kwargs=cross_attention_kwargs,
1128
+ encoder_attention_mask=encoder_attention_mask,
1129
+ **additional_residuals,
1130
+ )
1131
+ else:
1132
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
1133
+
1134
+ if is_adapter and len(down_block_additional_residuals) > 0:
1135
+ sample += down_block_additional_residuals.pop(0)
1136
+
1137
+ down_block_res_samples += res_samples
1138
+
1139
+ if is_controlnet:
1140
+ new_down_block_res_samples = ()
1141
+
1142
+ for down_block_res_sample, down_block_additional_residual in zip(
1143
+ down_block_res_samples, down_block_additional_residuals
1144
+ ):
1145
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
1146
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
1147
+
1148
+ down_block_res_samples = new_down_block_res_samples
1149
+
1150
+ if self.addition_downsample:
1151
+ global_sample = sample
1152
+ global_sample = self.downsample(global_sample)
1153
+ for layer in self.conv_block:
1154
+ global_sample = layer(global_sample)
1155
+ global_sample = self.addition_act_out(self.addition_conv_out(global_sample))
1156
+ global_sample = self.upsample(global_sample)
1157
+ # 4. mid
1158
+ if self.mid_block is not None:
1159
+ sample = self.mid_block(
1160
+ sample,
1161
+ emb,
1162
+ encoder_hidden_states=encoder_hidden_states,
1163
+ dino_feature=dino_feature,
1164
+ attention_mask=attention_mask,
1165
+ cross_attention_kwargs=cross_attention_kwargs,
1166
+ encoder_attention_mask=encoder_attention_mask,
1167
+ )
1168
+ # 4.1 regress elevation and focal length
1169
+ # # predict elevation -> embed -> projection -> add to time emb
1170
+ if self.regress_elevation or self.regress_focal_length:
1171
+ pool_embeds = self.pool(sample.detach()).squeeze(-1).squeeze(-1) # (2B, C)
1172
+ if self.mvcd_attention:
1173
+ pool_embeds_normal, pool_embeds_color = torch.chunk(pool_embeds, 2, dim=0)
1174
+ pool_embeds = torch.cat([pool_embeds_normal, pool_embeds_color], dim=-1) # (B, 2C)
1175
+ pose_pred = []
1176
+ if self.regress_elevation:
1177
+ ele_pred = self.elevation_regressor(pool_embeds)
1178
+ ele_pred = rearrange(ele_pred, '(b v) c -> b v c', v=self.num_views)
1179
+ ele_pred = torch.mean(ele_pred, dim=1)
1180
+ pose_pred.append(ele_pred) # b, c
1181
+
1182
+ if self.regress_focal_length:
1183
+ focal_pred = self.focal_regressor(pool_embeds)
1184
+ focal_pred = rearrange(focal_pred, '(b v) c -> b v c', v=self.num_views)
1185
+ focal_pred = torch.mean(focal_pred, dim=1)
1186
+ pose_pred.append(focal_pred)
1187
+ pose_pred = torch.cat(pose_pred, dim=-1)
1188
+ # 'e_de_da_sincos', (B, 2)
1189
+ pose_embeds = torch.cat([
1190
+ torch.sin(pose_pred),
1191
+ torch.cos(pose_pred)
1192
+ ], dim=-1)
1193
+ pose_embeds = self.camera_embedding(pose_embeds)
1194
+ pose_embeds = torch.repeat_interleave(pose_embeds, self.num_views, 0)
1195
+ if self.mvcd_attention:
1196
+ pose_embeds = torch.cat([pose_embeds,] * 2, dim=0)
1197
+
1198
+ emb = pose_embeds + emb_pre_act
1199
+ if self.time_embed_act is not None:
1200
+ emb = self.time_embed_act(emb)
1201
+
1202
+ if is_controlnet:
1203
+ sample = sample + mid_block_additional_residual
1204
+
1205
+ if self.addition_downsample:
1206
+ sample = sample + global_sample
1207
+
1208
+ # 5. up
1209
+ for i, upsample_block in enumerate(self.up_blocks):
1210
+ is_final_block = i == len(self.up_blocks) - 1
1211
+
1212
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1213
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1214
+
1215
+ # if we have not reached the final block and need to forward the
1216
+ # upsample size, we do it here
1217
+ if not is_final_block and forward_upsample_size:
1218
+ upsample_size = down_block_res_samples[-1].shape[2:]
1219
+
1220
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1221
+ sample = upsample_block(
1222
+ hidden_states=sample,
1223
+ temb=emb,
1224
+ res_hidden_states_tuple=res_samples,
1225
+ encoder_hidden_states=encoder_hidden_states,
1226
+ dino_feature=dino_feature,
1227
+ cross_attention_kwargs=cross_attention_kwargs,
1228
+ upsample_size=upsample_size,
1229
+ attention_mask=attention_mask,
1230
+ encoder_attention_mask=encoder_attention_mask,
1231
+ )
1232
+ else:
1233
+ sample = upsample_block(
1234
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
1235
+ )
1236
+ if torch.isnan(sample).any() or torch.isinf(sample).any():
1237
+ print("NAN in sample, stop training.")
1238
+ exit()
1239
+ # 6. post-process
1240
+ if self.conv_norm_out:
1241
+ sample = self.conv_norm_out(sample)
1242
+ sample = self.conv_act(sample)
1243
+ sample = self.conv_out(sample)
1244
+ if not return_dict:
1245
+ return (sample, pose_pred)
1246
+ if self.regress_elevation or self.regress_focal_length:
1247
+ return UNetMV2DConditionOutput(sample=sample), pose_pred
1248
+ else:
1249
+ return UNetMV2DConditionOutput(sample=sample)
1250
+
1251
+
1252
+ @classmethod
1253
+ def from_pretrained_2d(
1254
+ cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
1255
+ camera_embedding_type: str, num_views: int, sample_size: int,
1256
+ zero_init_conv_in: bool = True, zero_init_camera_projection: bool = False,
1257
+ projection_camera_embeddings_input_dim: int=2,
1258
+ cd_attention_last: bool = False, num_regress_blocks: int = 4,
1259
+ cd_attention_mid: bool = False, multiview_attention: bool = True,
1260
+ sparse_mv_attention: bool = False, selfattn_block: str = 'custom', mvcd_attention: bool = False,
1261
+ in_channels: int = 8, out_channels: int = 4, unclip: bool = False, regress_elevation: bool = False, regress_focal_length: bool = False,
1262
+ init_mvattn_with_selfattn: bool= False, use_dino: bool = False, addition_downsample: bool = False,
1263
+ **kwargs
1264
+ ):
1265
+ r"""
1266
+ Instantiate a pretrained PyTorch model from a pretrained model configuration.
1267
+
1268
+ The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
1269
+ train the model, set it back in training mode with `model.train()`.
1270
+
1271
+ Parameters:
1272
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
1273
+ Can be either:
1274
+
1275
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
1276
+ the Hub.
1277
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
1278
+ with [`~ModelMixin.save_pretrained`].
1279
+
1280
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
1281
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
1282
+ is not used.
1283
+ torch_dtype (`str` or `torch.dtype`, *optional*):
1284
+ Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
1285
+ dtype is automatically derived from the model's weights.
1286
+ force_download (`bool`, *optional*, defaults to `False`):
1287
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
1288
+ cached versions if they exist.
1289
+ resume_download (`bool`, *optional*, defaults to `False`):
1290
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
1291
+ incompletely downloaded files are deleted.
1292
+ proxies (`Dict[str, str]`, *optional*):
1293
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
1294
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
1295
+ output_loading_info (`bool`, *optional*, defaults to `False`):
1296
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
1297
+ local_files_only(`bool`, *optional*, defaults to `False`):
1298
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
1299
+ won't be downloaded from the Hub.
1300
+ use_auth_token (`str` or *bool*, *optional*):
1301
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
1302
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
1303
+ revision (`str`, *optional*, defaults to `"main"`):
1304
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
1305
+ allowed by Git.
1306
+ from_flax (`bool`, *optional*, defaults to `False`):
1307
+ Load the model weights from a Flax checkpoint save file.
1308
+ subfolder (`str`, *optional*, defaults to `""`):
1309
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
1310
+ mirror (`str`, *optional*):
1311
+ Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
1312
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
1313
+ information.
1314
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
1315
+ A map that specifies where each submodule should go. It doesn't need to be defined for each
1316
+ parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
1317
+ same device.
1318
+
1319
+ Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
1320
+ more information about each option see [designing a device
1321
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
1322
+ max_memory (`Dict`, *optional*):
1323
+ A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
1324
+ each GPU and the available CPU RAM if unset.
1325
+ offload_folder (`str` or `os.PathLike`, *optional*):
1326
+ The path to offload weights if `device_map` contains the value `"disk"`.
1327
+ offload_state_dict (`bool`, *optional*):
1328
+ If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
1329
+ the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
1330
+ when there is some disk offload.
1331
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
1332
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
1333
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
1334
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
1335
+ argument to `True` will raise an error.
1336
+ variant (`str`, *optional*):
1337
+ Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
1338
+ loading `from_flax`.
1339
+ use_safetensors (`bool`, *optional*, defaults to `None`):
1340
+ If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
1341
+ `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
1342
+ weights. If set to `False`, `safetensors` weights are not loaded.
1343
+
1344
+ <Tip>
1345
+
1346
+ To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
1347
+ `huggingface-cli login`. You can also activate the special
1348
+ ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
1349
+ firewalled environment.
1350
+
1351
+ </Tip>
1352
+
1353
+ Example:
1354
+
1355
+ ```py
1356
+ from diffusers import UNet2DConditionModel
1357
+
1358
+ unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
1359
+ ```
1360
+
1361
+ If you get the error message below, you need to finetune the weights for your downstream task:
1362
+
1363
+ ```bash
1364
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
1365
+ - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
1366
+ You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
1367
+ ```
1368
+ """
1369
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
1370
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
1371
+ force_download = kwargs.pop("force_download", False)
1372
+ from_flax = kwargs.pop("from_flax", False)
1373
+ resume_download = kwargs.pop("resume_download", False)
1374
+ proxies = kwargs.pop("proxies", None)
1375
+ output_loading_info = kwargs.pop("output_loading_info", False)
1376
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
1377
+ use_auth_token = kwargs.pop("use_auth_token", None)
1378
+ revision = kwargs.pop("revision", None)
1379
+ torch_dtype = kwargs.pop("torch_dtype", None)
1380
+ subfolder = kwargs.pop("subfolder", None)
1381
+ device_map = kwargs.pop("device_map", None)
1382
+ max_memory = kwargs.pop("max_memory", None)
1383
+ offload_folder = kwargs.pop("offload_folder", None)
1384
+ offload_state_dict = kwargs.pop("offload_state_dict", False)
1385
+ variant = kwargs.pop("variant", None)
1386
+ use_safetensors = kwargs.pop("use_safetensors", None)
1387
+
1388
+ if use_safetensors:
1389
+ raise ValueError(
1390
+ "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
1391
+ )
1392
+
1393
+ allow_pickle = False
1394
+ if use_safetensors is None:
1395
+ use_safetensors = True
1396
+ allow_pickle = True
1397
+
1398
+ if device_map is not None and not is_accelerate_available():
1399
+ raise NotImplementedError(
1400
+ "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
1401
+ " `device_map=None`. You can install accelerate with `pip install accelerate`."
1402
+ )
1403
+
1404
+ # Check if we can handle device_map and dispatching the weights
1405
+ if device_map is not None and not is_torch_version(">=", "1.9.0"):
1406
+ raise NotImplementedError(
1407
+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
1408
+ " `device_map=None`."
1409
+ )
1410
+
1411
+ # Load config if we don't provide a configuration
1412
+ config_path = pretrained_model_name_or_path
1413
+
1414
+ user_agent = {
1415
+ "diffusers": __version__,
1416
+ "file_type": "model",
1417
+ "framework": "pytorch",
1418
+ }
1419
+
1420
+ # load config
1421
+ config, unused_kwargs, commit_hash = cls.load_config(
1422
+ config_path,
1423
+ cache_dir=cache_dir,
1424
+ return_unused_kwargs=True,
1425
+ return_commit_hash=True,
1426
+ force_download=force_download,
1427
+ resume_download=resume_download,
1428
+ proxies=proxies,
1429
+ local_files_only=local_files_only,
1430
+ use_auth_token=use_auth_token,
1431
+ revision=revision,
1432
+ subfolder=subfolder,
1433
+ device_map=device_map,
1434
+ max_memory=max_memory,
1435
+ offload_folder=offload_folder,
1436
+ offload_state_dict=offload_state_dict,
1437
+ user_agent=user_agent,
1438
+ **kwargs,
1439
+ )
1440
+
1441
+ # modify config
1442
+ config["_class_name"] = cls.__name__
1443
+ config['in_channels'] = in_channels
1444
+ config['out_channels'] = out_channels
1445
+ config['sample_size'] = sample_size # training resolution
1446
+ config['num_views'] = num_views
1447
+ config['cd_attention_last'] = cd_attention_last
1448
+ config['cd_attention_mid'] = cd_attention_mid
1449
+ config['multiview_attention'] = multiview_attention
1450
+ config['sparse_mv_attention'] = sparse_mv_attention
1451
+ config['selfattn_block'] = selfattn_block
1452
+ config['mvcd_attention'] = mvcd_attention
1453
+ config["down_block_types"] = [
1454
+ "CrossAttnDownBlockMV2D",
1455
+ "CrossAttnDownBlockMV2D",
1456
+ "CrossAttnDownBlockMV2D",
1457
+ "DownBlock2D"
1458
+ ]
1459
+ config['mid_block_type'] = "UNetMidBlockMV2DCrossAttn"
1460
+ config["up_block_types"] = [
1461
+ "UpBlock2D",
1462
+ "CrossAttnUpBlockMV2D",
1463
+ "CrossAttnUpBlockMV2D",
1464
+ "CrossAttnUpBlockMV2D"
1465
+ ]
1466
+
1467
+
1468
+ config['regress_elevation'] = regress_elevation # true
1469
+ config['regress_focal_length'] = regress_focal_length # true
1470
+ config['projection_camera_embeddings_input_dim'] = projection_camera_embeddings_input_dim # 2 for elevation and 10 for focal_length
1471
+ config['use_dino'] = use_dino
1472
+ config['num_regress_blocks'] = num_regress_blocks
1473
+ config['addition_downsample'] = addition_downsample
1474
+ # load model
1475
+ model_file = None
1476
+ if from_flax:
1477
+ raise NotImplementedError
1478
+ else:
1479
+ if use_safetensors:
1480
+ try:
1481
+ model_file = _get_model_file(
1482
+ pretrained_model_name_or_path,
1483
+ weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
1484
+ cache_dir=cache_dir,
1485
+ force_download=force_download,
1486
+ resume_download=resume_download,
1487
+ proxies=proxies,
1488
+ local_files_only=local_files_only,
1489
+ use_auth_token=use_auth_token,
1490
+ revision=revision,
1491
+ subfolder=subfolder,
1492
+ user_agent=user_agent,
1493
+ commit_hash=commit_hash,
1494
+ )
1495
+ except IOError as e:
1496
+ if not allow_pickle:
1497
+ raise e
1498
+ pass
1499
+ if model_file is None:
1500
+ model_file = _get_model_file(
1501
+ pretrained_model_name_or_path,
1502
+ weights_name=_add_variant(WEIGHTS_NAME, variant),
1503
+ cache_dir=cache_dir,
1504
+ force_download=force_download,
1505
+ resume_download=resume_download,
1506
+ proxies=proxies,
1507
+ local_files_only=local_files_only,
1508
+ use_auth_token=use_auth_token,
1509
+ revision=revision,
1510
+ subfolder=subfolder,
1511
+ user_agent=user_agent,
1512
+ commit_hash=commit_hash,
1513
+ )
1514
+
1515
+ model = cls.from_config(config, **unused_kwargs)
1516
+ import copy
1517
+ state_dict_pretrain = load_state_dict(model_file, variant=variant)
1518
+ state_dict = copy.deepcopy(state_dict_pretrain)
1519
+
1520
+ if init_mvattn_with_selfattn:
1521
+ for key in state_dict_pretrain:
1522
+ if 'attn1' in key:
1523
+ key_mv = key.replace('attn1', 'attn_mv')
1524
+ state_dict[key_mv] = state_dict_pretrain[key]
1525
+ if 'to_out.0.weight' in key:
1526
+ nn.init.zeros_(state_dict[key_mv].data)
1527
+ if 'transformer_blocks' in key and 'norm1' in key: # in case that initialize the norm layer in resnet block
1528
+ key_mv = key.replace('norm1', 'norm_mv')
1529
+ state_dict[key_mv] = state_dict_pretrain[key]
1530
+ # del state_dict_pretrain
1531
+
1532
+ model._convert_deprecated_attention_blocks(state_dict)
1533
+
1534
+ conv_in_weight = state_dict['conv_in.weight']
1535
+ conv_out_weight = state_dict['conv_out.weight']
1536
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model_2d(
1537
+ model,
1538
+ state_dict,
1539
+ model_file,
1540
+ pretrained_model_name_or_path,
1541
+ ignore_mismatched_sizes=True,
1542
+ )
1543
+ if any([key == 'conv_in.weight' for key, _, _ in mismatched_keys]):
1544
+ # initialize from the original SD structure
1545
+ model.conv_in.weight.data[:,:4] = conv_in_weight
1546
+
1547
+ # whether to place all zero to new layers?
1548
+ if zero_init_conv_in:
1549
+ model.conv_in.weight.data[:,4:] = 0.
1550
+
1551
+ if any([key == 'conv_out.weight' for key, _, _ in mismatched_keys]):
1552
+ # initialize from the original SD structure
1553
+ model.conv_out.weight.data[:,:4] = conv_out_weight
1554
+ if out_channels == 8: # copy for the last 4 channels
1555
+ model.conv_out.weight.data[:, 4:] = conv_out_weight
1556
+
1557
+ if zero_init_camera_projection: # true
1558
+ params = [p for p in model.camera_embedding.parameters()]
1559
+ torch.nn.init.zeros_(params[-1].data)
1560
+
1561
+ loading_info = {
1562
+ "missing_keys": missing_keys,
1563
+ "unexpected_keys": unexpected_keys,
1564
+ "mismatched_keys": mismatched_keys,
1565
+ "error_msgs": error_msgs,
1566
+ }
1567
+
1568
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
1569
+ raise ValueError(
1570
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
1571
+ )
1572
+ elif torch_dtype is not None:
1573
+ model = model.to(torch_dtype)
1574
+
1575
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
1576
+
1577
+ # Set model in evaluation mode to deactivate DropOut modules by default
1578
+ model.eval()
1579
+ if output_loading_info:
1580
+ return model, loading_info
1581
+ return model
1582
+
1583
+ @classmethod
1584
+ def _load_pretrained_model_2d(
1585
+ cls,
1586
+ model,
1587
+ state_dict,
1588
+ resolved_archive_file,
1589
+ pretrained_model_name_or_path,
1590
+ ignore_mismatched_sizes=False,
1591
+ ):
1592
+ # Retrieve missing & unexpected_keys
1593
+ model_state_dict = model.state_dict()
1594
+ loaded_keys = list(state_dict.keys())
1595
+
1596
+ expected_keys = list(model_state_dict.keys())
1597
+
1598
+ original_loaded_keys = loaded_keys
1599
+
1600
+ missing_keys = list(set(expected_keys) - set(loaded_keys))
1601
+ unexpected_keys = list(set(loaded_keys) - set(expected_keys))
1602
+
1603
+ # Make sure we are able to load base models as well as derived models (with heads)
1604
+ model_to_load = model
1605
+
1606
+ def _find_mismatched_keys(
1607
+ state_dict,
1608
+ model_state_dict,
1609
+ loaded_keys,
1610
+ ignore_mismatched_sizes,
1611
+ ):
1612
+ mismatched_keys = []
1613
+ if ignore_mismatched_sizes:
1614
+ for checkpoint_key in loaded_keys:
1615
+ model_key = checkpoint_key
1616
+
1617
+ if (
1618
+ model_key in model_state_dict
1619
+ and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
1620
+ ):
1621
+ mismatched_keys.append(
1622
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
1623
+ )
1624
+ del state_dict[checkpoint_key]
1625
+ return mismatched_keys
1626
+
1627
+ if state_dict is not None:
1628
+ # Whole checkpoint
1629
+ mismatched_keys = _find_mismatched_keys(
1630
+ state_dict,
1631
+ model_state_dict,
1632
+ original_loaded_keys,
1633
+ ignore_mismatched_sizes,
1634
+ )
1635
+ error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
1636
+
1637
+ if len(error_msgs) > 0:
1638
+ error_msg = "\n\t".join(error_msgs)
1639
+ if "size mismatch" in error_msg:
1640
+ error_msg += (
1641
+ "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
1642
+ )
1643
+ raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
1644
+
1645
+ if len(unexpected_keys) > 0:
1646
+ logger.warning(
1647
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
1648
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
1649
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
1650
+ " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
1651
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
1652
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
1653
+ " identical (initializing a BertForSequenceClassification model from a"
1654
+ " BertForSequenceClassification model)."
1655
+ )
1656
+ else:
1657
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
1658
+ if len(missing_keys) > 0:
1659
+ logger.warning(
1660
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
1661
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
1662
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
1663
+ )
1664
+ elif len(mismatched_keys) == 0:
1665
+ logger.info(
1666
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
1667
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
1668
+ f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
1669
+ " without further training."
1670
+ )
1671
+ if len(mismatched_keys) > 0:
1672
+ mismatched_warning = "\n".join(
1673
+ [
1674
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
1675
+ for key, shape1, shape2 in mismatched_keys
1676
+ ]
1677
+ )
1678
+ logger.warning(
1679
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
1680
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
1681
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
1682
+ " able to use it for predictions and inference."
1683
+ )
1684
+
1685
+ return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
1686
+
mvdiffusion/pipelines/pipeline_mvdiffusion_unclip.py ADDED
@@ -0,0 +1,633 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import warnings
3
+ from typing import Callable, List, Optional, Union, Dict, Any
4
+ import PIL
5
+ import torch
6
+ from packaging import version
7
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, CLIPFeatureExtractor, CLIPTokenizer, CLIPTextModel
8
+ from diffusers.utils.import_utils import is_accelerate_available
9
+ from diffusers.configuration_utils import FrozenDict
10
+ from diffusers.image_processor import VaeImageProcessor
11
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
12
+ from diffusers.models.embeddings import get_timestep_embedding
13
+ from diffusers.schedulers import KarrasDiffusionSchedulers
14
+ from diffusers.utils import deprecate, logging
15
+ from diffusers.utils.torch_utils import randn_tensor
16
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
17
+ from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
18
+ import os
19
+ import torchvision.transforms.functional as TF
20
+ from einops import rearrange
21
+ logger = logging.get_logger(__name__)
22
+
23
+ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline):
24
+ """
25
+ Pipeline for text-guided image to image generation using stable unCLIP.
26
+
27
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
28
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
29
+
30
+ Args:
31
+ feature_extractor ([`CLIPFeatureExtractor`]):
32
+ Feature extractor for image pre-processing before being encoded.
33
+ image_encoder ([`CLIPVisionModelWithProjection`]):
34
+ CLIP vision model for encoding images.
35
+ image_normalizer ([`StableUnCLIPImageNormalizer`]):
36
+ Used to normalize the predicted image embeddings before the noise is applied and un-normalize the image
37
+ embeddings after the noise has been applied.
38
+ image_noising_scheduler ([`KarrasDiffusionSchedulers`]):
39
+ Noise schedule for adding noise to the predicted image embeddings. The amount of noise to add is determined
40
+ by `noise_level` in `StableUnCLIPPipeline.__call__`.
41
+ tokenizer (`CLIPTokenizer`):
42
+ Tokenizer of class
43
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
44
+ text_encoder ([`CLIPTextModel`]):
45
+ Frozen text-encoder.
46
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
47
+ scheduler ([`KarrasDiffusionSchedulers`]):
48
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
49
+ vae ([`AutoencoderKL`]):
50
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
51
+ """
52
+ # image encoding components
53
+ feature_extractor: CLIPFeatureExtractor
54
+ image_encoder: CLIPVisionModelWithProjection
55
+ # image noising components
56
+ image_normalizer: StableUnCLIPImageNormalizer
57
+ image_noising_scheduler: KarrasDiffusionSchedulers
58
+ # regular denoising components
59
+ tokenizer: CLIPTokenizer
60
+ text_encoder: CLIPTextModel
61
+ unet: UNet2DConditionModel
62
+ scheduler: KarrasDiffusionSchedulers
63
+ vae: AutoencoderKL
64
+
65
+ def __init__(
66
+ self,
67
+ # image encoding components
68
+ feature_extractor: CLIPFeatureExtractor,
69
+ image_encoder: CLIPVisionModelWithProjection,
70
+ # image noising components
71
+ image_normalizer: StableUnCLIPImageNormalizer,
72
+ image_noising_scheduler: KarrasDiffusionSchedulers,
73
+ # regular denoising components
74
+ tokenizer: CLIPTokenizer,
75
+ text_encoder: CLIPTextModel,
76
+ unet: UNet2DConditionModel,
77
+ scheduler: KarrasDiffusionSchedulers,
78
+ # vae
79
+ vae: AutoencoderKL,
80
+ num_views: int = 4,
81
+ ):
82
+ super().__init__()
83
+
84
+ self.register_modules(
85
+ feature_extractor=feature_extractor,
86
+ image_encoder=image_encoder,
87
+ image_normalizer=image_normalizer,
88
+ image_noising_scheduler=image_noising_scheduler,
89
+ tokenizer=tokenizer,
90
+ text_encoder=text_encoder,
91
+ unet=unet,
92
+ scheduler=scheduler,
93
+ vae=vae,
94
+ )
95
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
96
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
97
+ self.num_views: int = num_views
98
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
99
+ def enable_vae_slicing(self):
100
+ r"""
101
+ Enable sliced VAE decoding.
102
+
103
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
104
+ steps. This is useful to save some memory and allow larger batch sizes.
105
+ """
106
+ self.vae.enable_slicing()
107
+
108
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
109
+ def disable_vae_slicing(self):
110
+ r"""
111
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
112
+ computing decoding in one step.
113
+ """
114
+ self.vae.disable_slicing()
115
+
116
+ def enable_sequential_cpu_offload(self, gpu_id=0):
117
+ r"""
118
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
119
+ models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
120
+ when their specific submodule has its `forward` method called.
121
+ """
122
+ if is_accelerate_available():
123
+ from accelerate import cpu_offload
124
+ else:
125
+ raise ImportError("Please install accelerate via `pip install accelerate`")
126
+
127
+ device = torch.device(f"cuda:{gpu_id}")
128
+
129
+ # TODO: self.image_normalizer.{scale,unscale} are not covered by the offload hooks, so they fails if added to the list
130
+ models = [
131
+ self.image_encoder,
132
+ self.text_encoder,
133
+ self.unet,
134
+ self.vae,
135
+ ]
136
+ for cpu_offloaded_model in models:
137
+ if cpu_offloaded_model is not None:
138
+ cpu_offload(cpu_offloaded_model, device)
139
+
140
+ @property
141
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
142
+ def _execution_device(self):
143
+ r"""
144
+ Returns the device on which the pipeline's models will be executed. After calling
145
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
146
+ hooks.
147
+ """
148
+ if not hasattr(self.unet, "_hf_hook"):
149
+ return self.device
150
+ for module in self.unet.modules():
151
+ if (
152
+ hasattr(module, "_hf_hook")
153
+ and hasattr(module._hf_hook, "execution_device")
154
+ and module._hf_hook.execution_device is not None
155
+ ):
156
+ return torch.device(module._hf_hook.execution_device)
157
+ return self.device
158
+
159
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
160
+ def _encode_prompt(
161
+ self,
162
+ prompt,
163
+ device,
164
+ num_images_per_prompt,
165
+ do_classifier_free_guidance,
166
+ negative_prompt=None,
167
+ prompt_embeds: Optional[torch.FloatTensor] = None,
168
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
169
+ lora_scale: Optional[float] = None,
170
+ ):
171
+ r"""
172
+ Encodes the prompt into text encoder hidden states.
173
+
174
+ Args:
175
+ prompt (`str` or `List[str]`, *optional*):
176
+ prompt to be encoded
177
+ device: (`torch.device`):
178
+ torch device
179
+ num_images_per_prompt (`int`):
180
+ number of images that should be generated per prompt
181
+ do_classifier_free_guidance (`bool`):
182
+ whether to use classifier free guidance or not
183
+ negative_prompt (`str` or `List[str]`, *optional*):
184
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
185
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
186
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
187
+ prompt_embeds (`torch.FloatTensor`, *optional*):
188
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
189
+ provided, text embeddings will be generated from `prompt` input argument.
190
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
191
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
192
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
193
+ argument.
194
+ """
195
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
196
+
197
+ if do_classifier_free_guidance:
198
+ # For classifier free guidance, we need to do two forward passes.
199
+ # Here we concatenate the unconditional and text embeddings into a single batch
200
+ # to avoid doing two forward passes
201
+ normal_prompt_embeds, color_prompt_embeds = torch.chunk(prompt_embeds, 2, dim=0)
202
+
203
+ prompt_embeds = torch.cat([normal_prompt_embeds, normal_prompt_embeds, color_prompt_embeds, color_prompt_embeds], 0)
204
+
205
+ return prompt_embeds
206
+
207
+ def _encode_image(
208
+ self,
209
+ image_pil,
210
+ device,
211
+ num_images_per_prompt,
212
+ do_classifier_free_guidance,
213
+ noise_level: int=0,
214
+ generator: Optional[torch.Generator] = None
215
+ ):
216
+ dtype = next(self.image_encoder.parameters()).dtype
217
+ # ______________________________clip image embedding______________________________
218
+ image = self.feature_extractor(images=image_pil, return_tensors="pt").pixel_values
219
+ image = image.to(device=device, dtype=dtype)
220
+ image_embeds = self.image_encoder(image).image_embeds
221
+
222
+ image_embeds = self.noise_image_embeddings(
223
+ image_embeds=image_embeds,
224
+ noise_level=noise_level,
225
+ generator=generator,
226
+ )
227
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
228
+ # image_embeds = image_embeds.unsqueeze(1)
229
+ # note: the condition input is same
230
+ image_embeds = image_embeds.repeat(num_images_per_prompt, 1)
231
+
232
+ if do_classifier_free_guidance:
233
+ normal_image_embeds, color_image_embeds = torch.chunk(image_embeds, 2, dim=0)
234
+ negative_prompt_embeds = torch.zeros_like(normal_image_embeds)
235
+
236
+ # For classifier free guidance, we need to do two forward passes.
237
+ # Here we concatenate the unconditional and text embeddings into a single batch
238
+ # to avoid doing two forward passes
239
+ image_embeds = torch.cat([negative_prompt_embeds, normal_image_embeds, negative_prompt_embeds, color_image_embeds], 0)
240
+
241
+ # _____________________________vae input latents__________________________________________________
242
+ image_pt = torch.stack([TF.to_tensor(img) for img in image_pil], dim=0).to(device)
243
+ image_pt = image_pt * 2.0 - 1.0
244
+ image_latents = self.vae.encode(image_pt).latent_dist.mode() * self.vae.config.scaling_factor
245
+ # Note: repeat differently from official pipelines
246
+ image_latents = image_latents.repeat(num_images_per_prompt, 1, 1, 1)
247
+
248
+ if do_classifier_free_guidance:
249
+ normal_image_latents, color_image_latents = torch.chunk(image_latents, 2, dim=0)
250
+ image_latents = torch.cat([torch.zeros_like(normal_image_latents), normal_image_latents,
251
+ torch.zeros_like(color_image_latents), color_image_latents], 0)
252
+
253
+ return image_embeds, image_latents
254
+
255
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
256
+ def decode_latents(self, latents):
257
+ latents = 1 / self.vae.config.scaling_factor * latents
258
+ image = self.vae.decode(latents).sample
259
+ image = (image / 2 + 0.5).clamp(0, 1)
260
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
261
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
262
+ return image
263
+
264
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
265
+ def prepare_extra_step_kwargs(self, generator, eta):
266
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
267
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
268
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
269
+ # and should be between [0, 1]
270
+
271
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
272
+ extra_step_kwargs = {}
273
+ if accepts_eta:
274
+ extra_step_kwargs["eta"] = eta
275
+
276
+ # check if the scheduler accepts generator
277
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
278
+ if accepts_generator:
279
+ extra_step_kwargs["generator"] = generator
280
+ return extra_step_kwargs
281
+
282
+ def check_inputs(
283
+ self,
284
+ prompt,
285
+ image,
286
+ height,
287
+ width,
288
+ callback_steps,
289
+ noise_level,
290
+ ):
291
+ if height % 8 != 0 or width % 8 != 0:
292
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
293
+
294
+ if (callback_steps is None) or (
295
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
296
+ ):
297
+ raise ValueError(
298
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
299
+ f" {type(callback_steps)}."
300
+ )
301
+
302
+ if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
303
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
304
+
305
+
306
+ if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps:
307
+ raise ValueError(
308
+ f"`noise_level` must be between 0 and {self.image_noising_scheduler.config.num_train_timesteps - 1}, inclusive."
309
+ )
310
+
311
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
312
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
313
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
314
+ if isinstance(generator, list) and len(generator) != batch_size:
315
+ raise ValueError(
316
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
317
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
318
+ )
319
+
320
+ if latents is None:
321
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
322
+ else:
323
+ latents = latents.to(device)
324
+
325
+ # scale the initial noise by the standard deviation required by the scheduler
326
+ latents = latents * self.scheduler.init_noise_sigma
327
+ return latents
328
+
329
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_unclip.StableUnCLIPPipeline.noise_image_embeddings
330
+ def noise_image_embeddings(
331
+ self,
332
+ image_embeds: torch.Tensor,
333
+ noise_level: int,
334
+ noise: Optional[torch.FloatTensor] = None,
335
+ generator: Optional[torch.Generator] = None,
336
+ ):
337
+ """
338
+ Add noise to the image embeddings. The amount of noise is controlled by a `noise_level` input. A higher
339
+ `noise_level` increases the variance in the final un-noised images.
340
+
341
+ The noise is applied in two ways
342
+ 1. A noise schedule is applied directly to the embeddings
343
+ 2. A vector of sinusoidal time embeddings are appended to the output.
344
+
345
+ In both cases, the amount of noise is controlled by the same `noise_level`.
346
+
347
+ The embeddings are normalized before the noise is applied and un-normalized after the noise is applied.
348
+ """
349
+ if noise is None:
350
+ noise = randn_tensor(
351
+ image_embeds.shape, generator=generator, device=image_embeds.device, dtype=image_embeds.dtype
352
+ )
353
+
354
+ noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device)
355
+
356
+ image_embeds = self.image_normalizer.scale(image_embeds)
357
+
358
+ image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise)
359
+
360
+ image_embeds = self.image_normalizer.unscale(image_embeds)
361
+
362
+ noise_level = get_timestep_embedding(
363
+ timesteps=noise_level, embedding_dim=image_embeds.shape[-1], flip_sin_to_cos=True, downscale_freq_shift=0
364
+ )
365
+
366
+ # `get_timestep_embeddings` does not contain any weights and will always return f32 tensors,
367
+ # but we might actually be running in fp16. so we need to cast here.
368
+ # there might be better ways to encapsulate this.
369
+ noise_level = noise_level.to(image_embeds.dtype)
370
+
371
+ image_embeds = torch.cat((image_embeds, noise_level), 1)
372
+
373
+ return image_embeds
374
+
375
+ @torch.no_grad()
376
+ # @replace_example_docstring(EXAMPLE_DOC_STRING)
377
+ def __call__(
378
+ self,
379
+ image: Union[torch.FloatTensor, PIL.Image.Image],
380
+ prompt: Union[str, List[str]],
381
+ prompt_embeds: torch.FloatTensor = None,
382
+ dino_feature: torch.FloatTensor = None,
383
+ height: Optional[int] = None,
384
+ width: Optional[int] = None,
385
+ num_inference_steps: int = 20,
386
+ guidance_scale: float = 10,
387
+ negative_prompt: Optional[Union[str, List[str]]] = None,
388
+ num_images_per_prompt: Optional[int] = 1,
389
+ eta: float = 0.0,
390
+ generator: Optional[torch.Generator] = None,
391
+ latents: Optional[torch.FloatTensor] = None,
392
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
393
+ output_type: Optional[str] = "pil",
394
+ return_dict: bool = True,
395
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
396
+ callback_steps: int = 1,
397
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
398
+ noise_level: int = 0,
399
+ image_embeds: Optional[torch.FloatTensor] = None,
400
+ return_elevation_focal: Optional[bool] = False,
401
+ gt_img_in: Optional[torch.FloatTensor] = None,
402
+ ):
403
+ r"""
404
+ Function invoked when calling the pipeline for generation.
405
+
406
+ Args:
407
+ prompt (`str` or `List[str]`, *optional*):
408
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
409
+ instead.
410
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
411
+ `Image`, or tensor representing an image batch. The image will be encoded to its CLIP embedding which
412
+ the unet will be conditioned on. Note that the image is _not_ encoded by the vae and then used as the
413
+ latents in the denoising process such as in the standard stable diffusion text guided image variation
414
+ process.
415
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
416
+ The height in pixels of the generated image.
417
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
418
+ The width in pixels of the generated image.
419
+ num_inference_steps (`int`, *optional*, defaults to 20):
420
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
421
+ expense of slower inference.
422
+ guidance_scale (`float`, *optional*, defaults to 10.0):
423
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
424
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
425
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
426
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
427
+ usually at the expense of lower image quality.
428
+ negative_prompt (`str` or `List[str]`, *optional*):
429
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
430
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
431
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
432
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
433
+ The number of images to generate per prompt.
434
+ eta (`float`, *optional*, defaults to 0.0):
435
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
436
+ [`schedulers.DDIMScheduler`], will be ignored for others.
437
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
438
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
439
+ to make generation deterministic.
440
+ latents (`torch.FloatTensor`, *optional*):
441
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
442
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
443
+ tensor will ge generated by sampling using the supplied random `generator`.
444
+ prompt_embeds (`torch.FloatTensor`, *optional*):
445
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
446
+ provided, text embeddings will be generated from `prompt` input argument.
447
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
448
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
449
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
450
+ argument.
451
+ output_type (`str`, *optional*, defaults to `"pil"`):
452
+ The output format of the generate image. Choose between
453
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
454
+ return_dict (`bool`, *optional*, defaults to `True`):
455
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
456
+ plain tuple.
457
+ callback (`Callable`, *optional*):
458
+ A function that will be called every `callback_steps` steps during inference. The function will be
459
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
460
+ callback_steps (`int`, *optional*, defaults to 1):
461
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
462
+ called at every step.
463
+ cross_attention_kwargs (`dict`, *optional*):
464
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
465
+ `self.processor` in
466
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
467
+ noise_level (`int`, *optional*, defaults to `0`):
468
+ The amount of noise to add to the image embeddings. A higher `noise_level` increases the variance in
469
+ the final un-noised images. See `StableUnCLIPPipeline.noise_image_embeddings` for details.
470
+ image_embeds (`torch.FloatTensor`, *optional*):
471
+ Pre-generated CLIP embeddings to condition the unet on. Note that these are not latents to be used in
472
+ the denoising process. If you want to provide pre-generated latents, pass them to `__call__` as
473
+ `latents`.
474
+
475
+ Examples:
476
+
477
+ Returns:
478
+ [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput`] if `return_dict` is
479
+ True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images.
480
+ """
481
+ # 0. Default height and width to unet
482
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
483
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
484
+
485
+ # 1. Check inputs. Raise error if not correct
486
+ self.check_inputs(
487
+ prompt=prompt,
488
+ image=image,
489
+ height=height,
490
+ width=width,
491
+ callback_steps=callback_steps,
492
+ noise_level=noise_level
493
+ )
494
+
495
+ # 2. Define call parameters
496
+ if isinstance(image, list):
497
+ batch_size = len(image)
498
+ elif isinstance(image, torch.Tensor):
499
+ batch_size = image.shape[0]
500
+ assert batch_size >= self.num_views and batch_size % self.num_views == 0
501
+ elif isinstance(image, PIL.Image.Image):
502
+ image = [image]*self.num_views*2
503
+ batch_size = self.num_views*2
504
+
505
+ if isinstance(prompt, str):
506
+ prompt = [prompt] * self.num_views * 2
507
+
508
+ device = self._execution_device
509
+
510
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
511
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
512
+ # corresponds to doing no classifier free guidance.
513
+ do_classifier_free_guidance = guidance_scale != 1.0
514
+
515
+ # 3. Encode input prompt
516
+ text_encoder_lora_scale = (
517
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
518
+ )
519
+ prompt_embeds = self._encode_prompt(
520
+ prompt=prompt,
521
+ device=device,
522
+ num_images_per_prompt=num_images_per_prompt,
523
+ do_classifier_free_guidance=do_classifier_free_guidance,
524
+ negative_prompt=negative_prompt,
525
+ prompt_embeds=prompt_embeds,
526
+ negative_prompt_embeds=negative_prompt_embeds,
527
+ lora_scale=text_encoder_lora_scale,
528
+ )
529
+
530
+
531
+ # 4. Encoder input image
532
+ if isinstance(image, list):
533
+ image_pil = image
534
+ elif isinstance(image, torch.Tensor):
535
+ image_pil = [TF.to_pil_image(image[i]) for i in range(image.shape[0])]
536
+ noise_level = torch.tensor([noise_level], device=device)
537
+ image_embeds, image_latents = self._encode_image(
538
+ image_pil=image_pil,
539
+ device=device,
540
+ num_images_per_prompt=num_images_per_prompt,
541
+ do_classifier_free_guidance=do_classifier_free_guidance,
542
+ noise_level=noise_level,
543
+ generator=generator,
544
+ )
545
+
546
+ # 5. Prepare timesteps
547
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
548
+ timesteps = self.scheduler.timesteps
549
+
550
+ # 6. Prepare latent variables
551
+ num_channels_latents = self.unet.config.out_channels
552
+ if gt_img_in is not None:
553
+ latents = gt_img_in * self.scheduler.init_noise_sigma
554
+ else:
555
+ latents = self.prepare_latents(
556
+ batch_size=batch_size,
557
+ num_channels_latents=num_channels_latents,
558
+ height=height,
559
+ width=width,
560
+ dtype=prompt_embeds.dtype,
561
+ device=device,
562
+ generator=generator,
563
+ latents=latents,
564
+ )
565
+
566
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
567
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
568
+
569
+ eles, focals = [], []
570
+ # 8. Denoising loop
571
+ for i, t in enumerate(self.progress_bar(timesteps)):
572
+ if do_classifier_free_guidance:
573
+ normal_latents, color_latents = torch.chunk(latents, 2, dim=0)
574
+ latent_model_input = torch.cat([normal_latents, normal_latents, color_latents, color_latents], 0)
575
+ else:
576
+ latent_model_input = latents
577
+ latent_model_input = torch.cat([
578
+ latent_model_input, image_latents
579
+ ], dim=1)
580
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
581
+
582
+ # predict the noise residual
583
+ unet_out = self.unet(
584
+ latent_model_input,
585
+ t,
586
+ encoder_hidden_states=prompt_embeds,
587
+ dino_feature=dino_feature,
588
+ class_labels=image_embeds,
589
+ cross_attention_kwargs=cross_attention_kwargs,
590
+ return_dict=False)
591
+
592
+ noise_pred = unet_out[0]
593
+ if return_elevation_focal:
594
+ uncond_pose, pose = torch.chunk(unet_out[1], 2, 0)
595
+ pose = uncond_pose + guidance_scale * (pose - uncond_pose)
596
+ ele = pose[:, 0].detach().cpu().numpy() # b
597
+ eles.append(ele)
598
+ focal = pose[:, 1].detach().cpu().numpy()
599
+ focals.append(focal)
600
+
601
+ # perform guidance
602
+ if do_classifier_free_guidance:
603
+ normal_noise_pred_uncond, normal_noise_pred_text, color_noise_pred_uncond, color_noise_pred_text = torch.chunk(noise_pred, 4, dim=0)
604
+
605
+ noise_pred_uncond, noise_pred_text = torch.cat([normal_noise_pred_uncond, color_noise_pred_uncond], 0), torch.cat([normal_noise_pred_text, color_noise_pred_text], 0)
606
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
607
+
608
+ # compute the previous noisy sample x_t -> x_t-1
609
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
610
+
611
+ if callback is not None and i % callback_steps == 0:
612
+ callback(i, t, latents)
613
+
614
+ # 9. Post-processing
615
+ if not output_type == "latent":
616
+ if num_channels_latents == 8:
617
+ latents = torch.cat([latents[:, :4], latents[:, 4:]], dim=0)
618
+ with torch.no_grad():
619
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
620
+ else:
621
+ image = latents
622
+
623
+ image = self.image_processor.postprocess(image, output_type=output_type)
624
+
625
+ # Offload last model to CPU
626
+ # if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
627
+ # self.final_offload_hook.offload()
628
+ if not return_dict:
629
+ return (image, )
630
+ if return_elevation_focal:
631
+ return ImagePipelineOutput(images=image), eles, focals
632
+ else:
633
+ return ImagePipelineOutput(images=image)