pengHTYX commited on
Commit
00732de
1 Parent(s): 00e3192
mvdiffusion/models/transformer_mv2d_image.py DELETED
@@ -1,1029 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from dataclasses import dataclass
15
- from typing import Any, Dict, Optional
16
-
17
- import torch
18
- import torch.nn.functional as F
19
- from torch import nn
20
-
21
- from diffusers.configuration_utils import ConfigMixin, register_to_config
22
- from diffusers.models.embeddings import ImagePositionalEmbeddings
23
- from diffusers.utils import BaseOutput, deprecate
24
- from diffusers.utils.torch_utils import maybe_allow_in_graph
25
- from diffusers.models.attention import FeedForward, AdaLayerNorm, AdaLayerNormZero, Attention
26
- from diffusers.models.embeddings import PatchEmbed
27
- from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
28
- from diffusers.models.modeling_utils import ModelMixin
29
- from diffusers.utils.import_utils import is_xformers_available
30
-
31
- from einops import rearrange, repeat
32
- import pdb
33
- import random
34
-
35
-
36
- if is_xformers_available():
37
- import xformers
38
- import xformers.ops
39
- else:
40
- xformers = None
41
-
42
- def my_repeat(tensor, num_repeats):
43
- """
44
- Repeat a tensor along a given dimension
45
- """
46
- if len(tensor.shape) == 3:
47
- return repeat(tensor, "b d c -> (b v) d c", v=num_repeats)
48
- elif len(tensor.shape) == 4:
49
- return repeat(tensor, "a b d c -> (a v) b d c", v=num_repeats)
50
-
51
-
52
- @dataclass
53
- class TransformerMV2DModelOutput(BaseOutput):
54
- """
55
- The output of [`Transformer2DModel`].
56
-
57
- Args:
58
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
59
- The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
60
- distributions for the unnoised latent pixels.
61
- """
62
-
63
- sample: torch.FloatTensor
64
-
65
-
66
- class TransformerMV2DModel(ModelMixin, ConfigMixin):
67
- """
68
- A 2D Transformer model for image-like data.
69
-
70
- Parameters:
71
- num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
72
- attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
73
- in_channels (`int`, *optional*):
74
- The number of channels in the input and output (specify if the input is **continuous**).
75
- num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
76
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
77
- cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
78
- sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
79
- This is fixed during training since it is used to learn a number of position embeddings.
80
- num_vector_embeds (`int`, *optional*):
81
- The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
82
- Includes the class for the masked latent pixel.
83
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
84
- num_embeds_ada_norm ( `int`, *optional*):
85
- The number of diffusion steps used during training. Pass if at least one of the norm_layers is
86
- `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
87
- added to the hidden states.
88
-
89
- During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
90
- attention_bias (`bool`, *optional*):
91
- Configure if the `TransformerBlocks` attention should contain a bias parameter.
92
- """
93
-
94
- @register_to_config
95
- def __init__(
96
- self,
97
- num_attention_heads: int = 16,
98
- attention_head_dim: int = 88,
99
- in_channels: Optional[int] = None,
100
- out_channels: Optional[int] = None,
101
- num_layers: int = 1,
102
- dropout: float = 0.0,
103
- norm_num_groups: int = 32,
104
- cross_attention_dim: Optional[int] = None,
105
- attention_bias: bool = False,
106
- sample_size: Optional[int] = None,
107
- num_vector_embeds: Optional[int] = None,
108
- patch_size: Optional[int] = None,
109
- activation_fn: str = "geglu",
110
- num_embeds_ada_norm: Optional[int] = None,
111
- use_linear_projection: bool = False,
112
- only_cross_attention: bool = False,
113
- upcast_attention: bool = False,
114
- norm_type: str = "layer_norm",
115
- norm_elementwise_affine: bool = True,
116
- num_views: int = 1,
117
- cd_attention_last: bool=False,
118
- cd_attention_mid: bool=False,
119
- multiview_attention: bool=True,
120
- sparse_mv_attention: bool = False,
121
- mvcd_attention: bool=False
122
- ):
123
- super().__init__()
124
- self.use_linear_projection = use_linear_projection
125
- self.num_attention_heads = num_attention_heads
126
- self.attention_head_dim = attention_head_dim
127
- inner_dim = num_attention_heads * attention_head_dim
128
-
129
- # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
130
- # Define whether input is continuous or discrete depending on configuration
131
- self.is_input_continuous = (in_channels is not None) and (patch_size is None)
132
- self.is_input_vectorized = num_vector_embeds is not None
133
- self.is_input_patches = in_channels is not None and patch_size is not None
134
-
135
- if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
136
- deprecation_message = (
137
- f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
138
- " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
139
- " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
140
- " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
141
- " would be very nice if you could open a Pull request for the `transformer/config.json` file"
142
- )
143
- deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
144
- norm_type = "ada_norm"
145
-
146
- if self.is_input_continuous and self.is_input_vectorized:
147
- raise ValueError(
148
- f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
149
- " sure that either `in_channels` or `num_vector_embeds` is None."
150
- )
151
- elif self.is_input_vectorized and self.is_input_patches:
152
- raise ValueError(
153
- f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
154
- " sure that either `num_vector_embeds` or `num_patches` is None."
155
- )
156
- elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
157
- raise ValueError(
158
- f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
159
- f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
160
- )
161
-
162
- # 2. Define input layers
163
- if self.is_input_continuous:
164
- self.in_channels = in_channels
165
-
166
- self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
167
- if use_linear_projection:
168
- self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
169
- else:
170
- self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
171
- elif self.is_input_vectorized:
172
- assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
173
- assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
174
-
175
- self.height = sample_size
176
- self.width = sample_size
177
- self.num_vector_embeds = num_vector_embeds
178
- self.num_latent_pixels = self.height * self.width
179
-
180
- self.latent_image_embedding = ImagePositionalEmbeddings(
181
- num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
182
- )
183
- elif self.is_input_patches:
184
- assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
185
-
186
- self.height = sample_size
187
- self.width = sample_size
188
-
189
- self.patch_size = patch_size
190
- self.pos_embed = PatchEmbed(
191
- height=sample_size,
192
- width=sample_size,
193
- patch_size=patch_size,
194
- in_channels=in_channels,
195
- embed_dim=inner_dim,
196
- )
197
-
198
- # 3. Define transformers blocks
199
- self.transformer_blocks = nn.ModuleList(
200
- [
201
- BasicMVTransformerBlock(
202
- inner_dim,
203
- num_attention_heads,
204
- attention_head_dim,
205
- dropout=dropout,
206
- cross_attention_dim=cross_attention_dim,
207
- activation_fn=activation_fn,
208
- num_embeds_ada_norm=num_embeds_ada_norm,
209
- attention_bias=attention_bias,
210
- only_cross_attention=only_cross_attention,
211
- upcast_attention=upcast_attention,
212
- norm_type=norm_type,
213
- norm_elementwise_affine=norm_elementwise_affine,
214
- num_views=num_views,
215
- cd_attention_last=cd_attention_last,
216
- cd_attention_mid=cd_attention_mid,
217
- multiview_attention=multiview_attention,
218
- sparse_mv_attention=sparse_mv_attention,
219
- mvcd_attention=mvcd_attention
220
- )
221
- for d in range(num_layers)
222
- ]
223
- )
224
-
225
- # 4. Define output layers
226
- self.out_channels = in_channels if out_channels is None else out_channels
227
- if self.is_input_continuous:
228
- # TODO: should use out_channels for continuous projections
229
- if use_linear_projection:
230
- self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
231
- else:
232
- self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
233
- elif self.is_input_vectorized:
234
- self.norm_out = nn.LayerNorm(inner_dim)
235
- self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
236
- elif self.is_input_patches:
237
- self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
238
- self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
239
- self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
240
-
241
- def forward(
242
- self,
243
- hidden_states: torch.Tensor,
244
- encoder_hidden_states: Optional[torch.Tensor] = None,
245
- timestep: Optional[torch.LongTensor] = None,
246
- class_labels: Optional[torch.LongTensor] = None,
247
- cross_attention_kwargs: Dict[str, Any] = None,
248
- attention_mask: Optional[torch.Tensor] = None,
249
- encoder_attention_mask: Optional[torch.Tensor] = None,
250
- return_dict: bool = True,
251
- ):
252
- """
253
- The [`Transformer2DModel`] forward method.
254
-
255
- Args:
256
- hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
257
- Input `hidden_states`.
258
- encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
259
- Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
260
- self-attention.
261
- timestep ( `torch.LongTensor`, *optional*):
262
- Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
263
- class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
264
- Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
265
- `AdaLayerZeroNorm`.
266
- encoder_attention_mask ( `torch.Tensor`, *optional*):
267
- Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
268
-
269
- * Mask `(batch, sequence_length)` True = keep, False = discard.
270
- * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
271
-
272
- If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
273
- above. This bias will be added to the cross-attention scores.
274
- return_dict (`bool`, *optional*, defaults to `True`):
275
- Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
276
- tuple.
277
-
278
- Returns:
279
- If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
280
- `tuple` where the first element is the sample tensor.
281
- """
282
- # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
283
- # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
284
- # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
285
- # expects mask of shape:
286
- # [batch, key_tokens]
287
- # adds singleton query_tokens dimension:
288
- # [batch, 1, key_tokens]
289
- # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
290
- # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
291
- # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
292
- if attention_mask is not None and attention_mask.ndim == 2:
293
- # assume that mask is expressed as:
294
- # (1 = keep, 0 = discard)
295
- # convert mask into a bias that can be added to attention scores:
296
- # (keep = +0, discard = -10000.0)
297
- attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
298
- attention_mask = attention_mask.unsqueeze(1)
299
-
300
- # convert encoder_attention_mask to a bias the same way we do for attention_mask
301
- if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
302
- encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
303
- encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
304
-
305
- # 1. Input
306
- if self.is_input_continuous:
307
- batch, _, height, width = hidden_states.shape
308
- residual = hidden_states
309
-
310
- hidden_states = self.norm(hidden_states)
311
- if not self.use_linear_projection:
312
- hidden_states = self.proj_in(hidden_states)
313
- inner_dim = hidden_states.shape[1]
314
- hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
315
- else:
316
- inner_dim = hidden_states.shape[1]
317
- hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
318
- hidden_states = self.proj_in(hidden_states)
319
- elif self.is_input_vectorized:
320
- hidden_states = self.latent_image_embedding(hidden_states)
321
- elif self.is_input_patches:
322
- hidden_states = self.pos_embed(hidden_states)
323
-
324
- # 2. Blocks
325
- for block in self.transformer_blocks:
326
- hidden_states = block(
327
- hidden_states,
328
- attention_mask=attention_mask,
329
- encoder_hidden_states=encoder_hidden_states,
330
- encoder_attention_mask=encoder_attention_mask,
331
- timestep=timestep,
332
- cross_attention_kwargs=cross_attention_kwargs,
333
- class_labels=class_labels,
334
- )
335
-
336
- # 3. Output
337
- if self.is_input_continuous:
338
- if not self.use_linear_projection:
339
- hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
340
- hidden_states = self.proj_out(hidden_states)
341
- else:
342
- hidden_states = self.proj_out(hidden_states)
343
- hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
344
-
345
- output = hidden_states + residual
346
- elif self.is_input_vectorized:
347
- hidden_states = self.norm_out(hidden_states)
348
- logits = self.out(hidden_states)
349
- # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
350
- logits = logits.permute(0, 2, 1)
351
-
352
- # log(p(x_0))
353
- output = F.log_softmax(logits.double(), dim=1).float()
354
- elif self.is_input_patches:
355
- # TODO: cleanup!
356
- conditioning = self.transformer_blocks[0].norm1.emb(
357
- timestep, class_labels, hidden_dtype=hidden_states.dtype
358
- )
359
- shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
360
- hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
361
- hidden_states = self.proj_out_2(hidden_states)
362
-
363
- # unpatchify
364
- height = width = int(hidden_states.shape[1] ** 0.5)
365
- hidden_states = hidden_states.reshape(
366
- shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
367
- )
368
- hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
369
- output = hidden_states.reshape(
370
- shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
371
- )
372
-
373
- if not return_dict:
374
- return (output,)
375
-
376
- return TransformerMV2DModelOutput(sample=output)
377
-
378
-
379
- @maybe_allow_in_graph
380
- class BasicMVTransformerBlock(nn.Module):
381
- r"""
382
- A basic Transformer block.
383
-
384
- Parameters:
385
- dim (`int`): The number of channels in the input and output.
386
- num_attention_heads (`int`): The number of heads to use for multi-head attention.
387
- attention_head_dim (`int`): The number of channels in each head.
388
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
389
- cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
390
- only_cross_attention (`bool`, *optional*):
391
- Whether to use only cross-attention layers. In this case two cross attention layers are used.
392
- double_self_attention (`bool`, *optional*):
393
- Whether to use two self-attention layers. In this case no cross attention layers are used.
394
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
395
- num_embeds_ada_norm (:
396
- obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
397
- attention_bias (:
398
- obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
399
- """
400
-
401
- def __init__(
402
- self,
403
- dim: int,
404
- num_attention_heads: int,
405
- attention_head_dim: int,
406
- dropout=0.0,
407
- cross_attention_dim: Optional[int] = None,
408
- activation_fn: str = "geglu",
409
- num_embeds_ada_norm: Optional[int] = None,
410
- attention_bias: bool = False,
411
- only_cross_attention: bool = False,
412
- double_self_attention: bool = False,
413
- upcast_attention: bool = False,
414
- norm_elementwise_affine: bool = True,
415
- norm_type: str = "layer_norm",
416
- final_dropout: bool = False,
417
- num_views: int = 1,
418
- cd_attention_last: bool = False,
419
- cd_attention_mid: bool = False,
420
- multiview_attention: bool = True,
421
- sparse_mv_attention: bool = False,
422
- mvcd_attention: bool = False
423
- ):
424
- super().__init__()
425
- self.only_cross_attention = only_cross_attention
426
-
427
- self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
428
- self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
429
-
430
- if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
431
- raise ValueError(
432
- f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
433
- f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
434
- )
435
-
436
- # Define 3 blocks. Each block has its own normalization layer.
437
- # 1. Self-Attn
438
- if self.use_ada_layer_norm:
439
- self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
440
- elif self.use_ada_layer_norm_zero:
441
- self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
442
- else:
443
- self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
444
-
445
- self.multiview_attention = multiview_attention
446
- self.sparse_mv_attention = sparse_mv_attention
447
- self.mvcd_attention = mvcd_attention
448
-
449
- self.attn1 = CustomAttention(
450
- query_dim=dim,
451
- heads=num_attention_heads,
452
- dim_head=attention_head_dim,
453
- dropout=dropout,
454
- bias=attention_bias,
455
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
456
- upcast_attention=upcast_attention,
457
- processor=MVAttnProcessor()
458
- )
459
-
460
- # 2. Cross-Attn
461
- if cross_attention_dim is not None or double_self_attention:
462
- # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
463
- # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
464
- # the second cross attention block.
465
- self.norm2 = (
466
- AdaLayerNorm(dim, num_embeds_ada_norm)
467
- if self.use_ada_layer_norm
468
- else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
469
- )
470
- self.attn2 = Attention(
471
- query_dim=dim,
472
- cross_attention_dim=cross_attention_dim if not double_self_attention else None,
473
- heads=num_attention_heads,
474
- dim_head=attention_head_dim,
475
- dropout=dropout,
476
- bias=attention_bias,
477
- upcast_attention=upcast_attention,
478
- ) # is self-attn if encoder_hidden_states is none
479
- else:
480
- self.norm2 = None
481
- self.attn2 = None
482
-
483
- # 3. Feed-forward
484
- self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
485
- self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
486
-
487
- # let chunk size default to None
488
- self._chunk_size = None
489
- self._chunk_dim = 0
490
-
491
- self.num_views = num_views
492
-
493
- self.cd_attention_last = cd_attention_last
494
-
495
- if self.cd_attention_last:
496
- # Joint task -Attn
497
- self.attn_joint_last = CustomJointAttention(
498
- query_dim=dim,
499
- heads=num_attention_heads,
500
- dim_head=attention_head_dim,
501
- dropout=dropout,
502
- bias=attention_bias,
503
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
504
- upcast_attention=upcast_attention,
505
- processor=JointAttnProcessor()
506
- )
507
- nn.init.zeros_(self.attn_joint_last.to_out[0].weight.data)
508
- self.norm_joint_last = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
509
-
510
-
511
- self.cd_attention_mid = cd_attention_mid
512
-
513
- if self.cd_attention_mid:
514
- print("cross-domain attn in the middle")
515
- # Joint task -Attn
516
- self.attn_joint_mid = CustomJointAttention(
517
- query_dim=dim,
518
- heads=num_attention_heads,
519
- dim_head=attention_head_dim,
520
- dropout=dropout,
521
- bias=attention_bias,
522
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
523
- upcast_attention=upcast_attention,
524
- processor=JointAttnProcessor()
525
- )
526
- nn.init.zeros_(self.attn_joint_mid.to_out[0].weight.data)
527
- self.norm_joint_mid = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
528
-
529
- def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
530
- # Sets chunk feed-forward
531
- self._chunk_size = chunk_size
532
- self._chunk_dim = dim
533
-
534
- def forward(
535
- self,
536
- hidden_states: torch.FloatTensor,
537
- attention_mask: Optional[torch.FloatTensor] = None,
538
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
539
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
540
- timestep: Optional[torch.LongTensor] = None,
541
- cross_attention_kwargs: Dict[str, Any] = None,
542
- class_labels: Optional[torch.LongTensor] = None,
543
- ):
544
- assert attention_mask is None # not supported yet
545
- # Notice that normalization is always applied before the real computation in the following blocks.
546
- # 1. Self-Attention
547
- if self.use_ada_layer_norm:
548
- norm_hidden_states = self.norm1(hidden_states, timestep)
549
- elif self.use_ada_layer_norm_zero:
550
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
551
- hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
552
- )
553
- else:
554
- norm_hidden_states = self.norm1(hidden_states)
555
-
556
- cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
557
-
558
- attn_output = self.attn1(
559
- norm_hidden_states,
560
- encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
561
- attention_mask=attention_mask,
562
- num_views=self.num_views,
563
- multiview_attention=self.multiview_attention,
564
- sparse_mv_attention=self.sparse_mv_attention,
565
- mvcd_attention=self.mvcd_attention,
566
- **cross_attention_kwargs,
567
- )
568
-
569
-
570
- if self.use_ada_layer_norm_zero:
571
- attn_output = gate_msa.unsqueeze(1) * attn_output
572
- hidden_states = attn_output + hidden_states
573
-
574
- # joint attention twice
575
- if self.cd_attention_mid:
576
- norm_hidden_states = (
577
- self.norm_joint_mid(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_mid(hidden_states)
578
- )
579
- hidden_states = self.attn_joint_mid(norm_hidden_states) + hidden_states
580
-
581
- # 2. Cross-Attention
582
- if self.attn2 is not None:
583
- norm_hidden_states = (
584
- self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
585
- )
586
-
587
- attn_output = self.attn2(
588
- norm_hidden_states,
589
- encoder_hidden_states=encoder_hidden_states,
590
- attention_mask=encoder_attention_mask,
591
- **cross_attention_kwargs,
592
- )
593
- hidden_states = attn_output + hidden_states
594
-
595
- # 3. Feed-forward
596
- norm_hidden_states = self.norm3(hidden_states)
597
-
598
- if self.use_ada_layer_norm_zero:
599
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
600
-
601
- if self._chunk_size is not None:
602
- # "feed_forward_chunk_size" can be used to save memory
603
- if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
604
- raise ValueError(
605
- f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
606
- )
607
-
608
- num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
609
- ff_output = torch.cat(
610
- [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
611
- dim=self._chunk_dim,
612
- )
613
- else:
614
- ff_output = self.ff(norm_hidden_states)
615
-
616
- if self.use_ada_layer_norm_zero:
617
- ff_output = gate_mlp.unsqueeze(1) * ff_output
618
-
619
- hidden_states = ff_output + hidden_states
620
-
621
- if self.cd_attention_last:
622
- norm_hidden_states = (
623
- self.norm_joint_last(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_last(hidden_states)
624
- )
625
- hidden_states = self.attn_joint_last(norm_hidden_states) + hidden_states
626
-
627
- return hidden_states
628
-
629
-
630
- class CustomAttention(Attention):
631
- def set_use_memory_efficient_attention_xformers(
632
- self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
633
- ):
634
- processor = XFormersMVAttnProcessor()
635
- self.set_processor(processor)
636
- # print("using xformers attention processor")
637
-
638
-
639
- class CustomJointAttention(Attention):
640
- def set_use_memory_efficient_attention_xformers(
641
- self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
642
- ):
643
- processor = XFormersJointAttnProcessor()
644
- self.set_processor(processor)
645
- # print("using xformers attention processor")
646
-
647
- class MVAttnProcessor:
648
- r"""
649
- Default processor for performing attention-related computations.
650
- """
651
-
652
- def __call__(
653
- self,
654
- attn: Attention,
655
- hidden_states,
656
- encoder_hidden_states=None,
657
- attention_mask=None,
658
- temb=None,
659
- num_views=1,
660
- multiview_attention=True
661
- ):
662
- residual = hidden_states
663
-
664
- if attn.spatial_norm is not None:
665
- hidden_states = attn.spatial_norm(hidden_states, temb)
666
-
667
- input_ndim = hidden_states.ndim
668
-
669
- if input_ndim == 4:
670
- batch_size, channel, height, width = hidden_states.shape
671
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
672
-
673
- batch_size, sequence_length, _ = (
674
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
675
- )
676
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
677
-
678
- if attn.group_norm is not None:
679
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
680
-
681
- query = attn.to_q(hidden_states)
682
-
683
- if encoder_hidden_states is None:
684
- encoder_hidden_states = hidden_states
685
- elif attn.norm_cross:
686
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
687
-
688
- key = attn.to_k(encoder_hidden_states)
689
- value = attn.to_v(encoder_hidden_states)
690
-
691
- # print('query', query.shape, 'key', key.shape, 'value', value.shape)
692
- #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320])
693
- # pdb.set_trace()
694
- # multi-view self-attention
695
- if multiview_attention:
696
- if num_views <= 6:
697
- # after use xformer; possible to train with 6 views
698
- key = rearrange(key, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
699
- value = rearrange(value, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
700
- else:# apply sparse attention
701
- pass
702
- # print("use sparse attention")
703
- # # seems that the sparse random sampling cause problems
704
- # # don't use random sampling, just fix the indexes
705
- # onekey = rearrange(key, "(b t) d c -> b t d c", t=num_views)
706
- # onevalue = rearrange(value, "(b t) d c -> b t d c", t=num_views)
707
- # allkeys = []
708
- # allvalues = []
709
- # all_indexes = {
710
- # 0 : [0, 2, 3, 4],
711
- # 1: [0, 1, 3, 5],
712
- # 2: [0, 2, 3, 4],
713
- # 3: [0, 2, 3, 4],
714
- # 4: [0, 2, 3, 4],
715
- # 5: [0, 1, 3, 5]
716
- # }
717
- # for jj in range(num_views):
718
- # # valid_index = [x for x in range(0, num_views) if x!= jj]
719
- # # indexes = random.sample(valid_index, 3) + [jj] + [0]
720
- # indexes = all_indexes[jj]
721
-
722
- # indexes = torch.tensor(indexes).long().to(key.device)
723
- # allkeys.append(onekey[:, indexes])
724
- # allvalues.append(onevalue[:, indexes])
725
- # keys = torch.stack(allkeys, dim=1) # checked, should be dim=1
726
- # values = torch.stack(allvalues, dim=1)
727
- # key = rearrange(keys, 'b t f d c -> (b t) (f d) c')
728
- # value = rearrange(values, 'b t f d c -> (b t) (f d) c')
729
-
730
-
731
- query = attn.head_to_batch_dim(query).contiguous()
732
- key = attn.head_to_batch_dim(key).contiguous()
733
- value = attn.head_to_batch_dim(value).contiguous()
734
-
735
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
736
- hidden_states = torch.bmm(attention_probs, value)
737
- hidden_states = attn.batch_to_head_dim(hidden_states)
738
-
739
- # linear proj
740
- hidden_states = attn.to_out[0](hidden_states)
741
- # dropout
742
- hidden_states = attn.to_out[1](hidden_states)
743
-
744
- if input_ndim == 4:
745
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
746
-
747
- if attn.residual_connection:
748
- hidden_states = hidden_states + residual
749
-
750
- hidden_states = hidden_states / attn.rescale_output_factor
751
-
752
- return hidden_states
753
-
754
-
755
- class XFormersMVAttnProcessor:
756
- r"""
757
- Default processor for performing attention-related computations.
758
- """
759
-
760
- def __call__(
761
- self,
762
- attn: Attention,
763
- hidden_states,
764
- encoder_hidden_states=None,
765
- attention_mask=None,
766
- temb=None,
767
- num_views=1.,
768
- multiview_attention=True,
769
- sparse_mv_attention=False,
770
- mvcd_attention=False,
771
- ):
772
- residual = hidden_states
773
-
774
- if attn.spatial_norm is not None:
775
- hidden_states = attn.spatial_norm(hidden_states, temb)
776
-
777
- input_ndim = hidden_states.ndim
778
-
779
- if input_ndim == 4:
780
- batch_size, channel, height, width = hidden_states.shape
781
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
782
-
783
- batch_size, sequence_length, _ = (
784
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
785
- )
786
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
787
-
788
- # from yuancheng; here attention_mask is None
789
- if attention_mask is not None:
790
- # expand our mask's singleton query_tokens dimension:
791
- # [batch*heads, 1, key_tokens] ->
792
- # [batch*heads, query_tokens, key_tokens]
793
- # so that it can be added as a bias onto the attention scores that xformers computes:
794
- # [batch*heads, query_tokens, key_tokens]
795
- # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
796
- _, query_tokens, _ = hidden_states.shape
797
- attention_mask = attention_mask.expand(-1, query_tokens, -1)
798
-
799
- if attn.group_norm is not None:
800
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
801
-
802
- query = attn.to_q(hidden_states)
803
-
804
- if encoder_hidden_states is None:
805
- encoder_hidden_states = hidden_states
806
- elif attn.norm_cross:
807
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
808
-
809
- key_raw = attn.to_k(encoder_hidden_states)
810
- value_raw = attn.to_v(encoder_hidden_states)
811
-
812
- # print('query', query.shape, 'key', key.shape, 'value', value.shape)
813
- #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320])
814
- # pdb.set_trace()
815
- # multi-view self-attention
816
- if multiview_attention:
817
- if not sparse_mv_attention:
818
- key = my_repeat(rearrange(key_raw, "(b t) d c -> b (t d) c", t=num_views), num_views)
819
- value = my_repeat(rearrange(value_raw, "(b t) d c -> b (t d) c", t=num_views), num_views)
820
- else:
821
- key_front = my_repeat(rearrange(key_raw, "(b t) d c -> b t d c", t=num_views)[:, 0, :, :], num_views) # [(b t), d, c]
822
- value_front = my_repeat(rearrange(value_raw, "(b t) d c -> b t d c", t=num_views)[:, 0, :, :], num_views)
823
- key = torch.cat([key_front, key_raw], dim=1) # shape (b t) (2 d) c
824
- value = torch.cat([value_front, value_raw], dim=1)
825
-
826
- if mvcd_attention:
827
- # memory efficient, cross domain attention
828
- key_0, key_1 = torch.chunk(key_raw, dim=0, chunks=2) # keys shape (b t) d c
829
- value_0, value_1 = torch.chunk(value_raw, dim=0, chunks=2)
830
- key_cross = torch.concat([key_1, key_0], dim=0)
831
- value_cross = torch.concat([value_1, value_0], dim=0) # shape (b t) d c
832
- key = torch.cat([key, key_cross], dim=1)
833
- value = torch.cat([value, value_cross], dim=1) # shape (b t) (t+1 d) c
834
- else:
835
- # print("don't use multiview attention.")
836
- key = key_raw
837
- value = value_raw
838
-
839
- query = attn.head_to_batch_dim(query)
840
- key = attn.head_to_batch_dim(key)
841
- value = attn.head_to_batch_dim(value)
842
-
843
- hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
844
- hidden_states = attn.batch_to_head_dim(hidden_states)
845
-
846
- # linear proj
847
- hidden_states = attn.to_out[0](hidden_states)
848
- # dropout
849
- hidden_states = attn.to_out[1](hidden_states)
850
-
851
- if input_ndim == 4:
852
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
853
-
854
- if attn.residual_connection:
855
- hidden_states = hidden_states + residual
856
-
857
- hidden_states = hidden_states / attn.rescale_output_factor
858
-
859
- return hidden_states
860
-
861
-
862
-
863
- class XFormersJointAttnProcessor:
864
- r"""
865
- Default processor for performing attention-related computations.
866
- """
867
-
868
- def __call__(
869
- self,
870
- attn: Attention,
871
- hidden_states,
872
- encoder_hidden_states=None,
873
- attention_mask=None,
874
- temb=None,
875
- num_tasks=2
876
- ):
877
-
878
- residual = hidden_states
879
-
880
- if attn.spatial_norm is not None:
881
- hidden_states = attn.spatial_norm(hidden_states, temb)
882
-
883
- input_ndim = hidden_states.ndim
884
-
885
- if input_ndim == 4:
886
- batch_size, channel, height, width = hidden_states.shape
887
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
888
-
889
- batch_size, sequence_length, _ = (
890
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
891
- )
892
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
893
-
894
- # from yuancheng; here attention_mask is None
895
- if attention_mask is not None:
896
- # expand our mask's singleton query_tokens dimension:
897
- # [batch*heads, 1, key_tokens] ->
898
- # [batch*heads, query_tokens, key_tokens]
899
- # so that it can be added as a bias onto the attention scores that xformers computes:
900
- # [batch*heads, query_tokens, key_tokens]
901
- # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
902
- _, query_tokens, _ = hidden_states.shape
903
- attention_mask = attention_mask.expand(-1, query_tokens, -1)
904
-
905
- if attn.group_norm is not None:
906
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
907
-
908
- query = attn.to_q(hidden_states)
909
-
910
- if encoder_hidden_states is None:
911
- encoder_hidden_states = hidden_states
912
- elif attn.norm_cross:
913
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
914
-
915
- key = attn.to_k(encoder_hidden_states)
916
- value = attn.to_v(encoder_hidden_states)
917
-
918
- assert num_tasks == 2 # only support two tasks now
919
-
920
- key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c
921
- value_0, value_1 = torch.chunk(value, dim=0, chunks=2)
922
- key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c
923
- value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c
924
- key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c
925
- value = torch.cat([value]*2, dim=0) # (2 b t) 2d c
926
-
927
-
928
- query = attn.head_to_batch_dim(query).contiguous()
929
- key = attn.head_to_batch_dim(key).contiguous()
930
- value = attn.head_to_batch_dim(value).contiguous()
931
-
932
- hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
933
- hidden_states = attn.batch_to_head_dim(hidden_states)
934
-
935
- # linear proj
936
- hidden_states = attn.to_out[0](hidden_states)
937
- # dropout
938
- hidden_states = attn.to_out[1](hidden_states)
939
-
940
- if input_ndim == 4:
941
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
942
-
943
- if attn.residual_connection:
944
- hidden_states = hidden_states + residual
945
-
946
- hidden_states = hidden_states / attn.rescale_output_factor
947
-
948
- return hidden_states
949
-
950
-
951
- class JointAttnProcessor:
952
- r"""
953
- Default processor for performing attention-related computations.
954
- """
955
-
956
- def __call__(
957
- self,
958
- attn: Attention,
959
- hidden_states,
960
- encoder_hidden_states=None,
961
- attention_mask=None,
962
- temb=None,
963
- num_tasks=2
964
- ):
965
-
966
- residual = hidden_states
967
-
968
- if attn.spatial_norm is not None:
969
- hidden_states = attn.spatial_norm(hidden_states, temb)
970
-
971
- input_ndim = hidden_states.ndim
972
-
973
- if input_ndim == 4:
974
- batch_size, channel, height, width = hidden_states.shape
975
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
976
-
977
- batch_size, sequence_length, _ = (
978
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
979
- )
980
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
981
-
982
-
983
- if attn.group_norm is not None:
984
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
985
-
986
- query = attn.to_q(hidden_states)
987
-
988
- if encoder_hidden_states is None:
989
- encoder_hidden_states = hidden_states
990
- elif attn.norm_cross:
991
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
992
-
993
- key = attn.to_k(encoder_hidden_states)
994
- value = attn.to_v(encoder_hidden_states)
995
-
996
- assert num_tasks == 2 # only support two tasks now
997
-
998
- key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c
999
- value_0, value_1 = torch.chunk(value, dim=0, chunks=2)
1000
- key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c
1001
- value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c
1002
- key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c
1003
- value = torch.cat([value]*2, dim=0) # (2 b t) 2d c
1004
-
1005
-
1006
- query = attn.head_to_batch_dim(query).contiguous()
1007
- key = attn.head_to_batch_dim(key).contiguous()
1008
- value = attn.head_to_batch_dim(value).contiguous()
1009
-
1010
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
1011
- hidden_states = torch.bmm(attention_probs, value)
1012
- hidden_states = attn.batch_to_head_dim(hidden_states)
1013
-
1014
- # linear proj
1015
- hidden_states = attn.to_out[0](hidden_states)
1016
- # dropout
1017
- hidden_states = attn.to_out[1](hidden_states)
1018
-
1019
- if input_ndim == 4:
1020
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1021
-
1022
- if attn.residual_connection:
1023
- hidden_states = hidden_states + residual
1024
-
1025
- hidden_states = hidden_states / attn.rescale_output_factor
1026
-
1027
- return hidden_states
1028
-
1029
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mvdiffusion/models/transformer_mv2d_rowwise.py DELETED
@@ -1,978 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from dataclasses import dataclass
15
- from typing import Any, Dict, Optional
16
-
17
- import torch
18
- import torch.nn.functional as F
19
- from torch import nn
20
-
21
- from diffusers.configuration_utils import ConfigMixin, register_to_config
22
- from diffusers.models.embeddings import ImagePositionalEmbeddings
23
- from diffusers.utils import BaseOutput, deprecate
24
- from diffusers.utils.torch_utils import maybe_allow_in_graph
25
- from diffusers.models.attention import FeedForward, AdaLayerNorm, AdaLayerNormZero, Attention
26
- from diffusers.models.embeddings import PatchEmbed
27
- from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
28
- from diffusers.models.modeling_utils import ModelMixin
29
- from diffusers.utils.import_utils import is_xformers_available
30
-
31
- from einops import rearrange
32
- import pdb
33
- import random
34
- import math
35
-
36
-
37
- if is_xformers_available():
38
- import xformers
39
- import xformers.ops
40
- else:
41
- xformers = None
42
-
43
-
44
- @dataclass
45
- class TransformerMV2DModelOutput(BaseOutput):
46
- """
47
- The output of [`Transformer2DModel`].
48
-
49
- Args:
50
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
51
- The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
52
- distributions for the unnoised latent pixels.
53
- """
54
-
55
- sample: torch.FloatTensor
56
-
57
-
58
- class TransformerMV2DModel(ModelMixin, ConfigMixin):
59
- """
60
- A 2D Transformer model for image-like data.
61
-
62
- Parameters:
63
- num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
64
- attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
65
- in_channels (`int`, *optional*):
66
- The number of channels in the input and output (specify if the input is **continuous**).
67
- num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
68
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
69
- cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
70
- sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
71
- This is fixed during training since it is used to learn a number of position embeddings.
72
- num_vector_embeds (`int`, *optional*):
73
- The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
74
- Includes the class for the masked latent pixel.
75
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
76
- num_embeds_ada_norm ( `int`, *optional*):
77
- The number of diffusion steps used during training. Pass if at least one of the norm_layers is
78
- `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
79
- added to the hidden states.
80
-
81
- During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
82
- attention_bias (`bool`, *optional*):
83
- Configure if the `TransformerBlocks` attention should contain a bias parameter.
84
- """
85
-
86
- @register_to_config
87
- def __init__(
88
- self,
89
- num_attention_heads: int = 16,
90
- attention_head_dim: int = 88,
91
- in_channels: Optional[int] = None,
92
- out_channels: Optional[int] = None,
93
- num_layers: int = 1,
94
- dropout: float = 0.0,
95
- norm_num_groups: int = 32,
96
- cross_attention_dim: Optional[int] = None,
97
- attention_bias: bool = False,
98
- sample_size: Optional[int] = None,
99
- num_vector_embeds: Optional[int] = None,
100
- patch_size: Optional[int] = None,
101
- activation_fn: str = "geglu",
102
- num_embeds_ada_norm: Optional[int] = None,
103
- use_linear_projection: bool = False,
104
- only_cross_attention: bool = False,
105
- upcast_attention: bool = False,
106
- norm_type: str = "layer_norm",
107
- norm_elementwise_affine: bool = True,
108
- num_views: int = 1,
109
- cd_attention_last: bool=False,
110
- cd_attention_mid: bool=False,
111
- multiview_attention: bool=True,
112
- sparse_mv_attention: bool = True, # not used
113
- mvcd_attention: bool=False
114
- ):
115
- super().__init__()
116
- self.use_linear_projection = use_linear_projection
117
- self.num_attention_heads = num_attention_heads
118
- self.attention_head_dim = attention_head_dim
119
- inner_dim = num_attention_heads * attention_head_dim
120
-
121
- # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
122
- # Define whether input is continuous or discrete depending on configuration
123
- self.is_input_continuous = (in_channels is not None) and (patch_size is None)
124
- self.is_input_vectorized = num_vector_embeds is not None
125
- self.is_input_patches = in_channels is not None and patch_size is not None
126
-
127
- if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
128
- deprecation_message = (
129
- f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
130
- " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
131
- " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
132
- " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
133
- " would be very nice if you could open a Pull request for the `transformer/config.json` file"
134
- )
135
- deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
136
- norm_type = "ada_norm"
137
-
138
- if self.is_input_continuous and self.is_input_vectorized:
139
- raise ValueError(
140
- f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
141
- " sure that either `in_channels` or `num_vector_embeds` is None."
142
- )
143
- elif self.is_input_vectorized and self.is_input_patches:
144
- raise ValueError(
145
- f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
146
- " sure that either `num_vector_embeds` or `num_patches` is None."
147
- )
148
- elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
149
- raise ValueError(
150
- f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
151
- f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
152
- )
153
-
154
- # 2. Define input layers
155
- if self.is_input_continuous:
156
- self.in_channels = in_channels
157
-
158
- self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
159
- if use_linear_projection:
160
- self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
161
- else:
162
- self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
163
- elif self.is_input_vectorized:
164
- assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
165
- assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
166
-
167
- self.height = sample_size
168
- self.width = sample_size
169
- self.num_vector_embeds = num_vector_embeds
170
- self.num_latent_pixels = self.height * self.width
171
-
172
- self.latent_image_embedding = ImagePositionalEmbeddings(
173
- num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
174
- )
175
- elif self.is_input_patches:
176
- assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
177
-
178
- self.height = sample_size
179
- self.width = sample_size
180
-
181
- self.patch_size = patch_size
182
- self.pos_embed = PatchEmbed(
183
- height=sample_size,
184
- width=sample_size,
185
- patch_size=patch_size,
186
- in_channels=in_channels,
187
- embed_dim=inner_dim,
188
- )
189
-
190
- # 3. Define transformers blocks
191
- self.transformer_blocks = nn.ModuleList(
192
- [
193
- BasicMVTransformerBlock(
194
- inner_dim,
195
- num_attention_heads,
196
- attention_head_dim,
197
- dropout=dropout,
198
- cross_attention_dim=cross_attention_dim,
199
- activation_fn=activation_fn,
200
- num_embeds_ada_norm=num_embeds_ada_norm,
201
- attention_bias=attention_bias,
202
- only_cross_attention=only_cross_attention,
203
- upcast_attention=upcast_attention,
204
- norm_type=norm_type,
205
- norm_elementwise_affine=norm_elementwise_affine,
206
- num_views=num_views,
207
- cd_attention_last=cd_attention_last,
208
- cd_attention_mid=cd_attention_mid,
209
- multiview_attention=multiview_attention,
210
- mvcd_attention=mvcd_attention
211
- )
212
- for d in range(num_layers)
213
- ]
214
- )
215
-
216
- # 4. Define output layers
217
- self.out_channels = in_channels if out_channels is None else out_channels
218
- if self.is_input_continuous:
219
- # TODO: should use out_channels for continuous projections
220
- if use_linear_projection:
221
- self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
222
- else:
223
- self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
224
- elif self.is_input_vectorized:
225
- self.norm_out = nn.LayerNorm(inner_dim)
226
- self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
227
- elif self.is_input_patches:
228
- self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
229
- self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
230
- self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
231
-
232
- def forward(
233
- self,
234
- hidden_states: torch.Tensor,
235
- encoder_hidden_states: Optional[torch.Tensor] = None,
236
- timestep: Optional[torch.LongTensor] = None,
237
- class_labels: Optional[torch.LongTensor] = None,
238
- cross_attention_kwargs: Dict[str, Any] = None,
239
- attention_mask: Optional[torch.Tensor] = None,
240
- encoder_attention_mask: Optional[torch.Tensor] = None,
241
- return_dict: bool = True,
242
- ):
243
- """
244
- The [`Transformer2DModel`] forward method.
245
-
246
- Args:
247
- hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
248
- Input `hidden_states`.
249
- encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
250
- Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
251
- self-attention.
252
- timestep ( `torch.LongTensor`, *optional*):
253
- Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
254
- class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
255
- Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
256
- `AdaLayerZeroNorm`.
257
- encoder_attention_mask ( `torch.Tensor`, *optional*):
258
- Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
259
-
260
- * Mask `(batch, sequence_length)` True = keep, False = discard.
261
- * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
262
-
263
- If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
264
- above. This bias will be added to the cross-attention scores.
265
- return_dict (`bool`, *optional*, defaults to `True`):
266
- Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
267
- tuple.
268
-
269
- Returns:
270
- If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
271
- `tuple` where the first element is the sample tensor.
272
- """
273
- # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
274
- # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
275
- # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
276
- # expects mask of shape:
277
- # [batch, key_tokens]
278
- # adds singleton query_tokens dimension:
279
- # [batch, 1, key_tokens]
280
- # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
281
- # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
282
- # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
283
- if attention_mask is not None and attention_mask.ndim == 2:
284
- # assume that mask is expressed as:
285
- # (1 = keep, 0 = discard)
286
- # convert mask into a bias that can be added to attention scores:
287
- # (keep = +0, discard = -10000.0)
288
- attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
289
- attention_mask = attention_mask.unsqueeze(1)
290
-
291
- # convert encoder_attention_mask to a bias the same way we do for attention_mask
292
- if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
293
- encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
294
- encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
295
-
296
- # 1. Input
297
- if self.is_input_continuous:
298
- batch, _, height, width = hidden_states.shape
299
- residual = hidden_states
300
-
301
- hidden_states = self.norm(hidden_states)
302
- if not self.use_linear_projection:
303
- hidden_states = self.proj_in(hidden_states)
304
- inner_dim = hidden_states.shape[1]
305
- hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
306
- else:
307
- inner_dim = hidden_states.shape[1]
308
- hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
309
- hidden_states = self.proj_in(hidden_states)
310
- elif self.is_input_vectorized:
311
- hidden_states = self.latent_image_embedding(hidden_states)
312
- elif self.is_input_patches:
313
- hidden_states = self.pos_embed(hidden_states)
314
-
315
- # 2. Blocks
316
- for block in self.transformer_blocks:
317
- hidden_states = block(
318
- hidden_states,
319
- attention_mask=attention_mask,
320
- encoder_hidden_states=encoder_hidden_states,
321
- encoder_attention_mask=encoder_attention_mask,
322
- timestep=timestep,
323
- cross_attention_kwargs=cross_attention_kwargs,
324
- class_labels=class_labels,
325
- )
326
-
327
- # 3. Output
328
- if self.is_input_continuous:
329
- if not self.use_linear_projection:
330
- hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
331
- hidden_states = self.proj_out(hidden_states)
332
- else:
333
- hidden_states = self.proj_out(hidden_states)
334
- hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
335
-
336
- output = hidden_states + residual
337
- elif self.is_input_vectorized:
338
- hidden_states = self.norm_out(hidden_states)
339
- logits = self.out(hidden_states)
340
- # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
341
- logits = logits.permute(0, 2, 1)
342
-
343
- # log(p(x_0))
344
- output = F.log_softmax(logits.double(), dim=1).float()
345
- elif self.is_input_patches:
346
- # TODO: cleanup!
347
- conditioning = self.transformer_blocks[0].norm1.emb(
348
- timestep, class_labels, hidden_dtype=hidden_states.dtype
349
- )
350
- shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
351
- hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
352
- hidden_states = self.proj_out_2(hidden_states)
353
-
354
- # unpatchify
355
- height = width = int(hidden_states.shape[1] ** 0.5)
356
- hidden_states = hidden_states.reshape(
357
- shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
358
- )
359
- hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
360
- output = hidden_states.reshape(
361
- shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
362
- )
363
-
364
- if not return_dict:
365
- return (output,)
366
-
367
- return TransformerMV2DModelOutput(sample=output)
368
-
369
-
370
- @maybe_allow_in_graph
371
- class BasicMVTransformerBlock(nn.Module):
372
- r"""
373
- A basic Transformer block.
374
-
375
- Parameters:
376
- dim (`int`): The number of channels in the input and output.
377
- num_attention_heads (`int`): The number of heads to use for multi-head attention.
378
- attention_head_dim (`int`): The number of channels in each head.
379
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
380
- cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
381
- only_cross_attention (`bool`, *optional*):
382
- Whether to use only cross-attention layers. In this case two cross attention layers are used.
383
- double_self_attention (`bool`, *optional*):
384
- Whether to use two self-attention layers. In this case no cross attention layers are used.
385
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
386
- num_embeds_ada_norm (:
387
- obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
388
- attention_bias (:
389
- obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
390
- """
391
-
392
- def __init__(
393
- self,
394
- dim: int,
395
- num_attention_heads: int,
396
- attention_head_dim: int,
397
- dropout=0.0,
398
- cross_attention_dim: Optional[int] = None,
399
- activation_fn: str = "geglu",
400
- num_embeds_ada_norm: Optional[int] = None,
401
- attention_bias: bool = False,
402
- only_cross_attention: bool = False,
403
- double_self_attention: bool = False,
404
- upcast_attention: bool = False,
405
- norm_elementwise_affine: bool = True,
406
- norm_type: str = "layer_norm",
407
- final_dropout: bool = False,
408
- num_views: int = 1,
409
- cd_attention_last: bool = False,
410
- cd_attention_mid: bool = False,
411
- multiview_attention: bool = True,
412
- mvcd_attention: bool = False,
413
- rowwise_attention: bool = True
414
- ):
415
- super().__init__()
416
- self.only_cross_attention = only_cross_attention
417
-
418
- self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
419
- self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
420
-
421
- if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
422
- raise ValueError(
423
- f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
424
- f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
425
- )
426
-
427
- # Define 3 blocks. Each block has its own normalization layer.
428
- # 1. Self-Attn
429
- if self.use_ada_layer_norm:
430
- self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
431
- elif self.use_ada_layer_norm_zero:
432
- self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
433
- else:
434
- self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
435
-
436
- self.multiview_attention = multiview_attention
437
- self.mvcd_attention = mvcd_attention
438
- self.rowwise_attention = multiview_attention and rowwise_attention
439
-
440
- # rowwise multiview attention
441
-
442
- print('INFO: using row wise attention...')
443
-
444
- self.attn1 = CustomAttention(
445
- query_dim=dim,
446
- heads=num_attention_heads,
447
- dim_head=attention_head_dim,
448
- dropout=dropout,
449
- bias=attention_bias,
450
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
451
- upcast_attention=upcast_attention,
452
- processor=MVAttnProcessor()
453
- )
454
-
455
- # 2. Cross-Attn
456
- if cross_attention_dim is not None or double_self_attention:
457
- # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
458
- # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
459
- # the second cross attention block.
460
- self.norm2 = (
461
- AdaLayerNorm(dim, num_embeds_ada_norm)
462
- if self.use_ada_layer_norm
463
- else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
464
- )
465
- self.attn2 = Attention(
466
- query_dim=dim,
467
- cross_attention_dim=cross_attention_dim if not double_self_attention else None,
468
- heads=num_attention_heads,
469
- dim_head=attention_head_dim,
470
- dropout=dropout,
471
- bias=attention_bias,
472
- upcast_attention=upcast_attention,
473
- ) # is self-attn if encoder_hidden_states is none
474
- else:
475
- self.norm2 = None
476
- self.attn2 = None
477
-
478
- # 3. Feed-forward
479
- self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
480
- self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
481
-
482
- # let chunk size default to None
483
- self._chunk_size = None
484
- self._chunk_dim = 0
485
-
486
- self.num_views = num_views
487
-
488
- self.cd_attention_last = cd_attention_last
489
-
490
- if self.cd_attention_last:
491
- # Joint task -Attn
492
- self.attn_joint = CustomJointAttention(
493
- query_dim=dim,
494
- heads=num_attention_heads,
495
- dim_head=attention_head_dim,
496
- dropout=dropout,
497
- bias=attention_bias,
498
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
499
- upcast_attention=upcast_attention,
500
- processor=JointAttnProcessor()
501
- )
502
- nn.init.zeros_(self.attn_joint.to_out[0].weight.data)
503
- self.norm_joint = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
504
-
505
-
506
- self.cd_attention_mid = cd_attention_mid
507
-
508
- if self.cd_attention_mid:
509
- print("joint twice")
510
- # Joint task -Attn
511
- self.attn_joint_twice = CustomJointAttention(
512
- query_dim=dim,
513
- heads=num_attention_heads,
514
- dim_head=attention_head_dim,
515
- dropout=dropout,
516
- bias=attention_bias,
517
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
518
- upcast_attention=upcast_attention,
519
- processor=JointAttnProcessor()
520
- )
521
- nn.init.zeros_(self.attn_joint_twice.to_out[0].weight.data)
522
- self.norm_joint_twice = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
523
-
524
- def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
525
- # Sets chunk feed-forward
526
- self._chunk_size = chunk_size
527
- self._chunk_dim = dim
528
-
529
- def forward(
530
- self,
531
- hidden_states: torch.FloatTensor,
532
- attention_mask: Optional[torch.FloatTensor] = None,
533
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
534
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
535
- timestep: Optional[torch.LongTensor] = None,
536
- cross_attention_kwargs: Dict[str, Any] = None,
537
- class_labels: Optional[torch.LongTensor] = None,
538
- ):
539
- assert attention_mask is None # not supported yet
540
- # Notice that normalization is always applied before the real computation in the following blocks.
541
- # 1. Self-Attention
542
- if self.use_ada_layer_norm:
543
- norm_hidden_states = self.norm1(hidden_states, timestep)
544
- elif self.use_ada_layer_norm_zero:
545
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
546
- hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
547
- )
548
- else:
549
- norm_hidden_states = self.norm1(hidden_states)
550
-
551
- cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
552
-
553
- attn_output = self.attn1(
554
- norm_hidden_states,
555
- encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
556
- attention_mask=attention_mask,
557
- multiview_attention=self.multiview_attention,
558
- mvcd_attention=self.mvcd_attention,
559
- num_views=self.num_views,
560
- **cross_attention_kwargs,
561
- )
562
-
563
- if self.use_ada_layer_norm_zero:
564
- attn_output = gate_msa.unsqueeze(1) * attn_output
565
- hidden_states = attn_output + hidden_states
566
-
567
- # joint attention twice
568
- if self.cd_attention_mid:
569
- norm_hidden_states = (
570
- self.norm_joint_twice(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_twice(hidden_states)
571
- )
572
- hidden_states = self.attn_joint_twice(norm_hidden_states) + hidden_states
573
-
574
- # 2. Cross-Attention
575
- if self.attn2 is not None:
576
- norm_hidden_states = (
577
- self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
578
- )
579
-
580
- attn_output = self.attn2(
581
- norm_hidden_states,
582
- encoder_hidden_states=encoder_hidden_states,
583
- attention_mask=encoder_attention_mask,
584
- **cross_attention_kwargs,
585
- )
586
- hidden_states = attn_output + hidden_states
587
-
588
- # 3. Feed-forward
589
- norm_hidden_states = self.norm3(hidden_states)
590
-
591
- if self.use_ada_layer_norm_zero:
592
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
593
-
594
- if self._chunk_size is not None:
595
- # "feed_forward_chunk_size" can be used to save memory
596
- if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
597
- raise ValueError(
598
- f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
599
- )
600
-
601
- num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
602
- ff_output = torch.cat(
603
- [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
604
- dim=self._chunk_dim,
605
- )
606
- else:
607
- ff_output = self.ff(norm_hidden_states)
608
-
609
- if self.use_ada_layer_norm_zero:
610
- ff_output = gate_mlp.unsqueeze(1) * ff_output
611
-
612
- hidden_states = ff_output + hidden_states
613
-
614
- if self.cd_attention_last:
615
- norm_hidden_states = (
616
- self.norm_joint(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint(hidden_states)
617
- )
618
- hidden_states = self.attn_joint(norm_hidden_states) + hidden_states
619
-
620
- return hidden_states
621
-
622
-
623
- class CustomAttention(Attention):
624
- def set_use_memory_efficient_attention_xformers(
625
- self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
626
- ):
627
- processor = XFormersMVAttnProcessor()
628
- self.set_processor(processor)
629
- # print("using xformers attention processor")
630
-
631
-
632
- class CustomJointAttention(Attention):
633
- def set_use_memory_efficient_attention_xformers(
634
- self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
635
- ):
636
- processor = XFormersJointAttnProcessor()
637
- self.set_processor(processor)
638
- # print("using xformers attention processor")
639
-
640
- class MVAttnProcessor:
641
- r"""
642
- Default processor for performing attention-related computations.
643
- """
644
-
645
- def __call__(
646
- self,
647
- attn: Attention,
648
- hidden_states,
649
- encoder_hidden_states=None,
650
- attention_mask=None,
651
- temb=None,
652
- num_views=1,
653
- multiview_attention=True
654
- ):
655
- residual = hidden_states
656
-
657
- if attn.spatial_norm is not None:
658
- hidden_states = attn.spatial_norm(hidden_states, temb)
659
-
660
- input_ndim = hidden_states.ndim
661
-
662
- if input_ndim == 4:
663
- batch_size, channel, height, width = hidden_states.shape
664
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
665
-
666
- batch_size, sequence_length, _ = (
667
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
668
- )
669
- height = int(math.sqrt(sequence_length))
670
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
671
-
672
- if attn.group_norm is not None:
673
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
674
-
675
- query = attn.to_q(hidden_states)
676
-
677
- if encoder_hidden_states is None:
678
- encoder_hidden_states = hidden_states
679
- elif attn.norm_cross:
680
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
681
-
682
- key = attn.to_k(encoder_hidden_states)
683
- value = attn.to_v(encoder_hidden_states)
684
-
685
- # print('query', query.shape, 'key', key.shape, 'value', value.shape)
686
- #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320])
687
- # pdb.set_trace()
688
- # multi-view self-attention
689
- key = rearrange(key, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
690
- value = rearrange(value, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
691
- query = rearrange(query, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) # torch.Size([192, 384, 320])
692
-
693
- query = attn.head_to_batch_dim(query).contiguous()
694
- key = attn.head_to_batch_dim(key).contiguous()
695
- value = attn.head_to_batch_dim(value).contiguous()
696
-
697
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
698
- hidden_states = torch.bmm(attention_probs, value)
699
- hidden_states = attn.batch_to_head_dim(hidden_states)
700
-
701
- # linear proj
702
- hidden_states = attn.to_out[0](hidden_states)
703
- # dropout
704
- hidden_states = attn.to_out[1](hidden_states)
705
- hidden_states = rearrange(hidden_states, "(b h) (v w) c -> (b v) (h w) c", v=num_views, h=height)
706
- if input_ndim == 4:
707
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
708
-
709
- if attn.residual_connection:
710
- hidden_states = hidden_states + residual
711
-
712
- hidden_states = hidden_states / attn.rescale_output_factor
713
-
714
- return hidden_states
715
-
716
-
717
- class XFormersMVAttnProcessor:
718
- r"""
719
- Default processor for performing attention-related computations.
720
- """
721
-
722
- def __call__(
723
- self,
724
- attn: Attention,
725
- hidden_states,
726
- encoder_hidden_states=None,
727
- attention_mask=None,
728
- temb=None,
729
- num_views=1,
730
- multiview_attention=True,
731
- mvcd_attention=False,
732
- ):
733
- residual = hidden_states
734
-
735
- if attn.spatial_norm is not None:
736
- hidden_states = attn.spatial_norm(hidden_states, temb)
737
-
738
- input_ndim = hidden_states.ndim
739
-
740
- if input_ndim == 4:
741
- batch_size, channel, height, width = hidden_states.shape
742
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
743
-
744
- batch_size, sequence_length, _ = (
745
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
746
- )
747
- height = int(math.sqrt(sequence_length))
748
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
749
- # from yuancheng; here attention_mask is None
750
- if attention_mask is not None:
751
- # expand our mask's singleton query_tokens dimension:
752
- # [batch*heads, 1, key_tokens] ->
753
- # [batch*heads, query_tokens, key_tokens]
754
- # so that it can be added as a bias onto the attention scores that xformers computes:
755
- # [batch*heads, query_tokens, key_tokens]
756
- # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
757
- _, query_tokens, _ = hidden_states.shape
758
- attention_mask = attention_mask.expand(-1, query_tokens, -1)
759
-
760
- if attn.group_norm is not None:
761
- print('Warning: using group norm, pay attention to use it in row-wise attention')
762
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
763
-
764
- query = attn.to_q(hidden_states)
765
-
766
- if encoder_hidden_states is None:
767
- encoder_hidden_states = hidden_states
768
- elif attn.norm_cross:
769
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
770
-
771
- key_raw = attn.to_k(encoder_hidden_states)
772
- value_raw = attn.to_v(encoder_hidden_states)
773
-
774
- # print('query', query.shape, 'key', key.shape, 'value', value.shape)
775
- # pdb.set_trace()
776
-
777
- key = rearrange(key_raw, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
778
- value = rearrange(value_raw, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
779
- query = rearrange(query, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) # torch.Size([192, 384, 320])
780
- if mvcd_attention:
781
- # memory efficient, cross domain attention
782
- key_0, key_1 = torch.chunk(key_raw, dim=0, chunks=2) # keys shape (b t) d c
783
- value_0, value_1 = torch.chunk(value_raw, dim=0, chunks=2)
784
- key_cross = torch.concat([key_1, key_0], dim=0)
785
- value_cross = torch.concat([value_1, value_0], dim=0) # shape (b t) d c
786
- key = torch.cat([key, key_cross], dim=1)
787
- value = torch.cat([value, value_cross], dim=1) # shape (b t) (t+1 d) c
788
-
789
-
790
- query = attn.head_to_batch_dim(query) # torch.Size([960, 384, 64])
791
- key = attn.head_to_batch_dim(key)
792
- value = attn.head_to_batch_dim(value)
793
-
794
- hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
795
- hidden_states = attn.batch_to_head_dim(hidden_states)
796
-
797
- # linear proj
798
- hidden_states = attn.to_out[0](hidden_states)
799
- # dropout
800
- hidden_states = attn.to_out[1](hidden_states)
801
- # print(hidden_states.shape)
802
- hidden_states = rearrange(hidden_states, "(b h) (v w) c -> (b v) (h w) c", v=num_views, h=height)
803
- if input_ndim == 4:
804
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
805
-
806
- if attn.residual_connection:
807
- hidden_states = hidden_states + residual
808
-
809
- hidden_states = hidden_states / attn.rescale_output_factor
810
-
811
- return hidden_states
812
-
813
-
814
- class XFormersJointAttnProcessor:
815
- r"""
816
- Default processor for performing attention-related computations.
817
- """
818
-
819
- def __call__(
820
- self,
821
- attn: Attention,
822
- hidden_states,
823
- encoder_hidden_states=None,
824
- attention_mask=None,
825
- temb=None,
826
- num_tasks=2
827
- ):
828
-
829
- residual = hidden_states
830
-
831
- if attn.spatial_norm is not None:
832
- hidden_states = attn.spatial_norm(hidden_states, temb)
833
-
834
- input_ndim = hidden_states.ndim
835
-
836
- if input_ndim == 4:
837
- batch_size, channel, height, width = hidden_states.shape
838
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
839
-
840
- batch_size, sequence_length, _ = (
841
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
842
- )
843
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
844
-
845
- # from yuancheng; here attention_mask is None
846
- if attention_mask is not None:
847
- # expand our mask's singleton query_tokens dimension:
848
- # [batch*heads, 1, key_tokens] ->
849
- # [batch*heads, query_tokens, key_tokens]
850
- # so that it can be added as a bias onto the attention scores that xformers computes:
851
- # [batch*heads, query_tokens, key_tokens]
852
- # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
853
- _, query_tokens, _ = hidden_states.shape
854
- attention_mask = attention_mask.expand(-1, query_tokens, -1)
855
-
856
- if attn.group_norm is not None:
857
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
858
-
859
- query = attn.to_q(hidden_states)
860
-
861
- if encoder_hidden_states is None:
862
- encoder_hidden_states = hidden_states
863
- elif attn.norm_cross:
864
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
865
-
866
- key = attn.to_k(encoder_hidden_states)
867
- value = attn.to_v(encoder_hidden_states)
868
-
869
- assert num_tasks == 2 # only support two tasks now
870
-
871
- key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c
872
- value_0, value_1 = torch.chunk(value, dim=0, chunks=2)
873
- key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c
874
- value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c
875
- key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c
876
- value = torch.cat([value]*2, dim=0) # (2 b t) 2d c
877
-
878
-
879
- query = attn.head_to_batch_dim(query).contiguous()
880
- key = attn.head_to_batch_dim(key).contiguous()
881
- value = attn.head_to_batch_dim(value).contiguous()
882
-
883
- hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
884
- hidden_states = attn.batch_to_head_dim(hidden_states)
885
-
886
- # linear proj
887
- hidden_states = attn.to_out[0](hidden_states)
888
- # dropout
889
- hidden_states = attn.to_out[1](hidden_states)
890
-
891
- if input_ndim == 4:
892
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
893
-
894
- if attn.residual_connection:
895
- hidden_states = hidden_states + residual
896
-
897
- hidden_states = hidden_states / attn.rescale_output_factor
898
-
899
- return hidden_states
900
-
901
-
902
- class JointAttnProcessor:
903
- r"""
904
- Default processor for performing attention-related computations.
905
- """
906
-
907
- def __call__(
908
- self,
909
- attn: Attention,
910
- hidden_states,
911
- encoder_hidden_states=None,
912
- attention_mask=None,
913
- temb=None,
914
- num_tasks=2
915
- ):
916
-
917
- residual = hidden_states
918
-
919
- if attn.spatial_norm is not None:
920
- hidden_states = attn.spatial_norm(hidden_states, temb)
921
-
922
- input_ndim = hidden_states.ndim
923
-
924
- if input_ndim == 4:
925
- batch_size, channel, height, width = hidden_states.shape
926
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
927
-
928
- batch_size, sequence_length, _ = (
929
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
930
- )
931
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
932
-
933
-
934
- if attn.group_norm is not None:
935
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
936
-
937
- query = attn.to_q(hidden_states)
938
-
939
- if encoder_hidden_states is None:
940
- encoder_hidden_states = hidden_states
941
- elif attn.norm_cross:
942
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
943
-
944
- key = attn.to_k(encoder_hidden_states)
945
- value = attn.to_v(encoder_hidden_states)
946
-
947
- assert num_tasks == 2 # only support two tasks now
948
-
949
- key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c
950
- value_0, value_1 = torch.chunk(value, dim=0, chunks=2)
951
- key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c
952
- value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c
953
- key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c
954
- value = torch.cat([value]*2, dim=0) # (2 b t) 2d c
955
-
956
-
957
- query = attn.head_to_batch_dim(query).contiguous()
958
- key = attn.head_to_batch_dim(key).contiguous()
959
- value = attn.head_to_batch_dim(value).contiguous()
960
-
961
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
962
- hidden_states = torch.bmm(attention_probs, value)
963
- hidden_states = attn.batch_to_head_dim(hidden_states)
964
-
965
- # linear proj
966
- hidden_states = attn.to_out[0](hidden_states)
967
- # dropout
968
- hidden_states = attn.to_out[1](hidden_states)
969
-
970
- if input_ndim == 4:
971
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
972
-
973
- if attn.residual_connection:
974
- hidden_states = hidden_states + residual
975
-
976
- hidden_states = hidden_states / attn.rescale_output_factor
977
-
978
- return hidden_states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mvdiffusion/models/transformer_mv2d_self_rowwise.py DELETED
@@ -1,1038 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from dataclasses import dataclass
15
- from typing import Any, Dict, Optional
16
-
17
- import torch
18
- import torch.nn.functional as F
19
- from torch import nn
20
-
21
- from diffusers.configuration_utils import ConfigMixin, register_to_config
22
- from diffusers.models.embeddings import ImagePositionalEmbeddings
23
- from diffusers.utils import BaseOutput, deprecate
24
- from diffusers.utils.torch_utils import maybe_allow_in_graph
25
- from diffusers.models.attention import FeedForward, AdaLayerNorm, AdaLayerNormZero, Attention
26
- from diffusers.models.embeddings import PatchEmbed
27
- from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
28
- from diffusers.models.modeling_utils import ModelMixin
29
- from diffusers.utils.import_utils import is_xformers_available
30
-
31
- from einops import rearrange
32
- import pdb
33
- import random
34
- import math
35
-
36
-
37
- if is_xformers_available():
38
- import xformers
39
- import xformers.ops
40
- else:
41
- xformers = None
42
-
43
-
44
- @dataclass
45
- class TransformerMV2DModelOutput(BaseOutput):
46
- """
47
- The output of [`Transformer2DModel`].
48
-
49
- Args:
50
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
51
- The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
52
- distributions for the unnoised latent pixels.
53
- """
54
-
55
- sample: torch.FloatTensor
56
-
57
-
58
- class TransformerMV2DModel(ModelMixin, ConfigMixin):
59
- """
60
- A 2D Transformer model for image-like data.
61
-
62
- Parameters:
63
- num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
64
- attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
65
- in_channels (`int`, *optional*):
66
- The number of channels in the input and output (specify if the input is **continuous**).
67
- num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
68
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
69
- cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
70
- sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
71
- This is fixed during training since it is used to learn a number of position embeddings.
72
- num_vector_embeds (`int`, *optional*):
73
- The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
74
- Includes the class for the masked latent pixel.
75
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
76
- num_embeds_ada_norm ( `int`, *optional*):
77
- The number of diffusion steps used during training. Pass if at least one of the norm_layers is
78
- `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
79
- added to the hidden states.
80
-
81
- During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
82
- attention_bias (`bool`, *optional*):
83
- Configure if the `TransformerBlocks` attention should contain a bias parameter.
84
- """
85
-
86
- @register_to_config
87
- def __init__(
88
- self,
89
- num_attention_heads: int = 16,
90
- attention_head_dim: int = 88,
91
- in_channels: Optional[int] = None,
92
- out_channels: Optional[int] = None,
93
- num_layers: int = 1,
94
- dropout: float = 0.0,
95
- norm_num_groups: int = 32,
96
- cross_attention_dim: Optional[int] = None,
97
- attention_bias: bool = False,
98
- sample_size: Optional[int] = None,
99
- num_vector_embeds: Optional[int] = None,
100
- patch_size: Optional[int] = None,
101
- activation_fn: str = "geglu",
102
- num_embeds_ada_norm: Optional[int] = None,
103
- use_linear_projection: bool = False,
104
- only_cross_attention: bool = False,
105
- upcast_attention: bool = False,
106
- norm_type: str = "layer_norm",
107
- norm_elementwise_affine: bool = True,
108
- num_views: int = 1,
109
- cd_attention_mid: bool=False,
110
- cd_attention_last: bool=False,
111
- multiview_attention: bool=True,
112
- sparse_mv_attention: bool = True, # not used
113
- mvcd_attention: bool=False,
114
- use_dino: bool=False
115
- ):
116
- super().__init__()
117
- self.use_linear_projection = use_linear_projection
118
- self.num_attention_heads = num_attention_heads
119
- self.attention_head_dim = attention_head_dim
120
- inner_dim = num_attention_heads * attention_head_dim
121
-
122
- # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
123
- # Define whether input is continuous or discrete depending on configuration
124
- self.is_input_continuous = (in_channels is not None) and (patch_size is None)
125
- self.is_input_vectorized = num_vector_embeds is not None
126
- self.is_input_patches = in_channels is not None and patch_size is not None
127
-
128
- if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
129
- deprecation_message = (
130
- f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
131
- " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
132
- " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
133
- " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
134
- " would be very nice if you could open a Pull request for the `transformer/config.json` file"
135
- )
136
- deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
137
- norm_type = "ada_norm"
138
-
139
- if self.is_input_continuous and self.is_input_vectorized:
140
- raise ValueError(
141
- f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
142
- " sure that either `in_channels` or `num_vector_embeds` is None."
143
- )
144
- elif self.is_input_vectorized and self.is_input_patches:
145
- raise ValueError(
146
- f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
147
- " sure that either `num_vector_embeds` or `num_patches` is None."
148
- )
149
- elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
150
- raise ValueError(
151
- f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
152
- f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
153
- )
154
-
155
- # 2. Define input layers
156
- if self.is_input_continuous:
157
- self.in_channels = in_channels
158
-
159
- self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
160
- if use_linear_projection:
161
- self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
162
- else:
163
- self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
164
- elif self.is_input_vectorized:
165
- assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
166
- assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
167
-
168
- self.height = sample_size
169
- self.width = sample_size
170
- self.num_vector_embeds = num_vector_embeds
171
- self.num_latent_pixels = self.height * self.width
172
-
173
- self.latent_image_embedding = ImagePositionalEmbeddings(
174
- num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
175
- )
176
- elif self.is_input_patches:
177
- assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
178
-
179
- self.height = sample_size
180
- self.width = sample_size
181
-
182
- self.patch_size = patch_size
183
- self.pos_embed = PatchEmbed(
184
- height=sample_size,
185
- width=sample_size,
186
- patch_size=patch_size,
187
- in_channels=in_channels,
188
- embed_dim=inner_dim,
189
- )
190
-
191
- # 3. Define transformers blocks
192
- self.transformer_blocks = nn.ModuleList(
193
- [
194
- BasicMVTransformerBlock(
195
- inner_dim,
196
- num_attention_heads,
197
- attention_head_dim,
198
- dropout=dropout,
199
- cross_attention_dim=cross_attention_dim,
200
- activation_fn=activation_fn,
201
- num_embeds_ada_norm=num_embeds_ada_norm,
202
- attention_bias=attention_bias,
203
- only_cross_attention=only_cross_attention,
204
- upcast_attention=upcast_attention,
205
- norm_type=norm_type,
206
- norm_elementwise_affine=norm_elementwise_affine,
207
- num_views=num_views,
208
- cd_attention_last=cd_attention_last,
209
- cd_attention_mid=cd_attention_mid,
210
- multiview_attention=multiview_attention,
211
- mvcd_attention=mvcd_attention,
212
- use_dino=use_dino
213
- )
214
- for d in range(num_layers)
215
- ]
216
- )
217
-
218
- # 4. Define output layers
219
- self.out_channels = in_channels if out_channels is None else out_channels
220
- if self.is_input_continuous:
221
- # TODO: should use out_channels for continuous projections
222
- if use_linear_projection:
223
- self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
224
- else:
225
- self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
226
- elif self.is_input_vectorized:
227
- self.norm_out = nn.LayerNorm(inner_dim)
228
- self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
229
- elif self.is_input_patches:
230
- self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
231
- self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
232
- self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
233
-
234
- def forward(
235
- self,
236
- hidden_states: torch.Tensor,
237
- encoder_hidden_states: Optional[torch.Tensor] = None,
238
- dino_feature: Optional[torch.Tensor] = None,
239
- timestep: Optional[torch.LongTensor] = None,
240
- class_labels: Optional[torch.LongTensor] = None,
241
- cross_attention_kwargs: Dict[str, Any] = None,
242
- attention_mask: Optional[torch.Tensor] = None,
243
- encoder_attention_mask: Optional[torch.Tensor] = None,
244
- return_dict: bool = True,
245
- ):
246
- """
247
- The [`Transformer2DModel`] forward method.
248
-
249
- Args:
250
- hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
251
- Input `hidden_states`.
252
- encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
253
- Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
254
- self-attention.
255
- timestep ( `torch.LongTensor`, *optional*):
256
- Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
257
- class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
258
- Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
259
- `AdaLayerZeroNorm`.
260
- encoder_attention_mask ( `torch.Tensor`, *optional*):
261
- Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
262
-
263
- * Mask `(batch, sequence_length)` True = keep, False = discard.
264
- * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
265
-
266
- If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
267
- above. This bias will be added to the cross-attention scores.
268
- return_dict (`bool`, *optional*, defaults to `True`):
269
- Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
270
- tuple.
271
-
272
- Returns:
273
- If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
274
- `tuple` where the first element is the sample tensor.
275
- """
276
- # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
277
- # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
278
- # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
279
- # expects mask of shape:
280
- # [batch, key_tokens]
281
- # adds singleton query_tokens dimension:
282
- # [batch, 1, key_tokens]
283
- # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
284
- # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
285
- # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
286
- if attention_mask is not None and attention_mask.ndim == 2:
287
- # assume that mask is expressed as:
288
- # (1 = keep, 0 = discard)
289
- # convert mask into a bias that can be added to attention scores:
290
- # (keep = +0, discard = -10000.0)
291
- attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
292
- attention_mask = attention_mask.unsqueeze(1)
293
-
294
- # convert encoder_attention_mask to a bias the same way we do for attention_mask
295
- if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
296
- encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
297
- encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
298
-
299
- # 1. Input
300
- if self.is_input_continuous:
301
- batch, _, height, width = hidden_states.shape
302
- residual = hidden_states
303
-
304
- hidden_states = self.norm(hidden_states)
305
- if not self.use_linear_projection:
306
- hidden_states = self.proj_in(hidden_states)
307
- inner_dim = hidden_states.shape[1]
308
- hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
309
- else:
310
- inner_dim = hidden_states.shape[1]
311
- hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
312
- hidden_states = self.proj_in(hidden_states)
313
- elif self.is_input_vectorized:
314
- hidden_states = self.latent_image_embedding(hidden_states)
315
- elif self.is_input_patches:
316
- hidden_states = self.pos_embed(hidden_states)
317
-
318
- # 2. Blocks
319
- for block in self.transformer_blocks:
320
- hidden_states = block(
321
- hidden_states,
322
- attention_mask=attention_mask,
323
- encoder_hidden_states=encoder_hidden_states,
324
- dino_feature=dino_feature,
325
- encoder_attention_mask=encoder_attention_mask,
326
- timestep=timestep,
327
- cross_attention_kwargs=cross_attention_kwargs,
328
- class_labels=class_labels,
329
- )
330
-
331
- # 3. Output
332
- if self.is_input_continuous:
333
- if not self.use_linear_projection:
334
- hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
335
- hidden_states = self.proj_out(hidden_states)
336
- else:
337
- hidden_states = self.proj_out(hidden_states)
338
- hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
339
-
340
- output = hidden_states + residual
341
- elif self.is_input_vectorized:
342
- hidden_states = self.norm_out(hidden_states)
343
- logits = self.out(hidden_states)
344
- # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
345
- logits = logits.permute(0, 2, 1)
346
-
347
- # log(p(x_0))
348
- output = F.log_softmax(logits.double(), dim=1).float()
349
- elif self.is_input_patches:
350
- # TODO: cleanup!
351
- conditioning = self.transformer_blocks[0].norm1.emb(
352
- timestep, class_labels, hidden_dtype=hidden_states.dtype
353
- )
354
- shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
355
- hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
356
- hidden_states = self.proj_out_2(hidden_states)
357
-
358
- # unpatchify
359
- height = width = int(hidden_states.shape[1] ** 0.5)
360
- hidden_states = hidden_states.reshape(
361
- shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
362
- )
363
- hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
364
- output = hidden_states.reshape(
365
- shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
366
- )
367
-
368
- if not return_dict:
369
- return (output,)
370
-
371
- return TransformerMV2DModelOutput(sample=output)
372
-
373
-
374
- @maybe_allow_in_graph
375
- class BasicMVTransformerBlock(nn.Module):
376
- r"""
377
- A basic Transformer block.
378
-
379
- Parameters:
380
- dim (`int`): The number of channels in the input and output.
381
- num_attention_heads (`int`): The number of heads to use for multi-head attention.
382
- attention_head_dim (`int`): The number of channels in each head.
383
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
384
- cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
385
- only_cross_attention (`bool`, *optional*):
386
- Whether to use only cross-attention layers. In this case two cross attention layers are used.
387
- double_self_attention (`bool`, *optional*):
388
- Whether to use two self-attention layers. In this case no cross attention layers are used.
389
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
390
- num_embeds_ada_norm (:
391
- obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
392
- attention_bias (:
393
- obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
394
- """
395
-
396
- def __init__(
397
- self,
398
- dim: int,
399
- num_attention_heads: int,
400
- attention_head_dim: int,
401
- dropout=0.0,
402
- cross_attention_dim: Optional[int] = None,
403
- activation_fn: str = "geglu",
404
- num_embeds_ada_norm: Optional[int] = None,
405
- attention_bias: bool = False,
406
- only_cross_attention: bool = False,
407
- double_self_attention: bool = False,
408
- upcast_attention: bool = False,
409
- norm_elementwise_affine: bool = True,
410
- norm_type: str = "layer_norm",
411
- final_dropout: bool = False,
412
- num_views: int = 1,
413
- cd_attention_last: bool = False,
414
- cd_attention_mid: bool = False,
415
- multiview_attention: bool = True,
416
- mvcd_attention: bool = False,
417
- rowwise_attention: bool = True,
418
- use_dino: bool = False
419
- ):
420
- super().__init__()
421
- self.only_cross_attention = only_cross_attention
422
-
423
- self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
424
- self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
425
-
426
- if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
427
- raise ValueError(
428
- f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
429
- f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
430
- )
431
-
432
- # Define 3 blocks. Each block has its own normalization layer.
433
- # 1. Self-Attn
434
- if self.use_ada_layer_norm:
435
- self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
436
- elif self.use_ada_layer_norm_zero:
437
- self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
438
- else:
439
- self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
440
-
441
- self.multiview_attention = multiview_attention
442
- self.mvcd_attention = mvcd_attention
443
- self.cd_attention_mid = cd_attention_mid
444
- self.rowwise_attention = multiview_attention and rowwise_attention
445
-
446
- if mvcd_attention and (not cd_attention_mid):
447
- # add cross domain attn to self attn
448
- self.attn1 = CustomJointAttention(
449
- query_dim=dim,
450
- heads=num_attention_heads,
451
- dim_head=attention_head_dim,
452
- dropout=dropout,
453
- bias=attention_bias,
454
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
455
- upcast_attention=upcast_attention,
456
- processor=JointAttnProcessor()
457
- )
458
- else:
459
- self.attn1 = Attention(
460
- query_dim=dim,
461
- heads=num_attention_heads,
462
- dim_head=attention_head_dim,
463
- dropout=dropout,
464
- bias=attention_bias,
465
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
466
- upcast_attention=upcast_attention
467
- )
468
- # 1.1 rowwise multiview attention
469
- if self.rowwise_attention:
470
- # print('INFO: using self+row_wise mv attention...')
471
- self.norm_mv = (
472
- AdaLayerNorm(dim, num_embeds_ada_norm)
473
- if self.use_ada_layer_norm
474
- else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
475
- )
476
- self.attn_mv = CustomAttention(
477
- query_dim=dim,
478
- heads=num_attention_heads,
479
- dim_head=attention_head_dim,
480
- dropout=dropout,
481
- bias=attention_bias,
482
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
483
- upcast_attention=upcast_attention,
484
- processor=MVAttnProcessor()
485
- )
486
- nn.init.zeros_(self.attn_mv.to_out[0].weight.data)
487
- else:
488
- self.norm_mv = None
489
- self.attn_mv = None
490
-
491
- # # 1.2 rowwise cross-domain attn
492
- # if mvcd_attention:
493
- # self.attn_joint = CustomJointAttention(
494
- # query_dim=dim,
495
- # heads=num_attention_heads,
496
- # dim_head=attention_head_dim,
497
- # dropout=dropout,
498
- # bias=attention_bias,
499
- # cross_attention_dim=cross_attention_dim if only_cross_attention else None,
500
- # upcast_attention=upcast_attention,
501
- # processor=JointAttnProcessor()
502
- # )
503
- # nn.init.zeros_(self.attn_joint.to_out[0].weight.data)
504
- # self.norm_joint = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
505
- # else:
506
- # self.attn_joint = None
507
- # self.norm_joint = None
508
-
509
- # 2. Cross-Attn
510
- if cross_attention_dim is not None or double_self_attention:
511
- # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
512
- # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
513
- # the second cross attention block.
514
- self.norm2 = (
515
- AdaLayerNorm(dim, num_embeds_ada_norm)
516
- if self.use_ada_layer_norm
517
- else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
518
- )
519
- self.attn2 = Attention(
520
- query_dim=dim,
521
- cross_attention_dim=cross_attention_dim if not double_self_attention else None,
522
- heads=num_attention_heads,
523
- dim_head=attention_head_dim,
524
- dropout=dropout,
525
- bias=attention_bias,
526
- upcast_attention=upcast_attention,
527
- ) # is self-attn if encoder_hidden_states is none
528
- else:
529
- self.norm2 = None
530
- self.attn2 = None
531
-
532
- # 3. Feed-forward
533
- self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
534
- self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
535
-
536
- # let chunk size default to None
537
- self._chunk_size = None
538
- self._chunk_dim = 0
539
-
540
- self.num_views = num_views
541
-
542
-
543
- def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
544
- # Sets chunk feed-forward
545
- self._chunk_size = chunk_size
546
- self._chunk_dim = dim
547
-
548
- def forward(
549
- self,
550
- hidden_states: torch.FloatTensor,
551
- attention_mask: Optional[torch.FloatTensor] = None,
552
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
553
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
554
- timestep: Optional[torch.LongTensor] = None,
555
- cross_attention_kwargs: Dict[str, Any] = None,
556
- class_labels: Optional[torch.LongTensor] = None,
557
- dino_feature: Optional[torch.FloatTensor] = None
558
- ):
559
- assert attention_mask is None # not supported yet
560
- # Notice that normalization is always applied before the real computation in the following blocks.
561
- # 1. Self-Attention
562
- if self.use_ada_layer_norm:
563
- norm_hidden_states = self.norm1(hidden_states, timestep)
564
- elif self.use_ada_layer_norm_zero:
565
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
566
- hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
567
- )
568
- else:
569
- norm_hidden_states = self.norm1(hidden_states)
570
-
571
- cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
572
-
573
- attn_output = self.attn1(
574
- norm_hidden_states,
575
- encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
576
- attention_mask=attention_mask,
577
- # multiview_attention=self.multiview_attention,
578
- # mvcd_attention=self.mvcd_attention,
579
- **cross_attention_kwargs,
580
- )
581
-
582
-
583
- if self.use_ada_layer_norm_zero:
584
- attn_output = gate_msa.unsqueeze(1) * attn_output
585
- hidden_states = attn_output + hidden_states
586
-
587
- # import pdb;pdb.set_trace()
588
- # 1.1 row wise multiview attention
589
- if self.rowwise_attention:
590
- norm_hidden_states = (
591
- self.norm_mv(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_mv(hidden_states)
592
- )
593
- attn_output = self.attn_mv(
594
- norm_hidden_states,
595
- encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
596
- attention_mask=attention_mask,
597
- num_views=self.num_views,
598
- multiview_attention=self.multiview_attention,
599
- cd_attention_mid=self.cd_attention_mid,
600
- **cross_attention_kwargs,
601
- )
602
- hidden_states = attn_output + hidden_states
603
-
604
-
605
- # 2. Cross-Attention
606
- if self.attn2 is not None:
607
- norm_hidden_states = (
608
- self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
609
- )
610
-
611
- attn_output = self.attn2(
612
- norm_hidden_states,
613
- encoder_hidden_states=encoder_hidden_states,
614
- attention_mask=encoder_attention_mask,
615
- **cross_attention_kwargs,
616
- )
617
- hidden_states = attn_output + hidden_states
618
-
619
- # 3. Feed-forward
620
- norm_hidden_states = self.norm3(hidden_states)
621
-
622
- if self.use_ada_layer_norm_zero:
623
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
624
-
625
- if self._chunk_size is not None:
626
- # "feed_forward_chunk_size" can be used to save memory
627
- if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
628
- raise ValueError(
629
- f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
630
- )
631
-
632
- num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
633
- ff_output = torch.cat(
634
- [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
635
- dim=self._chunk_dim,
636
- )
637
- else:
638
- ff_output = self.ff(norm_hidden_states)
639
-
640
- if self.use_ada_layer_norm_zero:
641
- ff_output = gate_mlp.unsqueeze(1) * ff_output
642
-
643
- hidden_states = ff_output + hidden_states
644
-
645
- return hidden_states
646
-
647
-
648
- class CustomAttention(Attention):
649
- def set_use_memory_efficient_attention_xformers(
650
- self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
651
- ):
652
- processor = XFormersMVAttnProcessor()
653
- self.set_processor(processor)
654
- # print("using xformers attention processor")
655
-
656
-
657
- class CustomJointAttention(Attention):
658
- def set_use_memory_efficient_attention_xformers(
659
- self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
660
- ):
661
- processor = XFormersJointAttnProcessor()
662
- self.set_processor(processor)
663
- # print("using xformers attention processor")
664
-
665
- class MVAttnProcessor:
666
- r"""
667
- Default processor for performing attention-related computations.
668
- """
669
-
670
- def __call__(
671
- self,
672
- attn: Attention,
673
- hidden_states,
674
- encoder_hidden_states=None,
675
- attention_mask=None,
676
- temb=None,
677
- num_views=1,
678
- cd_attention_mid=False
679
- ):
680
- residual = hidden_states
681
-
682
- if attn.spatial_norm is not None:
683
- hidden_states = attn.spatial_norm(hidden_states, temb)
684
-
685
- input_ndim = hidden_states.ndim
686
-
687
- if input_ndim == 4:
688
- batch_size, channel, height, width = hidden_states.shape
689
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
690
-
691
- batch_size, sequence_length, _ = (
692
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
693
- )
694
- height = int(math.sqrt(sequence_length))
695
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
696
-
697
- if attn.group_norm is not None:
698
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
699
-
700
- query = attn.to_q(hidden_states)
701
-
702
- if encoder_hidden_states is None:
703
- encoder_hidden_states = hidden_states
704
- elif attn.norm_cross:
705
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
706
-
707
- key = attn.to_k(encoder_hidden_states)
708
- value = attn.to_v(encoder_hidden_states)
709
-
710
- # print('query', query.shape, 'key', key.shape, 'value', value.shape)
711
- #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320])
712
- # pdb.set_trace()
713
- # multi-view self-attention
714
- def transpose(tensor):
715
- tensor = rearrange(tensor, "(b v) (h w) c -> b v h w c", v=num_views, h=height)
716
- tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # b v h w c
717
- tensor = torch.cat([tensor_0, tensor_1], dim=3) # b v h 2w c
718
- tensor = rearrange(tensor, "b v h w c -> (b h) (v w) c", v=num_views, h=height)
719
- return tensor
720
-
721
- if cd_attention_mid:
722
- key = transpose(key)
723
- value = transpose(value)
724
- query = transpose(query)
725
- else:
726
- key = rearrange(key, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
727
- value = rearrange(value, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
728
- query = rearrange(query, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) # torch.Size([192, 384, 320])
729
-
730
- query = attn.head_to_batch_dim(query).contiguous()
731
- key = attn.head_to_batch_dim(key).contiguous()
732
- value = attn.head_to_batch_dim(value).contiguous()
733
-
734
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
735
- hidden_states = torch.bmm(attention_probs, value)
736
- hidden_states = attn.batch_to_head_dim(hidden_states)
737
-
738
- # linear proj
739
- hidden_states = attn.to_out[0](hidden_states)
740
- # dropout
741
- hidden_states = attn.to_out[1](hidden_states)
742
- if cd_attention_mid:
743
- hidden_states = rearrange(hidden_states, "(b h) (v w) c -> b v h w c", v=num_views, h=height)
744
- hidden_states_0, hidden_states_1 = torch.chunk(hidden_states, dim=3, chunks=2) # b v h w c
745
- hidden_states = torch.cat([hidden_states_0, hidden_states_1], dim=0) # 2b v h w c
746
- hidden_states = rearrange(hidden_states, "b v h w c -> (b v) (h w) c", v=num_views, h=height)
747
- else:
748
- hidden_states = rearrange(hidden_states, "(b h) (v w) c -> (b v) (h w) c", v=num_views, h=height)
749
- if input_ndim == 4:
750
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
751
-
752
- if attn.residual_connection:
753
- hidden_states = hidden_states + residual
754
-
755
- hidden_states = hidden_states / attn.rescale_output_factor
756
-
757
- return hidden_states
758
-
759
-
760
- class XFormersMVAttnProcessor:
761
- r"""
762
- Default processor for performing attention-related computations.
763
- """
764
-
765
- def __call__(
766
- self,
767
- attn: Attention,
768
- hidden_states,
769
- encoder_hidden_states=None,
770
- attention_mask=None,
771
- temb=None,
772
- num_views=1,
773
- multiview_attention=True,
774
- cd_attention_mid=False
775
- ):
776
- # print(num_views)
777
- residual = hidden_states
778
-
779
- if attn.spatial_norm is not None:
780
- hidden_states = attn.spatial_norm(hidden_states, temb)
781
-
782
- input_ndim = hidden_states.ndim
783
-
784
- if input_ndim == 4:
785
- batch_size, channel, height, width = hidden_states.shape
786
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
787
-
788
- batch_size, sequence_length, _ = (
789
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
790
- )
791
- height = int(math.sqrt(sequence_length))
792
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
793
- # from yuancheng; here attention_mask is None
794
- if attention_mask is not None:
795
- # expand our mask's singleton query_tokens dimension:
796
- # [batch*heads, 1, key_tokens] ->
797
- # [batch*heads, query_tokens, key_tokens]
798
- # so that it can be added as a bias onto the attention scores that xformers computes:
799
- # [batch*heads, query_tokens, key_tokens]
800
- # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
801
- _, query_tokens, _ = hidden_states.shape
802
- attention_mask = attention_mask.expand(-1, query_tokens, -1)
803
-
804
- if attn.group_norm is not None:
805
- print('Warning: using group norm, pay attention to use it in row-wise attention')
806
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
807
-
808
- query = attn.to_q(hidden_states)
809
-
810
- if encoder_hidden_states is None:
811
- encoder_hidden_states = hidden_states
812
- elif attn.norm_cross:
813
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
814
-
815
- key_raw = attn.to_k(encoder_hidden_states)
816
- value_raw = attn.to_v(encoder_hidden_states)
817
-
818
- # print('query', query.shape, 'key', key.shape, 'value', value.shape)
819
- # pdb.set_trace()
820
- def transpose(tensor):
821
- tensor = rearrange(tensor, "(b v) (h w) c -> b v h w c", v=num_views, h=height)
822
- tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # b v h w c
823
- tensor = torch.cat([tensor_0, tensor_1], dim=3) # b v h 2w c
824
- tensor = rearrange(tensor, "b v h w c -> (b h) (v w) c", v=num_views, h=height)
825
- return tensor
826
- # print(mvcd_attention)
827
- # import pdb;pdb.set_trace()
828
- if cd_attention_mid:
829
- key = transpose(key_raw)
830
- value = transpose(value_raw)
831
- query = transpose(query)
832
- else:
833
- key = rearrange(key_raw, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
834
- value = rearrange(value_raw, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
835
- query = rearrange(query, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) # torch.Size([192, 384, 320])
836
-
837
-
838
- query = attn.head_to_batch_dim(query) # torch.Size([960, 384, 64])
839
- key = attn.head_to_batch_dim(key)
840
- value = attn.head_to_batch_dim(value)
841
-
842
- hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
843
- hidden_states = attn.batch_to_head_dim(hidden_states)
844
-
845
- # linear proj
846
- hidden_states = attn.to_out[0](hidden_states)
847
- # dropout
848
- hidden_states = attn.to_out[1](hidden_states)
849
-
850
- if cd_attention_mid:
851
- hidden_states = rearrange(hidden_states, "(b h) (v w) c -> b v h w c", v=num_views, h=height)
852
- hidden_states_0, hidden_states_1 = torch.chunk(hidden_states, dim=3, chunks=2) # b v h w c
853
- hidden_states = torch.cat([hidden_states_0, hidden_states_1], dim=0) # 2b v h w c
854
- hidden_states = rearrange(hidden_states, "b v h w c -> (b v) (h w) c", v=num_views, h=height)
855
- else:
856
- hidden_states = rearrange(hidden_states, "(b h) (v w) c -> (b v) (h w) c", v=num_views, h=height)
857
- if input_ndim == 4:
858
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
859
-
860
- if attn.residual_connection:
861
- hidden_states = hidden_states + residual
862
-
863
- hidden_states = hidden_states / attn.rescale_output_factor
864
-
865
- return hidden_states
866
-
867
-
868
- class XFormersJointAttnProcessor:
869
- r"""
870
- Default processor for performing attention-related computations.
871
- """
872
-
873
- def __call__(
874
- self,
875
- attn: Attention,
876
- hidden_states,
877
- encoder_hidden_states=None,
878
- attention_mask=None,
879
- temb=None,
880
- num_tasks=2
881
- ):
882
- residual = hidden_states
883
-
884
- if attn.spatial_norm is not None:
885
- hidden_states = attn.spatial_norm(hidden_states, temb)
886
-
887
- input_ndim = hidden_states.ndim
888
-
889
- if input_ndim == 4:
890
- batch_size, channel, height, width = hidden_states.shape
891
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
892
-
893
- batch_size, sequence_length, _ = (
894
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
895
- )
896
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
897
-
898
- # from yuancheng; here attention_mask is None
899
- if attention_mask is not None:
900
- # expand our mask's singleton query_tokens dimension:
901
- # [batch*heads, 1, key_tokens] ->
902
- # [batch*heads, query_tokens, key_tokens]
903
- # so that it can be added as a bias onto the attention scores that xformers computes:
904
- # [batch*heads, query_tokens, key_tokens]
905
- # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
906
- _, query_tokens, _ = hidden_states.shape
907
- attention_mask = attention_mask.expand(-1, query_tokens, -1)
908
-
909
- if attn.group_norm is not None:
910
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
911
-
912
- query = attn.to_q(hidden_states)
913
-
914
- if encoder_hidden_states is None:
915
- encoder_hidden_states = hidden_states
916
- elif attn.norm_cross:
917
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
918
-
919
- key = attn.to_k(encoder_hidden_states)
920
- value = attn.to_v(encoder_hidden_states)
921
-
922
- assert num_tasks == 2 # only support two tasks now
923
-
924
- def transpose(tensor):
925
- tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # bv hw c
926
- tensor = torch.cat([tensor_0, tensor_1], dim=1) # bv 2hw c
927
- return tensor
928
- key = transpose(key)
929
- value = transpose(value)
930
- query = transpose(query)
931
- # from icecream import ic
932
- # ic(key.shape, value.shape, query.shape)
933
- # import pdb;pdb.set_trace()
934
- query = attn.head_to_batch_dim(query).contiguous()
935
- key = attn.head_to_batch_dim(key).contiguous()
936
- value = attn.head_to_batch_dim(value).contiguous()
937
-
938
- hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
939
- hidden_states = attn.batch_to_head_dim(hidden_states)
940
-
941
- # linear proj
942
- hidden_states = attn.to_out[0](hidden_states)
943
- # dropout
944
- hidden_states = attn.to_out[1](hidden_states)
945
- hidden_states_normal, hidden_states_color = torch.chunk(hidden_states, dim=1, chunks=2)
946
- hidden_states = torch.cat([hidden_states_normal, hidden_states_color], dim=0) # 2bv hw c
947
-
948
- if input_ndim == 4:
949
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
950
-
951
- if attn.residual_connection:
952
- hidden_states = hidden_states + residual
953
-
954
- hidden_states = hidden_states / attn.rescale_output_factor
955
-
956
- return hidden_states
957
-
958
-
959
- class JointAttnProcessor:
960
- r"""
961
- Default processor for performing attention-related computations.
962
- """
963
-
964
- def __call__(
965
- self,
966
- attn: Attention,
967
- hidden_states,
968
- encoder_hidden_states=None,
969
- attention_mask=None,
970
- temb=None,
971
- num_tasks=2
972
- ):
973
-
974
- residual = hidden_states
975
-
976
- if attn.spatial_norm is not None:
977
- hidden_states = attn.spatial_norm(hidden_states, temb)
978
-
979
- input_ndim = hidden_states.ndim
980
-
981
- if input_ndim == 4:
982
- batch_size, channel, height, width = hidden_states.shape
983
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
984
-
985
- batch_size, sequence_length, _ = (
986
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
987
- )
988
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
989
-
990
-
991
- if attn.group_norm is not None:
992
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
993
-
994
- query = attn.to_q(hidden_states)
995
-
996
- if encoder_hidden_states is None:
997
- encoder_hidden_states = hidden_states
998
- elif attn.norm_cross:
999
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1000
-
1001
- key = attn.to_k(encoder_hidden_states)
1002
- value = attn.to_v(encoder_hidden_states)
1003
-
1004
- assert num_tasks == 2 # only support two tasks now
1005
-
1006
- def transpose(tensor):
1007
- tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # bv hw c
1008
- tensor = torch.cat([tensor_0, tensor_1], dim=1) # bv 2hw c
1009
- return tensor
1010
- key = transpose(key)
1011
- value = transpose(value)
1012
- query = transpose(query)
1013
-
1014
-
1015
- query = attn.head_to_batch_dim(query).contiguous()
1016
- key = attn.head_to_batch_dim(key).contiguous()
1017
- value = attn.head_to_batch_dim(value).contiguous()
1018
-
1019
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
1020
- hidden_states = torch.bmm(attention_probs, value)
1021
- hidden_states = attn.batch_to_head_dim(hidden_states)
1022
-
1023
-
1024
- # linear proj
1025
- hidden_states = attn.to_out[0](hidden_states)
1026
- # dropout
1027
- hidden_states = attn.to_out[1](hidden_states)
1028
-
1029
- hidden_states = torch.cat([hidden_states[:, 0], hidden_states[:, 1]], dim=0) # 2bv hw c
1030
- if input_ndim == 4:
1031
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1032
-
1033
- if attn.residual_connection:
1034
- hidden_states = hidden_states + residual
1035
-
1036
- hidden_states = hidden_states / attn.rescale_output_factor
1037
-
1038
- return hidden_states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mvdiffusion/models/unet_mv2d_blocks.py DELETED
@@ -1,971 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from typing import Any, Dict, Optional, Tuple
15
-
16
- import numpy as np
17
- import torch
18
- import torch.nn.functional as F
19
- from torch import nn
20
-
21
- from diffusers.utils import is_torch_version, logging
22
- from diffusers.models.normalization import AdaGroupNorm
23
- from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
24
- from diffusers.models.dual_transformer_2d import DualTransformer2DModel
25
- from diffusers.models.resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D
26
-
27
- from diffusers.models.unets.unet_2d_blocks import DownBlock2D, ResnetDownsampleBlock2D, AttnDownBlock2D, CrossAttnDownBlock2D, SimpleCrossAttnDownBlock2D, SkipDownBlock2D, AttnSkipDownBlock2D, DownEncoderBlock2D, AttnDownEncoderBlock2D, KDownBlock2D, KCrossAttnDownBlock2D
28
- from diffusers.models.unets.unet_2d_blocks import UpBlock2D, ResnetUpsampleBlock2D, CrossAttnUpBlock2D, SimpleCrossAttnUpBlock2D, AttnUpBlock2D, SkipUpBlock2D, AttnSkipUpBlock2D, UpDecoderBlock2D, AttnUpDecoderBlock2D, KUpBlock2D, KCrossAttnUpBlock2D
29
-
30
-
31
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
-
33
-
34
- def get_down_block(
35
- down_block_type,
36
- num_layers,
37
- in_channels,
38
- out_channels,
39
- temb_channels,
40
- add_downsample,
41
- resnet_eps,
42
- resnet_act_fn,
43
- transformer_layers_per_block=1,
44
- num_attention_heads=None,
45
- resnet_groups=None,
46
- cross_attention_dim=None,
47
- downsample_padding=None,
48
- dual_cross_attention=False,
49
- use_linear_projection=False,
50
- only_cross_attention=False,
51
- upcast_attention=False,
52
- resnet_time_scale_shift="default",
53
- resnet_skip_time_act=False,
54
- resnet_out_scale_factor=1.0,
55
- cross_attention_norm=None,
56
- attention_head_dim=None,
57
- downsample_type=None,
58
- num_views=1,
59
- cd_attention_last: bool = False,
60
- cd_attention_mid: bool = False,
61
- multiview_attention: bool = True,
62
- sparse_mv_attention: bool = False,
63
- selfattn_block: str = "custom",
64
- mvcd_attention: bool=False,
65
- use_dino: bool = False
66
- ):
67
- # If attn head dim is not defined, we default it to the number of heads
68
- if attention_head_dim is None:
69
- logger.warn(
70
- f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
71
- )
72
- attention_head_dim = num_attention_heads
73
-
74
- down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
75
- if down_block_type == "DownBlock2D":
76
- return DownBlock2D(
77
- num_layers=num_layers,
78
- in_channels=in_channels,
79
- out_channels=out_channels,
80
- temb_channels=temb_channels,
81
- add_downsample=add_downsample,
82
- resnet_eps=resnet_eps,
83
- resnet_act_fn=resnet_act_fn,
84
- resnet_groups=resnet_groups,
85
- downsample_padding=downsample_padding,
86
- resnet_time_scale_shift=resnet_time_scale_shift,
87
- )
88
- elif down_block_type == "ResnetDownsampleBlock2D":
89
- return ResnetDownsampleBlock2D(
90
- num_layers=num_layers,
91
- in_channels=in_channels,
92
- out_channels=out_channels,
93
- temb_channels=temb_channels,
94
- add_downsample=add_downsample,
95
- resnet_eps=resnet_eps,
96
- resnet_act_fn=resnet_act_fn,
97
- resnet_groups=resnet_groups,
98
- resnet_time_scale_shift=resnet_time_scale_shift,
99
- skip_time_act=resnet_skip_time_act,
100
- output_scale_factor=resnet_out_scale_factor,
101
- )
102
- elif down_block_type == "AttnDownBlock2D":
103
- if add_downsample is False:
104
- downsample_type = None
105
- else:
106
- downsample_type = downsample_type or "conv" # default to 'conv'
107
- return AttnDownBlock2D(
108
- num_layers=num_layers,
109
- in_channels=in_channels,
110
- out_channels=out_channels,
111
- temb_channels=temb_channels,
112
- resnet_eps=resnet_eps,
113
- resnet_act_fn=resnet_act_fn,
114
- resnet_groups=resnet_groups,
115
- downsample_padding=downsample_padding,
116
- attention_head_dim=attention_head_dim,
117
- resnet_time_scale_shift=resnet_time_scale_shift,
118
- downsample_type=downsample_type,
119
- )
120
- elif down_block_type == "CrossAttnDownBlock2D":
121
- if cross_attention_dim is None:
122
- raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
123
- return CrossAttnDownBlock2D(
124
- num_layers=num_layers,
125
- transformer_layers_per_block=transformer_layers_per_block,
126
- in_channels=in_channels,
127
- out_channels=out_channels,
128
- temb_channels=temb_channels,
129
- add_downsample=add_downsample,
130
- resnet_eps=resnet_eps,
131
- resnet_act_fn=resnet_act_fn,
132
- resnet_groups=resnet_groups,
133
- downsample_padding=downsample_padding,
134
- cross_attention_dim=cross_attention_dim,
135
- num_attention_heads=num_attention_heads,
136
- dual_cross_attention=dual_cross_attention,
137
- use_linear_projection=use_linear_projection,
138
- only_cross_attention=only_cross_attention,
139
- upcast_attention=upcast_attention,
140
- resnet_time_scale_shift=resnet_time_scale_shift,
141
- )
142
- # custom MV2D attention block
143
- elif down_block_type == "CrossAttnDownBlockMV2D":
144
- if cross_attention_dim is None:
145
- raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMV2D")
146
- return CrossAttnDownBlockMV2D(
147
- num_layers=num_layers,
148
- transformer_layers_per_block=transformer_layers_per_block,
149
- in_channels=in_channels,
150
- out_channels=out_channels,
151
- temb_channels=temb_channels,
152
- add_downsample=add_downsample,
153
- resnet_eps=resnet_eps,
154
- resnet_act_fn=resnet_act_fn,
155
- resnet_groups=resnet_groups,
156
- downsample_padding=downsample_padding,
157
- cross_attention_dim=cross_attention_dim,
158
- num_attention_heads=num_attention_heads,
159
- dual_cross_attention=dual_cross_attention,
160
- use_linear_projection=use_linear_projection,
161
- only_cross_attention=only_cross_attention,
162
- upcast_attention=upcast_attention,
163
- resnet_time_scale_shift=resnet_time_scale_shift,
164
- num_views=num_views,
165
- cd_attention_last=cd_attention_last,
166
- cd_attention_mid=cd_attention_mid,
167
- multiview_attention=multiview_attention,
168
- sparse_mv_attention=sparse_mv_attention,
169
- selfattn_block=selfattn_block,
170
- mvcd_attention=mvcd_attention,
171
- use_dino=use_dino
172
- )
173
- elif down_block_type == "SimpleCrossAttnDownBlock2D":
174
- if cross_attention_dim is None:
175
- raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D")
176
- return SimpleCrossAttnDownBlock2D(
177
- num_layers=num_layers,
178
- in_channels=in_channels,
179
- out_channels=out_channels,
180
- temb_channels=temb_channels,
181
- add_downsample=add_downsample,
182
- resnet_eps=resnet_eps,
183
- resnet_act_fn=resnet_act_fn,
184
- resnet_groups=resnet_groups,
185
- cross_attention_dim=cross_attention_dim,
186
- attention_head_dim=attention_head_dim,
187
- resnet_time_scale_shift=resnet_time_scale_shift,
188
- skip_time_act=resnet_skip_time_act,
189
- output_scale_factor=resnet_out_scale_factor,
190
- only_cross_attention=only_cross_attention,
191
- cross_attention_norm=cross_attention_norm,
192
- )
193
- elif down_block_type == "SkipDownBlock2D":
194
- return SkipDownBlock2D(
195
- num_layers=num_layers,
196
- in_channels=in_channels,
197
- out_channels=out_channels,
198
- temb_channels=temb_channels,
199
- add_downsample=add_downsample,
200
- resnet_eps=resnet_eps,
201
- resnet_act_fn=resnet_act_fn,
202
- downsample_padding=downsample_padding,
203
- resnet_time_scale_shift=resnet_time_scale_shift,
204
- )
205
- elif down_block_type == "AttnSkipDownBlock2D":
206
- return AttnSkipDownBlock2D(
207
- num_layers=num_layers,
208
- in_channels=in_channels,
209
- out_channels=out_channels,
210
- temb_channels=temb_channels,
211
- add_downsample=add_downsample,
212
- resnet_eps=resnet_eps,
213
- resnet_act_fn=resnet_act_fn,
214
- attention_head_dim=attention_head_dim,
215
- resnet_time_scale_shift=resnet_time_scale_shift,
216
- )
217
- elif down_block_type == "DownEncoderBlock2D":
218
- return DownEncoderBlock2D(
219
- num_layers=num_layers,
220
- in_channels=in_channels,
221
- out_channels=out_channels,
222
- add_downsample=add_downsample,
223
- resnet_eps=resnet_eps,
224
- resnet_act_fn=resnet_act_fn,
225
- resnet_groups=resnet_groups,
226
- downsample_padding=downsample_padding,
227
- resnet_time_scale_shift=resnet_time_scale_shift,
228
- )
229
- elif down_block_type == "AttnDownEncoderBlock2D":
230
- return AttnDownEncoderBlock2D(
231
- num_layers=num_layers,
232
- in_channels=in_channels,
233
- out_channels=out_channels,
234
- add_downsample=add_downsample,
235
- resnet_eps=resnet_eps,
236
- resnet_act_fn=resnet_act_fn,
237
- resnet_groups=resnet_groups,
238
- downsample_padding=downsample_padding,
239
- attention_head_dim=attention_head_dim,
240
- resnet_time_scale_shift=resnet_time_scale_shift,
241
- )
242
- elif down_block_type == "KDownBlock2D":
243
- return KDownBlock2D(
244
- num_layers=num_layers,
245
- in_channels=in_channels,
246
- out_channels=out_channels,
247
- temb_channels=temb_channels,
248
- add_downsample=add_downsample,
249
- resnet_eps=resnet_eps,
250
- resnet_act_fn=resnet_act_fn,
251
- )
252
- elif down_block_type == "KCrossAttnDownBlock2D":
253
- return KCrossAttnDownBlock2D(
254
- num_layers=num_layers,
255
- in_channels=in_channels,
256
- out_channels=out_channels,
257
- temb_channels=temb_channels,
258
- add_downsample=add_downsample,
259
- resnet_eps=resnet_eps,
260
- resnet_act_fn=resnet_act_fn,
261
- cross_attention_dim=cross_attention_dim,
262
- attention_head_dim=attention_head_dim,
263
- add_self_attention=True if not add_downsample else False,
264
- )
265
- raise ValueError(f"{down_block_type} does not exist.")
266
-
267
-
268
- def get_up_block(
269
- up_block_type,
270
- num_layers,
271
- in_channels,
272
- out_channels,
273
- prev_output_channel,
274
- temb_channels,
275
- add_upsample,
276
- resnet_eps,
277
- resnet_act_fn,
278
- transformer_layers_per_block=1,
279
- num_attention_heads=None,
280
- resnet_groups=None,
281
- cross_attention_dim=None,
282
- dual_cross_attention=False,
283
- use_linear_projection=False,
284
- only_cross_attention=False,
285
- upcast_attention=False,
286
- resnet_time_scale_shift="default",
287
- resnet_skip_time_act=False,
288
- resnet_out_scale_factor=1.0,
289
- cross_attention_norm=None,
290
- attention_head_dim=None,
291
- upsample_type=None,
292
- num_views=1,
293
- cd_attention_last: bool = False,
294
- cd_attention_mid: bool = False,
295
- multiview_attention: bool = True,
296
- sparse_mv_attention: bool = False,
297
- selfattn_block: str = "custom",
298
- mvcd_attention: bool=False,
299
- use_dino: bool = False
300
- ):
301
- # If attn head dim is not defined, we default it to the number of heads
302
- if attention_head_dim is None:
303
- logger.warn(
304
- f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
305
- )
306
- attention_head_dim = num_attention_heads
307
-
308
- up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
309
- if up_block_type == "UpBlock2D":
310
- return UpBlock2D(
311
- num_layers=num_layers,
312
- in_channels=in_channels,
313
- out_channels=out_channels,
314
- prev_output_channel=prev_output_channel,
315
- temb_channels=temb_channels,
316
- add_upsample=add_upsample,
317
- resnet_eps=resnet_eps,
318
- resnet_act_fn=resnet_act_fn,
319
- resnet_groups=resnet_groups,
320
- resnet_time_scale_shift=resnet_time_scale_shift,
321
- )
322
- elif up_block_type == "ResnetUpsampleBlock2D":
323
- return ResnetUpsampleBlock2D(
324
- num_layers=num_layers,
325
- in_channels=in_channels,
326
- out_channels=out_channels,
327
- prev_output_channel=prev_output_channel,
328
- temb_channels=temb_channels,
329
- add_upsample=add_upsample,
330
- resnet_eps=resnet_eps,
331
- resnet_act_fn=resnet_act_fn,
332
- resnet_groups=resnet_groups,
333
- resnet_time_scale_shift=resnet_time_scale_shift,
334
- skip_time_act=resnet_skip_time_act,
335
- output_scale_factor=resnet_out_scale_factor,
336
- )
337
- elif up_block_type == "CrossAttnUpBlock2D":
338
- if cross_attention_dim is None:
339
- raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
340
- return CrossAttnUpBlock2D(
341
- num_layers=num_layers,
342
- transformer_layers_per_block=transformer_layers_per_block,
343
- in_channels=in_channels,
344
- out_channels=out_channels,
345
- prev_output_channel=prev_output_channel,
346
- temb_channels=temb_channels,
347
- add_upsample=add_upsample,
348
- resnet_eps=resnet_eps,
349
- resnet_act_fn=resnet_act_fn,
350
- resnet_groups=resnet_groups,
351
- cross_attention_dim=cross_attention_dim,
352
- num_attention_heads=num_attention_heads,
353
- dual_cross_attention=dual_cross_attention,
354
- use_linear_projection=use_linear_projection,
355
- only_cross_attention=only_cross_attention,
356
- upcast_attention=upcast_attention,
357
- resnet_time_scale_shift=resnet_time_scale_shift,
358
- )
359
- # custom MV2D attention block
360
- elif up_block_type == "CrossAttnUpBlockMV2D":
361
- if cross_attention_dim is None:
362
- raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMV2D")
363
- return CrossAttnUpBlockMV2D(
364
- num_layers=num_layers,
365
- transformer_layers_per_block=transformer_layers_per_block,
366
- in_channels=in_channels,
367
- out_channels=out_channels,
368
- prev_output_channel=prev_output_channel,
369
- temb_channels=temb_channels,
370
- add_upsample=add_upsample,
371
- resnet_eps=resnet_eps,
372
- resnet_act_fn=resnet_act_fn,
373
- resnet_groups=resnet_groups,
374
- cross_attention_dim=cross_attention_dim,
375
- num_attention_heads=num_attention_heads,
376
- dual_cross_attention=dual_cross_attention,
377
- use_linear_projection=use_linear_projection,
378
- only_cross_attention=only_cross_attention,
379
- upcast_attention=upcast_attention,
380
- resnet_time_scale_shift=resnet_time_scale_shift,
381
- num_views=num_views,
382
- cd_attention_last=cd_attention_last,
383
- cd_attention_mid=cd_attention_mid,
384
- multiview_attention=multiview_attention,
385
- sparse_mv_attention=sparse_mv_attention,
386
- selfattn_block=selfattn_block,
387
- mvcd_attention=mvcd_attention,
388
- use_dino=use_dino
389
- )
390
- elif up_block_type == "SimpleCrossAttnUpBlock2D":
391
- if cross_attention_dim is None:
392
- raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D")
393
- return SimpleCrossAttnUpBlock2D(
394
- num_layers=num_layers,
395
- in_channels=in_channels,
396
- out_channels=out_channels,
397
- prev_output_channel=prev_output_channel,
398
- temb_channels=temb_channels,
399
- add_upsample=add_upsample,
400
- resnet_eps=resnet_eps,
401
- resnet_act_fn=resnet_act_fn,
402
- resnet_groups=resnet_groups,
403
- cross_attention_dim=cross_attention_dim,
404
- attention_head_dim=attention_head_dim,
405
- resnet_time_scale_shift=resnet_time_scale_shift,
406
- skip_time_act=resnet_skip_time_act,
407
- output_scale_factor=resnet_out_scale_factor,
408
- only_cross_attention=only_cross_attention,
409
- cross_attention_norm=cross_attention_norm,
410
- )
411
- elif up_block_type == "AttnUpBlock2D":
412
- if add_upsample is False:
413
- upsample_type = None
414
- else:
415
- upsample_type = upsample_type or "conv" # default to 'conv'
416
-
417
- return AttnUpBlock2D(
418
- num_layers=num_layers,
419
- in_channels=in_channels,
420
- out_channels=out_channels,
421
- prev_output_channel=prev_output_channel,
422
- temb_channels=temb_channels,
423
- resnet_eps=resnet_eps,
424
- resnet_act_fn=resnet_act_fn,
425
- resnet_groups=resnet_groups,
426
- attention_head_dim=attention_head_dim,
427
- resnet_time_scale_shift=resnet_time_scale_shift,
428
- upsample_type=upsample_type,
429
- )
430
- elif up_block_type == "SkipUpBlock2D":
431
- return SkipUpBlock2D(
432
- num_layers=num_layers,
433
- in_channels=in_channels,
434
- out_channels=out_channels,
435
- prev_output_channel=prev_output_channel,
436
- temb_channels=temb_channels,
437
- add_upsample=add_upsample,
438
- resnet_eps=resnet_eps,
439
- resnet_act_fn=resnet_act_fn,
440
- resnet_time_scale_shift=resnet_time_scale_shift,
441
- )
442
- elif up_block_type == "AttnSkipUpBlock2D":
443
- return AttnSkipUpBlock2D(
444
- num_layers=num_layers,
445
- in_channels=in_channels,
446
- out_channels=out_channels,
447
- prev_output_channel=prev_output_channel,
448
- temb_channels=temb_channels,
449
- add_upsample=add_upsample,
450
- resnet_eps=resnet_eps,
451
- resnet_act_fn=resnet_act_fn,
452
- attention_head_dim=attention_head_dim,
453
- resnet_time_scale_shift=resnet_time_scale_shift,
454
- )
455
- elif up_block_type == "UpDecoderBlock2D":
456
- return UpDecoderBlock2D(
457
- num_layers=num_layers,
458
- in_channels=in_channels,
459
- out_channels=out_channels,
460
- add_upsample=add_upsample,
461
- resnet_eps=resnet_eps,
462
- resnet_act_fn=resnet_act_fn,
463
- resnet_groups=resnet_groups,
464
- resnet_time_scale_shift=resnet_time_scale_shift,
465
- temb_channels=temb_channels,
466
- )
467
- elif up_block_type == "AttnUpDecoderBlock2D":
468
- return AttnUpDecoderBlock2D(
469
- num_layers=num_layers,
470
- in_channels=in_channels,
471
- out_channels=out_channels,
472
- add_upsample=add_upsample,
473
- resnet_eps=resnet_eps,
474
- resnet_act_fn=resnet_act_fn,
475
- resnet_groups=resnet_groups,
476
- attention_head_dim=attention_head_dim,
477
- resnet_time_scale_shift=resnet_time_scale_shift,
478
- temb_channels=temb_channels,
479
- )
480
- elif up_block_type == "KUpBlock2D":
481
- return KUpBlock2D(
482
- num_layers=num_layers,
483
- in_channels=in_channels,
484
- out_channels=out_channels,
485
- temb_channels=temb_channels,
486
- add_upsample=add_upsample,
487
- resnet_eps=resnet_eps,
488
- resnet_act_fn=resnet_act_fn,
489
- )
490
- elif up_block_type == "KCrossAttnUpBlock2D":
491
- return KCrossAttnUpBlock2D(
492
- num_layers=num_layers,
493
- in_channels=in_channels,
494
- out_channels=out_channels,
495
- temb_channels=temb_channels,
496
- add_upsample=add_upsample,
497
- resnet_eps=resnet_eps,
498
- resnet_act_fn=resnet_act_fn,
499
- cross_attention_dim=cross_attention_dim,
500
- attention_head_dim=attention_head_dim,
501
- )
502
-
503
- raise ValueError(f"{up_block_type} does not exist.")
504
-
505
-
506
- class UNetMidBlockMV2DCrossAttn(nn.Module):
507
- def __init__(
508
- self,
509
- in_channels: int,
510
- temb_channels: int,
511
- dropout: float = 0.0,
512
- num_layers: int = 1,
513
- transformer_layers_per_block: int = 1,
514
- resnet_eps: float = 1e-6,
515
- resnet_time_scale_shift: str = "default",
516
- resnet_act_fn: str = "swish",
517
- resnet_groups: int = 32,
518
- resnet_pre_norm: bool = True,
519
- num_attention_heads=1,
520
- output_scale_factor=1.0,
521
- cross_attention_dim=1280,
522
- dual_cross_attention=False,
523
- use_linear_projection=False,
524
- upcast_attention=False,
525
- num_views: int = 1,
526
- cd_attention_last: bool = False,
527
- cd_attention_mid: bool = False,
528
- multiview_attention: bool = True,
529
- sparse_mv_attention: bool = False,
530
- selfattn_block: str = "custom",
531
- mvcd_attention: bool=False,
532
- use_dino: bool = False
533
- ):
534
- super().__init__()
535
-
536
- self.has_cross_attention = True
537
- self.num_attention_heads = num_attention_heads
538
- resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
539
- if selfattn_block == "custom":
540
- from .transformer_mv2d import TransformerMV2DModel
541
- elif selfattn_block == "rowwise":
542
- from .transformer_mv2d_rowwise import TransformerMV2DModel
543
- elif selfattn_block == "self_rowwise":
544
- from .transformer_mv2d_self_rowwise import TransformerMV2DModel
545
- else:
546
- raise NotImplementedError
547
-
548
- # there is always at least one resnet
549
- resnets = [
550
- ResnetBlock2D(
551
- in_channels=in_channels,
552
- out_channels=in_channels,
553
- temb_channels=temb_channels,
554
- eps=resnet_eps,
555
- groups=resnet_groups,
556
- dropout=dropout,
557
- time_embedding_norm=resnet_time_scale_shift,
558
- non_linearity=resnet_act_fn,
559
- output_scale_factor=output_scale_factor,
560
- pre_norm=resnet_pre_norm,
561
- )
562
- ]
563
- attentions = []
564
-
565
- for _ in range(num_layers):
566
- if not dual_cross_attention:
567
- attentions.append(
568
- TransformerMV2DModel(
569
- num_attention_heads,
570
- in_channels // num_attention_heads,
571
- in_channels=in_channels,
572
- num_layers=transformer_layers_per_block,
573
- cross_attention_dim=cross_attention_dim,
574
- norm_num_groups=resnet_groups,
575
- use_linear_projection=use_linear_projection,
576
- upcast_attention=upcast_attention,
577
- num_views=num_views,
578
- cd_attention_last=cd_attention_last,
579
- cd_attention_mid=cd_attention_mid,
580
- multiview_attention=multiview_attention,
581
- sparse_mv_attention=sparse_mv_attention,
582
- mvcd_attention=mvcd_attention,
583
- use_dino=use_dino
584
- )
585
- )
586
- else:
587
- raise NotImplementedError
588
- resnets.append(
589
- ResnetBlock2D(
590
- in_channels=in_channels,
591
- out_channels=in_channels,
592
- temb_channels=temb_channels,
593
- eps=resnet_eps,
594
- groups=resnet_groups,
595
- dropout=dropout,
596
- time_embedding_norm=resnet_time_scale_shift,
597
- non_linearity=resnet_act_fn,
598
- output_scale_factor=output_scale_factor,
599
- pre_norm=resnet_pre_norm,
600
- )
601
- )
602
-
603
- self.attentions = nn.ModuleList(attentions)
604
- self.resnets = nn.ModuleList(resnets)
605
-
606
- def forward(
607
- self,
608
- hidden_states: torch.FloatTensor,
609
- temb: Optional[torch.FloatTensor] = None,
610
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
611
- attention_mask: Optional[torch.FloatTensor] = None,
612
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
613
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
614
- dino_feature: Optional[torch.FloatTensor] = None
615
- ) -> torch.FloatTensor:
616
- hidden_states = self.resnets[0](hidden_states, temb)
617
- for attn, resnet in zip(self.attentions, self.resnets[1:]):
618
- hidden_states = attn(
619
- hidden_states,
620
- encoder_hidden_states=encoder_hidden_states,
621
- cross_attention_kwargs=cross_attention_kwargs,
622
- attention_mask=attention_mask,
623
- encoder_attention_mask=encoder_attention_mask,
624
- dino_feature=dino_feature,
625
- return_dict=False,
626
- )[0]
627
- hidden_states = resnet(hidden_states, temb)
628
-
629
- return hidden_states
630
-
631
-
632
- class CrossAttnUpBlockMV2D(nn.Module):
633
- def __init__(
634
- self,
635
- in_channels: int,
636
- out_channels: int,
637
- prev_output_channel: int,
638
- temb_channels: int,
639
- dropout: float = 0.0,
640
- num_layers: int = 1,
641
- transformer_layers_per_block: int = 1,
642
- resnet_eps: float = 1e-6,
643
- resnet_time_scale_shift: str = "default",
644
- resnet_act_fn: str = "swish",
645
- resnet_groups: int = 32,
646
- resnet_pre_norm: bool = True,
647
- num_attention_heads=1,
648
- cross_attention_dim=1280,
649
- output_scale_factor=1.0,
650
- add_upsample=True,
651
- dual_cross_attention=False,
652
- use_linear_projection=False,
653
- only_cross_attention=False,
654
- upcast_attention=False,
655
- num_views: int = 1,
656
- cd_attention_last: bool = False,
657
- cd_attention_mid: bool = False,
658
- multiview_attention: bool = True,
659
- sparse_mv_attention: bool = False,
660
- selfattn_block: str = "custom",
661
- mvcd_attention: bool=False,
662
- use_dino: bool = False
663
- ):
664
- super().__init__()
665
- resnets = []
666
- attentions = []
667
-
668
- self.has_cross_attention = True
669
- self.num_attention_heads = num_attention_heads
670
-
671
- if selfattn_block == "custom":
672
- from .transformer_mv2d import TransformerMV2DModel
673
- elif selfattn_block == "rowwise":
674
- from .transformer_mv2d_rowwise import TransformerMV2DModel
675
- elif selfattn_block == "self_rowwise":
676
- from .transformer_mv2d_self_rowwise import TransformerMV2DModel
677
- else:
678
- raise NotImplementedError
679
-
680
- for i in range(num_layers):
681
- res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
682
- resnet_in_channels = prev_output_channel if i == 0 else out_channels
683
-
684
- resnets.append(
685
- ResnetBlock2D(
686
- in_channels=resnet_in_channels + res_skip_channels,
687
- out_channels=out_channels,
688
- temb_channels=temb_channels,
689
- eps=resnet_eps,
690
- groups=resnet_groups,
691
- dropout=dropout,
692
- time_embedding_norm=resnet_time_scale_shift,
693
- non_linearity=resnet_act_fn,
694
- output_scale_factor=output_scale_factor,
695
- pre_norm=resnet_pre_norm,
696
- )
697
- )
698
- if not dual_cross_attention:
699
- attentions.append(
700
- TransformerMV2DModel(
701
- num_attention_heads,
702
- out_channels // num_attention_heads,
703
- in_channels=out_channels,
704
- num_layers=transformer_layers_per_block,
705
- cross_attention_dim=cross_attention_dim,
706
- norm_num_groups=resnet_groups,
707
- use_linear_projection=use_linear_projection,
708
- only_cross_attention=only_cross_attention,
709
- upcast_attention=upcast_attention,
710
- num_views=num_views,
711
- cd_attention_last=cd_attention_last,
712
- cd_attention_mid=cd_attention_mid,
713
- multiview_attention=multiview_attention,
714
- sparse_mv_attention=sparse_mv_attention,
715
- mvcd_attention=mvcd_attention,
716
- use_dino=use_dino
717
- )
718
- )
719
- else:
720
- raise NotImplementedError
721
- self.attentions = nn.ModuleList(attentions)
722
- self.resnets = nn.ModuleList(resnets)
723
-
724
- if add_upsample:
725
- self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
726
- else:
727
- self.upsamplers = None
728
-
729
- self.gradient_checkpointing = False
730
-
731
- def forward(
732
- self,
733
- hidden_states: torch.FloatTensor,
734
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
735
- temb: Optional[torch.FloatTensor] = None,
736
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
737
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
738
- upsample_size: Optional[int] = None,
739
- attention_mask: Optional[torch.FloatTensor] = None,
740
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
741
- dino_feature: Optional[torch.FloatTensor] = None
742
- ):
743
- for resnet, attn in zip(self.resnets, self.attentions):
744
- # pop res hidden states
745
- res_hidden_states = res_hidden_states_tuple[-1]
746
- res_hidden_states_tuple = res_hidden_states_tuple[:-1]
747
- hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
748
-
749
- if self.training and self.gradient_checkpointing:
750
-
751
- def create_custom_forward(module, return_dict=None):
752
- def custom_forward(*inputs):
753
- if return_dict is not None:
754
- return module(*inputs, return_dict=return_dict)
755
- else:
756
- return module(*inputs)
757
-
758
- return custom_forward
759
-
760
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
761
- hidden_states = torch.utils.checkpoint.checkpoint(
762
- create_custom_forward(resnet),
763
- hidden_states,
764
- temb,
765
- **ckpt_kwargs,
766
- )
767
- hidden_states = torch.utils.checkpoint.checkpoint(
768
- create_custom_forward(attn, return_dict=False),
769
- hidden_states,
770
- encoder_hidden_states,
771
- dino_feature,
772
- None, # timestep
773
- None, # class_labels
774
- cross_attention_kwargs,
775
- attention_mask,
776
- encoder_attention_mask,
777
- **ckpt_kwargs,
778
- )[0]
779
- else:
780
- hidden_states = resnet(hidden_states, temb)
781
- hidden_states = attn(
782
- hidden_states,
783
- encoder_hidden_states=encoder_hidden_states,
784
- cross_attention_kwargs=cross_attention_kwargs,
785
- attention_mask=attention_mask,
786
- encoder_attention_mask=encoder_attention_mask,
787
- dino_feature=dino_feature,
788
- return_dict=False,
789
- )[0]
790
-
791
- if self.upsamplers is not None:
792
- for upsampler in self.upsamplers:
793
- hidden_states = upsampler(hidden_states, upsample_size)
794
-
795
- return hidden_states
796
-
797
-
798
- class CrossAttnDownBlockMV2D(nn.Module):
799
- def __init__(
800
- self,
801
- in_channels: int,
802
- out_channels: int,
803
- temb_channels: int,
804
- dropout: float = 0.0,
805
- num_layers: int = 1,
806
- transformer_layers_per_block: int = 1,
807
- resnet_eps: float = 1e-6,
808
- resnet_time_scale_shift: str = "default",
809
- resnet_act_fn: str = "swish",
810
- resnet_groups: int = 32,
811
- resnet_pre_norm: bool = True,
812
- num_attention_heads=1,
813
- cross_attention_dim=1280,
814
- output_scale_factor=1.0,
815
- downsample_padding=1,
816
- add_downsample=True,
817
- dual_cross_attention=False,
818
- use_linear_projection=False,
819
- only_cross_attention=False,
820
- upcast_attention=False,
821
- num_views: int = 1,
822
- cd_attention_last: bool = False,
823
- cd_attention_mid: bool = False,
824
- multiview_attention: bool = True,
825
- sparse_mv_attention: bool = False,
826
- selfattn_block: str = "custom",
827
- mvcd_attention: bool=False,
828
- use_dino: bool = False
829
- ):
830
- super().__init__()
831
- resnets = []
832
- attentions = []
833
-
834
- self.has_cross_attention = True
835
- self.num_attention_heads = num_attention_heads
836
- if selfattn_block == "custom":
837
- from .transformer_mv2d import TransformerMV2DModel
838
- elif selfattn_block == "rowwise":
839
- from .transformer_mv2d_rowwise import TransformerMV2DModel
840
- elif selfattn_block == "self_rowwise":
841
- from .transformer_mv2d_self_rowwise import TransformerMV2DModel
842
- else:
843
- raise NotImplementedError
844
-
845
- for i in range(num_layers):
846
- in_channels = in_channels if i == 0 else out_channels
847
- resnets.append(
848
- ResnetBlock2D(
849
- in_channels=in_channels,
850
- out_channels=out_channels,
851
- temb_channels=temb_channels,
852
- eps=resnet_eps,
853
- groups=resnet_groups,
854
- dropout=dropout,
855
- time_embedding_norm=resnet_time_scale_shift,
856
- non_linearity=resnet_act_fn,
857
- output_scale_factor=output_scale_factor,
858
- pre_norm=resnet_pre_norm,
859
- )
860
- )
861
- if not dual_cross_attention:
862
- attentions.append(
863
- TransformerMV2DModel(
864
- num_attention_heads,
865
- out_channels // num_attention_heads,
866
- in_channels=out_channels,
867
- num_layers=transformer_layers_per_block,
868
- cross_attention_dim=cross_attention_dim,
869
- norm_num_groups=resnet_groups,
870
- use_linear_projection=use_linear_projection,
871
- only_cross_attention=only_cross_attention,
872
- upcast_attention=upcast_attention,
873
- num_views=num_views,
874
- cd_attention_last=cd_attention_last,
875
- cd_attention_mid=cd_attention_mid,
876
- multiview_attention=multiview_attention,
877
- sparse_mv_attention=sparse_mv_attention,
878
- mvcd_attention=mvcd_attention,
879
- use_dino=use_dino
880
- )
881
- )
882
- else:
883
- raise NotImplementedError
884
- self.attentions = nn.ModuleList(attentions)
885
- self.resnets = nn.ModuleList(resnets)
886
-
887
- if add_downsample:
888
- self.downsamplers = nn.ModuleList(
889
- [
890
- Downsample2D(
891
- out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
892
- )
893
- ]
894
- )
895
- else:
896
- self.downsamplers = None
897
-
898
- self.gradient_checkpointing = False
899
-
900
- def forward(
901
- self,
902
- hidden_states: torch.FloatTensor,
903
- temb: Optional[torch.FloatTensor] = None,
904
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
905
- dino_feature: Optional[torch.FloatTensor] = None,
906
- attention_mask: Optional[torch.FloatTensor] = None,
907
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
908
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
909
- additional_residuals=None,
910
- ):
911
- output_states = ()
912
-
913
- blocks = list(zip(self.resnets, self.attentions))
914
-
915
- for i, (resnet, attn) in enumerate(blocks):
916
- if self.training and self.gradient_checkpointing:
917
-
918
- def create_custom_forward(module, return_dict=None):
919
- def custom_forward(*inputs):
920
- if return_dict is not None:
921
- return module(*inputs, return_dict=return_dict)
922
- else:
923
- return module(*inputs)
924
-
925
- return custom_forward
926
-
927
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
928
- hidden_states = torch.utils.checkpoint.checkpoint(
929
- create_custom_forward(resnet),
930
- hidden_states,
931
- temb,
932
- **ckpt_kwargs,
933
- )
934
- hidden_states = torch.utils.checkpoint.checkpoint(
935
- create_custom_forward(attn, return_dict=False),
936
- hidden_states,
937
- encoder_hidden_states,
938
- dino_feature,
939
- None, # timestep
940
- None, # class_labels
941
- cross_attention_kwargs,
942
- attention_mask,
943
- encoder_attention_mask,
944
- **ckpt_kwargs,
945
- )[0]
946
- else:
947
- hidden_states = resnet(hidden_states, temb)
948
- hidden_states = attn(
949
- hidden_states,
950
- encoder_hidden_states=encoder_hidden_states,
951
- dino_feature=dino_feature,
952
- cross_attention_kwargs=cross_attention_kwargs,
953
- attention_mask=attention_mask,
954
- encoder_attention_mask=encoder_attention_mask,
955
- return_dict=False,
956
- )[0]
957
-
958
- # apply additional residuals to the output of the last pair of resnet and attention blocks
959
- if i == len(blocks) - 1 and additional_residuals is not None:
960
- hidden_states = hidden_states + additional_residuals
961
-
962
- output_states = output_states + (hidden_states,)
963
-
964
- if self.downsamplers is not None:
965
- for downsampler in self.downsamplers:
966
- hidden_states = downsampler(hidden_states)
967
-
968
- output_states = output_states + (hidden_states,)
969
-
970
- return hidden_states, output_states
971
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mvdiffusion/models/unet_mv2d_condition.py DELETED
@@ -1,1686 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from dataclasses import dataclass
15
- from typing import Any, Dict, List, Optional, Tuple, Union
16
- import os
17
-
18
- import torch
19
- import torch.nn as nn
20
- import torch.utils.checkpoint
21
-
22
- from diffusers.configuration_utils import ConfigMixin, register_to_config
23
- from diffusers.loaders import UNet2DConditionLoadersMixin
24
- from diffusers.utils import BaseOutput, logging
25
- from diffusers.models.activations import get_activation
26
- from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
27
- from diffusers.models.embeddings import (
28
- GaussianFourierProjection,
29
- ImageHintTimeEmbedding,
30
- ImageProjection,
31
- ImageTimeEmbedding,
32
- TextImageProjection,
33
- TextImageTimeEmbedding,
34
- TextTimeEmbedding,
35
- TimestepEmbedding,
36
- Timesteps,
37
- )
38
- from diffusers.models.modeling_utils import ModelMixin, load_state_dict, _load_state_dict_into_model
39
- from diffusers.models.unet_2d_blocks import (
40
- CrossAttnDownBlock2D,
41
- CrossAttnUpBlock2D,
42
- DownBlock2D,
43
- UNetMidBlock2DCrossAttn,
44
- UNetMidBlock2DSimpleCrossAttn,
45
- UpBlock2D,
46
- )
47
- from diffusers.utils import (
48
- CONFIG_NAME,
49
- FLAX_WEIGHTS_NAME,
50
- SAFETENSORS_WEIGHTS_NAME,
51
- WEIGHTS_NAME,
52
- _add_variant,
53
- _get_model_file,
54
- deprecate,
55
- is_torch_version,
56
- logging,
57
- )
58
- from diffusers.utils.import_utils import is_accelerate_available
59
- from diffusers.utils.hub_utils import HF_HUB_OFFLINE
60
- from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
61
- DIFFUSERS_CACHE = HUGGINGFACE_HUB_CACHE
62
-
63
- from diffusers import __version__
64
- from .unet_mv2d_blocks import (
65
- CrossAttnDownBlockMV2D,
66
- CrossAttnUpBlockMV2D,
67
- UNetMidBlockMV2DCrossAttn,
68
- get_down_block,
69
- get_up_block,
70
- )
71
- from einops import rearrange, repeat
72
-
73
- from diffusers import __version__
74
- from mvdiffusion.models.unet_mv2d_blocks import (
75
- CrossAttnDownBlockMV2D,
76
- CrossAttnUpBlockMV2D,
77
- UNetMidBlockMV2DCrossAttn,
78
- get_down_block,
79
- get_up_block,
80
- )
81
-
82
-
83
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
84
-
85
-
86
- @dataclass
87
- class UNetMV2DConditionOutput(BaseOutput):
88
- """
89
- The output of [`UNet2DConditionModel`].
90
-
91
- Args:
92
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
93
- The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
94
- """
95
-
96
- sample: torch.FloatTensor = None
97
-
98
-
99
- class ResidualBlock(nn.Module):
100
- def __init__(self, dim):
101
- super(ResidualBlock, self).__init__()
102
- self.linear1 = nn.Linear(dim, dim)
103
- self.activation = nn.SiLU()
104
- self.linear2 = nn.Linear(dim, dim)
105
-
106
- def forward(self, x):
107
- identity = x
108
- out = self.linear1(x)
109
- out = self.activation(out)
110
- out = self.linear2(out)
111
- out += identity
112
- out = self.activation(out)
113
- return out
114
-
115
- class ResidualLiner(nn.Module):
116
- def __init__(self, in_features, out_features, dim, act=None, num_block=1):
117
- super(ResidualLiner, self).__init__()
118
- self.linear_in = nn.Sequential(nn.Linear(in_features, dim), nn.SiLU())
119
-
120
- blocks = nn.ModuleList()
121
- for _ in range(num_block):
122
- blocks.append(ResidualBlock(dim))
123
- self.blocks = blocks
124
-
125
- self.linear_out = nn.Linear(dim, out_features)
126
- self.act = act
127
-
128
- def forward(self, x):
129
- out = self.linear_in(x)
130
- for block in self.blocks:
131
- out = block(out)
132
- out = self.linear_out(out)
133
- if self.act is not None:
134
- out = self.act(out)
135
- return out
136
-
137
- class BasicConvBlock(nn.Module):
138
- def __init__(self, in_channels, out_channels, stride=1):
139
- super(BasicConvBlock, self).__init__()
140
- self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
141
- self.norm1 = nn.GroupNorm(num_groups=8, num_channels=in_channels, affine=True)
142
- self.act = nn.SiLU()
143
- self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
144
- self.norm2 = nn.GroupNorm(num_groups=8, num_channels=in_channels, affine=True)
145
- self.downsample = nn.Sequential()
146
- if stride != 1 or in_channels != out_channels:
147
- self.downsample = nn.Sequential(
148
- nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
149
- nn.GroupNorm(num_groups=8, num_channels=in_channels, affine=True)
150
- )
151
-
152
- def forward(self, x):
153
- identity = x
154
- out = self.conv1(x)
155
- out = self.norm1(out)
156
- out = self.act(out)
157
- out = self.conv2(out)
158
- out = self.norm2(out)
159
- out += self.downsample(identity)
160
- out = self.act(out)
161
- return out
162
-
163
- class UNetMV2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
164
- r"""
165
- A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
166
- shaped output.
167
-
168
- This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
169
- for all models (such as downloading or saving).
170
-
171
- Parameters:
172
- sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
173
- Height and width of input/output sample.
174
- in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
175
- out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
176
- center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
177
- flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
178
- Whether to flip the sin to cos in the time embedding.
179
- freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
180
- down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
181
- The tuple of downsample blocks to use.
182
- mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
183
- Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
184
- `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
185
- up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
186
- The tuple of upsample blocks to use.
187
- only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
188
- Whether to include self-attention in the basic transformer blocks, see
189
- [`~models.attention.BasicTransformerBlock`].
190
- block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
191
- The tuple of output channels for each block.
192
- layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
193
- downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
194
- mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
195
- act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
196
- norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
197
- If `None`, normalization and activation layers is skipped in post-processing.
198
- norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
199
- cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
200
- The dimension of the cross attention features.
201
- transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
202
- The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
203
- [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
204
- [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
205
- encoder_hid_dim (`int`, *optional*, defaults to None):
206
- If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
207
- dimension to `cross_attention_dim`.
208
- encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
209
- If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
210
- embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
211
- attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
212
- num_attention_heads (`int`, *optional*):
213
- The number of attention heads. If not defined, defaults to `attention_head_dim`
214
- resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
215
- for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
216
- class_embed_type (`str`, *optional*, defaults to `None`):
217
- The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
218
- `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
219
- addition_embed_type (`str`, *optional*, defaults to `None`):
220
- Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
221
- "text". "text" will use the `TextTimeEmbedding` layer.
222
- addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
223
- Dimension for the timestep embeddings.
224
- num_class_embeds (`int`, *optional*, defaults to `None`):
225
- Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
226
- class conditioning with `class_embed_type` equal to `None`.
227
- time_embedding_type (`str`, *optional*, defaults to `positional`):
228
- The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
229
- time_embedding_dim (`int`, *optional*, defaults to `None`):
230
- An optional override for the dimension of the projected time embedding.
231
- time_embedding_act_fn (`str`, *optional*, defaults to `None`):
232
- Optional activation function to use only once on the time embeddings before they are passed to the rest of
233
- the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
234
- timestep_post_act (`str`, *optional*, defaults to `None`):
235
- The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
236
- time_cond_proj_dim (`int`, *optional*, defaults to `None`):
237
- The dimension of `cond_proj` layer in the timestep embedding.
238
- conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
239
- conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
240
- projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
241
- `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
242
- class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
243
- embeddings with the class embeddings.
244
- mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
245
- Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
246
- `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
247
- `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
248
- otherwise.
249
- """
250
-
251
- _supports_gradient_checkpointing = True
252
-
253
- @register_to_config
254
- def __init__(
255
- self,
256
- sample_size: Optional[int] = None,
257
- in_channels: int = 4,
258
- out_channels: int = 4,
259
- center_input_sample: bool = False,
260
- flip_sin_to_cos: bool = True,
261
- freq_shift: int = 0,
262
- down_block_types: Tuple[str] = (
263
- "CrossAttnDownBlockMV2D",
264
- "CrossAttnDownBlockMV2D",
265
- "CrossAttnDownBlockMV2D",
266
- "DownBlock2D",
267
- ),
268
- mid_block_type: Optional[str] = "UNetMidBlockMV2DCrossAttn",
269
- up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D"),
270
- only_cross_attention: Union[bool, Tuple[bool]] = False,
271
- block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
272
- layers_per_block: Union[int, Tuple[int]] = 2,
273
- downsample_padding: int = 1,
274
- mid_block_scale_factor: float = 1,
275
- act_fn: str = "silu",
276
- norm_num_groups: Optional[int] = 32,
277
- norm_eps: float = 1e-5,
278
- cross_attention_dim: Union[int, Tuple[int]] = 1280,
279
- transformer_layers_per_block: Union[int, Tuple[int]] = 1,
280
- encoder_hid_dim: Optional[int] = None,
281
- encoder_hid_dim_type: Optional[str] = None,
282
- attention_head_dim: Union[int, Tuple[int]] = 8,
283
- num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
284
- dual_cross_attention: bool = False,
285
- use_linear_projection: bool = False,
286
- class_embed_type: Optional[str] = None,
287
- addition_embed_type: Optional[str] = None,
288
- addition_time_embed_dim: Optional[int] = None,
289
- num_class_embeds: Optional[int] = None,
290
- upcast_attention: bool = False,
291
- resnet_time_scale_shift: str = "default",
292
- resnet_skip_time_act: bool = False,
293
- resnet_out_scale_factor: int = 1.0,
294
- time_embedding_type: str = "positional",
295
- time_embedding_dim: Optional[int] = None,
296
- time_embedding_act_fn: Optional[str] = None,
297
- timestep_post_act: Optional[str] = None,
298
- time_cond_proj_dim: Optional[int] = None,
299
- conv_in_kernel: int = 3,
300
- conv_out_kernel: int = 3,
301
- projection_class_embeddings_input_dim: Optional[int] = None,
302
- projection_camera_embeddings_input_dim: Optional[int] = None,
303
- class_embeddings_concat: bool = False,
304
- mid_block_only_cross_attention: Optional[bool] = None,
305
- cross_attention_norm: Optional[str] = None,
306
- addition_embed_type_num_heads=64,
307
- num_views: int = 1,
308
- cd_attention_last: bool = False,
309
- cd_attention_mid: bool = False,
310
- multiview_attention: bool = True,
311
- sparse_mv_attention: bool = False,
312
- selfattn_block: str = "custom",
313
- mvcd_attention: bool = False,
314
- regress_elevation: bool = False,
315
- regress_focal_length: bool = False,
316
- num_regress_blocks: int = 4,
317
- use_dino: bool = False,
318
- addition_downsample: bool = False,
319
- addition_channels: Optional[Tuple[int]] = (1280, 1280, 1280),
320
- ):
321
- super().__init__()
322
-
323
- self.sample_size = sample_size
324
- self.num_views = num_views
325
- self.mvcd_attention = mvcd_attention
326
- if num_attention_heads is not None:
327
- raise ValueError(
328
- "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
329
- )
330
-
331
- # If `num_attention_heads` is not defined (which is the case for most models)
332
- # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
333
- # The reason for this behavior is to correct for incorrectly named variables that were introduced
334
- # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
335
- # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
336
- # which is why we correct for the naming here.
337
- num_attention_heads = num_attention_heads or attention_head_dim
338
-
339
- # Check inputs
340
- if len(down_block_types) != len(up_block_types):
341
- raise ValueError(
342
- f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
343
- )
344
-
345
- if len(block_out_channels) != len(down_block_types):
346
- raise ValueError(
347
- f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
348
- )
349
-
350
- if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
351
- raise ValueError(
352
- f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
353
- )
354
-
355
- if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
356
- raise ValueError(
357
- f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
358
- )
359
-
360
- if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
361
- raise ValueError(
362
- f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
363
- )
364
-
365
- if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
366
- raise ValueError(
367
- f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
368
- )
369
-
370
- if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
371
- raise ValueError(
372
- f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
373
- )
374
-
375
- # input
376
- conv_in_padding = (conv_in_kernel - 1) // 2
377
- self.conv_in = nn.Conv2d(
378
- in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
379
- )
380
-
381
- # time
382
- if time_embedding_type == "fourier":
383
- time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
384
- if time_embed_dim % 2 != 0:
385
- raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
386
- self.time_proj = GaussianFourierProjection(
387
- time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
388
- )
389
- timestep_input_dim = time_embed_dim
390
- elif time_embedding_type == "positional":
391
- time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
392
-
393
- self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
394
- timestep_input_dim = block_out_channels[0]
395
- else:
396
- raise ValueError(
397
- f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
398
- )
399
-
400
- self.time_embedding = TimestepEmbedding(
401
- timestep_input_dim,
402
- time_embed_dim,
403
- act_fn=act_fn,
404
- post_act_fn=timestep_post_act,
405
- cond_proj_dim=time_cond_proj_dim,
406
- )
407
-
408
- if encoder_hid_dim_type is None and encoder_hid_dim is not None:
409
- encoder_hid_dim_type = "text_proj"
410
- self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
411
- logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
412
-
413
- if encoder_hid_dim is None and encoder_hid_dim_type is not None:
414
- raise ValueError(
415
- f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
416
- )
417
-
418
- if encoder_hid_dim_type == "text_proj":
419
- self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
420
- elif encoder_hid_dim_type == "text_image_proj":
421
- # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
422
- # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
423
- # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
424
- self.encoder_hid_proj = TextImageProjection(
425
- text_embed_dim=encoder_hid_dim,
426
- image_embed_dim=cross_attention_dim,
427
- cross_attention_dim=cross_attention_dim,
428
- )
429
- elif encoder_hid_dim_type == "image_proj":
430
- # Kandinsky 2.2
431
- self.encoder_hid_proj = ImageProjection(
432
- image_embed_dim=encoder_hid_dim,
433
- cross_attention_dim=cross_attention_dim,
434
- )
435
- elif encoder_hid_dim_type is not None:
436
- raise ValueError(
437
- f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
438
- )
439
- else:
440
- self.encoder_hid_proj = None
441
-
442
- # class embedding
443
- if class_embed_type is None and num_class_embeds is not None:
444
- self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
445
- elif class_embed_type == "timestep":
446
- self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
447
- elif class_embed_type == "identity":
448
- self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
449
- elif class_embed_type == "projection":
450
- if projection_class_embeddings_input_dim is None:
451
- raise ValueError(
452
- "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
453
- )
454
- # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
455
- # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
456
- # 2. it projects from an arbitrary input dimension.
457
- #
458
- # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
459
- # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
460
- # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
461
- self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
462
- elif class_embed_type == "simple_projection":
463
- if projection_class_embeddings_input_dim is None:
464
- raise ValueError(
465
- "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
466
- )
467
- self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
468
- else:
469
- self.class_embedding = None
470
-
471
- if addition_embed_type == "text":
472
- if encoder_hid_dim is not None:
473
- text_time_embedding_from_dim = encoder_hid_dim
474
- else:
475
- text_time_embedding_from_dim = cross_attention_dim
476
-
477
- self.add_embedding = TextTimeEmbedding(
478
- text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
479
- )
480
- elif addition_embed_type == "text_image":
481
- # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
482
- # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
483
- # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
484
- self.add_embedding = TextImageTimeEmbedding(
485
- text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
486
- )
487
- elif addition_embed_type == "text_time":
488
- self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
489
- self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
490
- elif addition_embed_type == "image":
491
- # Kandinsky 2.2
492
- self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
493
- elif addition_embed_type == "image_hint":
494
- # Kandinsky 2.2 ControlNet
495
- self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
496
- elif addition_embed_type is not None:
497
- raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
498
-
499
- if time_embedding_act_fn is None:
500
- self.time_embed_act = None
501
- else:
502
- self.time_embed_act = get_activation(time_embedding_act_fn)
503
-
504
- self.down_blocks = nn.ModuleList([])
505
- self.up_blocks = nn.ModuleList([])
506
-
507
- if isinstance(only_cross_attention, bool):
508
- if mid_block_only_cross_attention is None:
509
- mid_block_only_cross_attention = only_cross_attention
510
-
511
- only_cross_attention = [only_cross_attention] * len(down_block_types)
512
-
513
- if mid_block_only_cross_attention is None:
514
- mid_block_only_cross_attention = False
515
-
516
- if isinstance(num_attention_heads, int):
517
- num_attention_heads = (num_attention_heads,) * len(down_block_types)
518
-
519
- if isinstance(attention_head_dim, int):
520
- attention_head_dim = (attention_head_dim,) * len(down_block_types)
521
-
522
- if isinstance(cross_attention_dim, int):
523
- cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
524
-
525
- if isinstance(layers_per_block, int):
526
- layers_per_block = [layers_per_block] * len(down_block_types)
527
-
528
- if isinstance(transformer_layers_per_block, int):
529
- transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
530
-
531
- if class_embeddings_concat:
532
- # The time embeddings are concatenated with the class embeddings. The dimension of the
533
- # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
534
- # regular time embeddings
535
- blocks_time_embed_dim = time_embed_dim * 2
536
- else:
537
- blocks_time_embed_dim = time_embed_dim
538
-
539
- # down
540
- output_channel = block_out_channels[0]
541
- for i, down_block_type in enumerate(down_block_types):
542
- input_channel = output_channel
543
- output_channel = block_out_channels[i]
544
- is_final_block = i == len(block_out_channels) - 1
545
-
546
- down_block = get_down_block(
547
- down_block_type,
548
- num_layers=layers_per_block[i],
549
- transformer_layers_per_block=transformer_layers_per_block[i],
550
- in_channels=input_channel,
551
- out_channels=output_channel,
552
- temb_channels=blocks_time_embed_dim,
553
- add_downsample=not is_final_block,
554
- resnet_eps=norm_eps,
555
- resnet_act_fn=act_fn,
556
- resnet_groups=norm_num_groups,
557
- cross_attention_dim=cross_attention_dim[i],
558
- num_attention_heads=num_attention_heads[i],
559
- downsample_padding=downsample_padding,
560
- dual_cross_attention=dual_cross_attention,
561
- use_linear_projection=use_linear_projection,
562
- only_cross_attention=only_cross_attention[i],
563
- upcast_attention=upcast_attention,
564
- resnet_time_scale_shift=resnet_time_scale_shift,
565
- resnet_skip_time_act=resnet_skip_time_act,
566
- resnet_out_scale_factor=resnet_out_scale_factor,
567
- cross_attention_norm=cross_attention_norm,
568
- attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
569
- num_views=num_views,
570
- cd_attention_last=cd_attention_last,
571
- cd_attention_mid=cd_attention_mid,
572
- multiview_attention=multiview_attention,
573
- sparse_mv_attention=sparse_mv_attention,
574
- selfattn_block=selfattn_block,
575
- mvcd_attention=mvcd_attention,
576
- use_dino=use_dino
577
- )
578
- self.down_blocks.append(down_block)
579
-
580
- # mid
581
- if mid_block_type == "UNetMidBlock2DCrossAttn":
582
- self.mid_block = UNetMidBlock2DCrossAttn(
583
- transformer_layers_per_block=transformer_layers_per_block[-1],
584
- in_channels=block_out_channels[-1],
585
- temb_channels=blocks_time_embed_dim,
586
- resnet_eps=norm_eps,
587
- resnet_act_fn=act_fn,
588
- output_scale_factor=mid_block_scale_factor,
589
- resnet_time_scale_shift=resnet_time_scale_shift,
590
- cross_attention_dim=cross_attention_dim[-1],
591
- num_attention_heads=num_attention_heads[-1],
592
- resnet_groups=norm_num_groups,
593
- dual_cross_attention=dual_cross_attention,
594
- use_linear_projection=use_linear_projection,
595
- upcast_attention=upcast_attention,
596
- )
597
- # custom MV2D attention block
598
- elif mid_block_type == "UNetMidBlockMV2DCrossAttn":
599
- self.mid_block = UNetMidBlockMV2DCrossAttn(
600
- transformer_layers_per_block=transformer_layers_per_block[-1],
601
- in_channels=block_out_channels[-1],
602
- temb_channels=blocks_time_embed_dim,
603
- resnet_eps=norm_eps,
604
- resnet_act_fn=act_fn,
605
- output_scale_factor=mid_block_scale_factor,
606
- resnet_time_scale_shift=resnet_time_scale_shift,
607
- cross_attention_dim=cross_attention_dim[-1],
608
- num_attention_heads=num_attention_heads[-1],
609
- resnet_groups=norm_num_groups,
610
- dual_cross_attention=dual_cross_attention,
611
- use_linear_projection=use_linear_projection,
612
- upcast_attention=upcast_attention,
613
- num_views=num_views,
614
- cd_attention_last=cd_attention_last,
615
- cd_attention_mid=cd_attention_mid,
616
- multiview_attention=multiview_attention,
617
- sparse_mv_attention=sparse_mv_attention,
618
- selfattn_block=selfattn_block,
619
- mvcd_attention=mvcd_attention,
620
- use_dino=use_dino
621
- )
622
- elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
623
- self.mid_block = UNetMidBlock2DSimpleCrossAttn(
624
- in_channels=block_out_channels[-1],
625
- temb_channels=blocks_time_embed_dim,
626
- resnet_eps=norm_eps,
627
- resnet_act_fn=act_fn,
628
- output_scale_factor=mid_block_scale_factor,
629
- cross_attention_dim=cross_attention_dim[-1],
630
- attention_head_dim=attention_head_dim[-1],
631
- resnet_groups=norm_num_groups,
632
- resnet_time_scale_shift=resnet_time_scale_shift,
633
- skip_time_act=resnet_skip_time_act,
634
- only_cross_attention=mid_block_only_cross_attention,
635
- cross_attention_norm=cross_attention_norm,
636
- )
637
- elif mid_block_type is None:
638
- self.mid_block = None
639
- else:
640
- raise ValueError(f"unknown mid_block_type : {mid_block_type}")
641
-
642
- self.addition_downsample = addition_downsample
643
- if self.addition_downsample:
644
- inc = block_out_channels[-1]
645
- self.downsample = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
646
- self.conv_block = nn.ModuleList()
647
- self.conv_block.append(BasicConvBlock(inc, addition_channels[0], stride=1))
648
- for dim_ in addition_channels[1:-1]:
649
- self.conv_block.append(BasicConvBlock(dim_, dim_, stride=1))
650
- self.conv_block.append(BasicConvBlock(dim_, inc))
651
- self.addition_conv_out = nn.Conv2d(inc, inc, kernel_size=1, bias=False)
652
- nn.init.zeros_(self.addition_conv_out.weight.data)
653
- self.addition_act_out = nn.SiLU()
654
- self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
655
-
656
- self.regress_elevation = regress_elevation
657
- self.regress_focal_length = regress_focal_length
658
- if regress_elevation or regress_focal_length:
659
- self.pool = nn.AdaptiveAvgPool2d((1, 1))
660
- self.camera_embedding = TimestepEmbedding(projection_camera_embeddings_input_dim, time_embed_dim=time_embed_dim)
661
-
662
- regress_in_dim = block_out_channels[-1]*2 if mvcd_attention else block_out_channels
663
-
664
- if regress_elevation:
665
- self.elevation_regressor = ResidualLiner(regress_in_dim, 1, 1280, act=None, num_block=num_regress_blocks)
666
- if regress_focal_length:
667
- self.focal_regressor = ResidualLiner(regress_in_dim, 1, 1280, act=None, num_block=num_regress_blocks)
668
- '''
669
- self.regress_elevation = regress_elevation
670
- self.regress_focal_length = regress_focal_length
671
- if regress_elevation and (not regress_focal_length):
672
- print("Regressing elevation")
673
- cam_dim = 1
674
- elif regress_focal_length and (not regress_elevation):
675
- print("Regressing focal length")
676
- cam_dim = 6
677
- elif regress_elevation and regress_focal_length:
678
- print("Regressing both elevation and focal length")
679
- cam_dim = 7
680
- else:
681
- cam_dim = 0
682
- assert projection_camera_embeddings_input_dim == 2*cam_dim, "projection_camera_embeddings_input_dim should be 2*cam_dim"
683
- if regress_elevation or regress_focal_length:
684
- self.elevation_regressor = nn.ModuleList([
685
- nn.Linear(block_out_channels[-1], 1280),
686
- nn.SiLU(),
687
- nn.Linear(1280, 1280),
688
- nn.SiLU(),
689
- nn.Linear(1280, cam_dim)
690
- ])
691
- self.pool = nn.AdaptiveAvgPool2d((1, 1))
692
- self.focal_act = nn.Softmax(dim=-1)
693
- self.camera_embedding = TimestepEmbedding(projection_camera_embeddings_input_dim, time_embed_dim=time_embed_dim)
694
- '''
695
-
696
- # count how many layers upsample the images
697
- self.num_upsamplers = 0
698
-
699
- # up
700
- reversed_block_out_channels = list(reversed(block_out_channels))
701
- reversed_num_attention_heads = list(reversed(num_attention_heads))
702
- reversed_layers_per_block = list(reversed(layers_per_block))
703
- reversed_cross_attention_dim = list(reversed(cross_attention_dim))
704
- reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
705
- only_cross_attention = list(reversed(only_cross_attention))
706
-
707
- output_channel = reversed_block_out_channels[0]
708
- for i, up_block_type in enumerate(up_block_types):
709
- is_final_block = i == len(block_out_channels) - 1
710
-
711
- prev_output_channel = output_channel
712
- output_channel = reversed_block_out_channels[i]
713
- input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
714
-
715
- # add upsample block for all BUT final layer
716
- if not is_final_block:
717
- add_upsample = True
718
- self.num_upsamplers += 1
719
- else:
720
- add_upsample = False
721
-
722
- up_block = get_up_block(
723
- up_block_type,
724
- num_layers=reversed_layers_per_block[i] + 1,
725
- transformer_layers_per_block=reversed_transformer_layers_per_block[i],
726
- in_channels=input_channel,
727
- out_channels=output_channel,
728
- prev_output_channel=prev_output_channel,
729
- temb_channels=blocks_time_embed_dim,
730
- add_upsample=add_upsample,
731
- resnet_eps=norm_eps,
732
- resnet_act_fn=act_fn,
733
- resnet_groups=norm_num_groups,
734
- cross_attention_dim=reversed_cross_attention_dim[i],
735
- num_attention_heads=reversed_num_attention_heads[i],
736
- dual_cross_attention=dual_cross_attention,
737
- use_linear_projection=use_linear_projection,
738
- only_cross_attention=only_cross_attention[i],
739
- upcast_attention=upcast_attention,
740
- resnet_time_scale_shift=resnet_time_scale_shift,
741
- resnet_skip_time_act=resnet_skip_time_act,
742
- resnet_out_scale_factor=resnet_out_scale_factor,
743
- cross_attention_norm=cross_attention_norm,
744
- attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
745
- num_views=num_views,
746
- cd_attention_last=cd_attention_last,
747
- cd_attention_mid=cd_attention_mid,
748
- multiview_attention=multiview_attention,
749
- sparse_mv_attention=sparse_mv_attention,
750
- selfattn_block=selfattn_block,
751
- mvcd_attention=mvcd_attention,
752
- use_dino=use_dino
753
- )
754
- self.up_blocks.append(up_block)
755
- prev_output_channel = output_channel
756
-
757
- # out
758
- if norm_num_groups is not None:
759
- self.conv_norm_out = nn.GroupNorm(
760
- num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
761
- )
762
-
763
- self.conv_act = get_activation(act_fn)
764
-
765
- else:
766
- self.conv_norm_out = None
767
- self.conv_act = None
768
-
769
- conv_out_padding = (conv_out_kernel - 1) // 2
770
- self.conv_out = nn.Conv2d(
771
- block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
772
- )
773
-
774
- @property
775
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
776
- r"""
777
- Returns:
778
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
779
- indexed by its weight name.
780
- """
781
- # set recursively
782
- processors = {}
783
-
784
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
785
- if hasattr(module, "set_processor"):
786
- processors[f"{name}.processor"] = module.processor
787
-
788
- for sub_name, child in module.named_children():
789
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
790
-
791
- return processors
792
-
793
- for name, module in self.named_children():
794
- fn_recursive_add_processors(name, module, processors)
795
-
796
- return processors
797
-
798
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
799
- r"""
800
- Sets the attention processor to use to compute attention.
801
-
802
- Parameters:
803
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
804
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
805
- for **all** `Attention` layers.
806
-
807
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
808
- processor. This is strongly recommended when setting trainable attention processors.
809
-
810
- """
811
- count = len(self.attn_processors.keys())
812
-
813
- if isinstance(processor, dict) and len(processor) != count:
814
- raise ValueError(
815
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
816
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
817
- )
818
-
819
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
820
- if hasattr(module, "set_processor"):
821
- if not isinstance(processor, dict):
822
- module.set_processor(processor)
823
- else:
824
- module.set_processor(processor.pop(f"{name}.processor"))
825
-
826
- for sub_name, child in module.named_children():
827
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
828
-
829
- for name, module in self.named_children():
830
- fn_recursive_attn_processor(name, module, processor)
831
-
832
- def set_default_attn_processor(self):
833
- """
834
- Disables custom attention processors and sets the default attention implementation.
835
- """
836
- self.set_attn_processor(AttnProcessor())
837
-
838
- def set_attention_slice(self, slice_size):
839
- r"""
840
- Enable sliced attention computation.
841
-
842
- When this option is enabled, the attention module splits the input tensor in slices to compute attention in
843
- several steps. This is useful for saving some memory in exchange for a small decrease in speed.
844
-
845
- Args:
846
- slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
847
- When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
848
- `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
849
- provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
850
- must be a multiple of `slice_size`.
851
- """
852
- sliceable_head_dims = []
853
-
854
- def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
855
- if hasattr(module, "set_attention_slice"):
856
- sliceable_head_dims.append(module.sliceable_head_dim)
857
-
858
- for child in module.children():
859
- fn_recursive_retrieve_sliceable_dims(child)
860
-
861
- # retrieve number of attention layers
862
- for module in self.children():
863
- fn_recursive_retrieve_sliceable_dims(module)
864
-
865
- num_sliceable_layers = len(sliceable_head_dims)
866
-
867
- if slice_size == "auto":
868
- # half the attention head size is usually a good trade-off between
869
- # speed and memory
870
- slice_size = [dim // 2 for dim in sliceable_head_dims]
871
- elif slice_size == "max":
872
- # make smallest slice possible
873
- slice_size = num_sliceable_layers * [1]
874
-
875
- slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
876
-
877
- if len(slice_size) != len(sliceable_head_dims):
878
- raise ValueError(
879
- f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
880
- f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
881
- )
882
-
883
- for i in range(len(slice_size)):
884
- size = slice_size[i]
885
- dim = sliceable_head_dims[i]
886
- if size is not None and size > dim:
887
- raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
888
-
889
- # Recursively walk through all the children.
890
- # Any children which exposes the set_attention_slice method
891
- # gets the message
892
- def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
893
- if hasattr(module, "set_attention_slice"):
894
- module.set_attention_slice(slice_size.pop())
895
-
896
- for child in module.children():
897
- fn_recursive_set_attention_slice(child, slice_size)
898
-
899
- reversed_slice_size = list(reversed(slice_size))
900
- for module in self.children():
901
- fn_recursive_set_attention_slice(module, reversed_slice_size)
902
-
903
- def _set_gradient_checkpointing(self, module, value=False):
904
- if isinstance(module, (CrossAttnDownBlock2D, CrossAttnDownBlockMV2D, DownBlock2D, CrossAttnUpBlock2D, CrossAttnUpBlockMV2D, UpBlock2D)):
905
- module.gradient_checkpointing = value
906
-
907
- def forward(
908
- self,
909
- sample: torch.FloatTensor,
910
- timestep: Union[torch.Tensor, float, int],
911
- encoder_hidden_states: torch.Tensor,
912
- class_labels: Optional[torch.Tensor] = None,
913
- timestep_cond: Optional[torch.Tensor] = None,
914
- attention_mask: Optional[torch.Tensor] = None,
915
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
916
- added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
917
- down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
918
- mid_block_additional_residual: Optional[torch.Tensor] = None,
919
- encoder_attention_mask: Optional[torch.Tensor] = None,
920
- dino_feature: Optional[torch.Tensor] = None,
921
- return_dict: bool = True,
922
- vis_max_min: bool = False,
923
- ) -> Union[UNetMV2DConditionOutput, Tuple]:
924
- r"""
925
- The [`UNet2DConditionModel`] forward method.
926
-
927
- Args:
928
- sample (`torch.FloatTensor`):
929
- The noisy input tensor with the following shape `(batch, channel, height, width)`.
930
- timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
931
- encoder_hidden_states (`torch.FloatTensor`):
932
- The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
933
- encoder_attention_mask (`torch.Tensor`):
934
- A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
935
- `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
936
- which adds large negative values to the attention scores corresponding to "discard" tokens.
937
- return_dict (`bool`, *optional*, defaults to `True`):
938
- Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
939
- tuple.
940
- cross_attention_kwargs (`dict`, *optional*):
941
- A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
942
- added_cond_kwargs: (`dict`, *optional*):
943
- A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
944
- are passed along to the UNet blocks.
945
-
946
- Returns:
947
- [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
948
- If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
949
- a `tuple` is returned where the first element is the sample tensor.
950
- """
951
- record_max_min = {}
952
- # By default samples have to be AT least a multiple of the overall upsampling factor.
953
- # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
954
- # However, the upsampling interpolation output size can be forced to fit any upsampling size
955
- # on the fly if necessary.
956
- default_overall_up_factor = 2**self.num_upsamplers
957
-
958
- # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
959
- forward_upsample_size = False
960
- upsample_size = None
961
-
962
- if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
963
- logger.info("Forward upsample size to force interpolation output size.")
964
- forward_upsample_size = True
965
-
966
- # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
967
- # expects mask of shape:
968
- # [batch, key_tokens]
969
- # adds singleton query_tokens dimension:
970
- # [batch, 1, key_tokens]
971
- # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
972
- # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
973
- # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
974
- if attention_mask is not None:
975
- # assume that mask is expressed as:
976
- # (1 = keep, 0 = discard)
977
- # convert mask into a bias that can be added to attention scores:
978
- # (keep = +0, discard = -10000.0)
979
- attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
980
- attention_mask = attention_mask.unsqueeze(1)
981
-
982
- # convert encoder_attention_mask to a bias the same way we do for attention_mask
983
- if encoder_attention_mask is not None:
984
- encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
985
- encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
986
-
987
- # 0. center input if necessary
988
- if self.config.center_input_sample:
989
- sample = 2 * sample - 1.0
990
- # 1. time
991
- timesteps = timestep
992
- if not torch.is_tensor(timesteps):
993
- # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
994
- # This would be a good case for the `match` statement (Python 3.10+)
995
- is_mps = sample.device.type == "mps"
996
- if isinstance(timestep, float):
997
- dtype = torch.float32 if is_mps else torch.float64
998
- else:
999
- dtype = torch.int32 if is_mps else torch.int64
1000
- timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
1001
- elif len(timesteps.shape) == 0:
1002
- timesteps = timesteps[None].to(sample.device)
1003
-
1004
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1005
- timesteps = timesteps.expand(sample.shape[0])
1006
-
1007
- t_emb = self.time_proj(timesteps)
1008
-
1009
- # `Timesteps` does not contain any weights and will always return f32 tensors
1010
- # but time_embedding might actually be running in fp16. so we need to cast here.
1011
- # there might be better ways to encapsulate this.
1012
- t_emb = t_emb.to(dtype=sample.dtype)
1013
-
1014
- emb = self.time_embedding(t_emb, timestep_cond)
1015
- aug_emb = None
1016
- if self.class_embedding is not None:
1017
- if class_labels is None:
1018
- raise ValueError("class_labels should be provided when num_class_embeds > 0")
1019
-
1020
- if self.config.class_embed_type == "timestep":
1021
- class_labels = self.time_proj(class_labels)
1022
-
1023
- # `Timesteps` does not contain any weights and will always return f32 tensors
1024
- # there might be better ways to encapsulate this.
1025
- class_labels = class_labels.to(dtype=sample.dtype)
1026
-
1027
- class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
1028
- if self.config.class_embeddings_concat:
1029
- emb = torch.cat([emb, class_emb], dim=-1)
1030
- else:
1031
- emb = emb + class_emb
1032
-
1033
- if self.config.addition_embed_type == "text":
1034
- aug_emb = self.add_embedding(encoder_hidden_states)
1035
- elif self.config.addition_embed_type == "text_image":
1036
- # Kandinsky 2.1 - style
1037
- if "image_embeds" not in added_cond_kwargs:
1038
- raise ValueError(
1039
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1040
- )
1041
-
1042
- image_embs = added_cond_kwargs.get("image_embeds")
1043
- text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
1044
- aug_emb = self.add_embedding(text_embs, image_embs)
1045
- elif self.config.addition_embed_type == "text_time":
1046
- # SDXL - style
1047
- if "text_embeds" not in added_cond_kwargs:
1048
- raise ValueError(
1049
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
1050
- )
1051
- text_embeds = added_cond_kwargs.get("text_embeds")
1052
- if "time_ids" not in added_cond_kwargs:
1053
- raise ValueError(
1054
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
1055
- )
1056
- time_ids = added_cond_kwargs.get("time_ids")
1057
- time_embeds = self.add_time_proj(time_ids.flatten())
1058
- time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
1059
-
1060
- add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
1061
- add_embeds = add_embeds.to(emb.dtype)
1062
- aug_emb = self.add_embedding(add_embeds)
1063
- elif self.config.addition_embed_type == "image":
1064
- # Kandinsky 2.2 - style
1065
- if "image_embeds" not in added_cond_kwargs:
1066
- raise ValueError(
1067
- f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1068
- )
1069
- image_embs = added_cond_kwargs.get("image_embeds")
1070
- aug_emb = self.add_embedding(image_embs)
1071
- elif self.config.addition_embed_type == "image_hint":
1072
- # Kandinsky 2.2 - style
1073
- if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
1074
- raise ValueError(
1075
- f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
1076
- )
1077
- image_embs = added_cond_kwargs.get("image_embeds")
1078
- hint = added_cond_kwargs.get("hint")
1079
- aug_emb, hint = self.add_embedding(image_embs, hint)
1080
- sample = torch.cat([sample, hint], dim=1)
1081
-
1082
- emb = emb + aug_emb if aug_emb is not None else emb
1083
- emb_pre_act = emb
1084
- if self.time_embed_act is not None:
1085
- emb = self.time_embed_act(emb)
1086
-
1087
- if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
1088
- encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1089
- elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
1090
- # Kadinsky 2.1 - style
1091
- if "image_embeds" not in added_cond_kwargs:
1092
- raise ValueError(
1093
- f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1094
- )
1095
-
1096
- image_embeds = added_cond_kwargs.get("image_embeds")
1097
- encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
1098
- elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
1099
- # Kandinsky 2.2 - style
1100
- if "image_embeds" not in added_cond_kwargs:
1101
- raise ValueError(
1102
- f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1103
- )
1104
- image_embeds = added_cond_kwargs.get("image_embeds")
1105
- encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1106
- # 2. pre-process
1107
- sample = self.conv_in(sample)
1108
- # 3. down
1109
-
1110
- is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
1111
- is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
1112
-
1113
- down_block_res_samples = (sample,)
1114
- for i, downsample_block in enumerate(self.down_blocks):
1115
- if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
1116
- # For t2i-adapter CrossAttnDownBlock2D
1117
- additional_residuals = {}
1118
- if is_adapter and len(down_block_additional_residuals) > 0:
1119
- additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
1120
-
1121
- sample, res_samples = downsample_block(
1122
- hidden_states=sample,
1123
- temb=emb,
1124
- encoder_hidden_states=encoder_hidden_states,
1125
- dino_feature=dino_feature,
1126
- attention_mask=attention_mask,
1127
- cross_attention_kwargs=cross_attention_kwargs,
1128
- encoder_attention_mask=encoder_attention_mask,
1129
- **additional_residuals,
1130
- )
1131
- else:
1132
- sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
1133
-
1134
- if is_adapter and len(down_block_additional_residuals) > 0:
1135
- sample += down_block_additional_residuals.pop(0)
1136
-
1137
- down_block_res_samples += res_samples
1138
-
1139
- if is_controlnet:
1140
- new_down_block_res_samples = ()
1141
-
1142
- for down_block_res_sample, down_block_additional_residual in zip(
1143
- down_block_res_samples, down_block_additional_residuals
1144
- ):
1145
- down_block_res_sample = down_block_res_sample + down_block_additional_residual
1146
- new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
1147
-
1148
- down_block_res_samples = new_down_block_res_samples
1149
-
1150
- if self.addition_downsample:
1151
- global_sample = sample
1152
- global_sample = self.downsample(global_sample)
1153
- for layer in self.conv_block:
1154
- global_sample = layer(global_sample)
1155
- global_sample = self.addition_act_out(self.addition_conv_out(global_sample))
1156
- global_sample = self.upsample(global_sample)
1157
- # 4. mid
1158
- if self.mid_block is not None:
1159
- sample = self.mid_block(
1160
- sample,
1161
- emb,
1162
- encoder_hidden_states=encoder_hidden_states,
1163
- dino_feature=dino_feature,
1164
- attention_mask=attention_mask,
1165
- cross_attention_kwargs=cross_attention_kwargs,
1166
- encoder_attention_mask=encoder_attention_mask,
1167
- )
1168
- # 4.1 regress elevation and focal length
1169
- # # predict elevation -> embed -> projection -> add to time emb
1170
- if self.regress_elevation or self.regress_focal_length:
1171
- pool_embeds = self.pool(sample.detach()).squeeze(-1).squeeze(-1) # (2B, C)
1172
- if self.mvcd_attention:
1173
- pool_embeds_normal, pool_embeds_color = torch.chunk(pool_embeds, 2, dim=0)
1174
- pool_embeds = torch.cat([pool_embeds_normal, pool_embeds_color], dim=-1) # (B, 2C)
1175
- pose_pred = []
1176
- if self.regress_elevation:
1177
- ele_pred = self.elevation_regressor(pool_embeds)
1178
- ele_pred = rearrange(ele_pred, '(b v) c -> b v c', v=self.num_views)
1179
- ele_pred = torch.mean(ele_pred, dim=1)
1180
- pose_pred.append(ele_pred) # b, c
1181
-
1182
- if self.regress_focal_length:
1183
- focal_pred = self.focal_regressor(pool_embeds)
1184
- focal_pred = rearrange(focal_pred, '(b v) c -> b v c', v=self.num_views)
1185
- focal_pred = torch.mean(focal_pred, dim=1)
1186
- pose_pred.append(focal_pred)
1187
- pose_pred = torch.cat(pose_pred, dim=-1)
1188
- # 'e_de_da_sincos', (B, 2)
1189
- pose_embeds = torch.cat([
1190
- torch.sin(pose_pred),
1191
- torch.cos(pose_pred)
1192
- ], dim=-1)
1193
- pose_embeds = self.camera_embedding(pose_embeds)
1194
- pose_embeds = torch.repeat_interleave(pose_embeds, self.num_views, 0)
1195
- if self.mvcd_attention:
1196
- pose_embeds = torch.cat([pose_embeds,] * 2, dim=0)
1197
-
1198
- emb = pose_embeds + emb_pre_act
1199
- if self.time_embed_act is not None:
1200
- emb = self.time_embed_act(emb)
1201
-
1202
- if is_controlnet:
1203
- sample = sample + mid_block_additional_residual
1204
-
1205
- if self.addition_downsample:
1206
- sample = sample + global_sample
1207
-
1208
- # 5. up
1209
- for i, upsample_block in enumerate(self.up_blocks):
1210
- is_final_block = i == len(self.up_blocks) - 1
1211
-
1212
- res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1213
- down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1214
-
1215
- # if we have not reached the final block and need to forward the
1216
- # upsample size, we do it here
1217
- if not is_final_block and forward_upsample_size:
1218
- upsample_size = down_block_res_samples[-1].shape[2:]
1219
-
1220
- if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1221
- sample = upsample_block(
1222
- hidden_states=sample,
1223
- temb=emb,
1224
- res_hidden_states_tuple=res_samples,
1225
- encoder_hidden_states=encoder_hidden_states,
1226
- dino_feature=dino_feature,
1227
- cross_attention_kwargs=cross_attention_kwargs,
1228
- upsample_size=upsample_size,
1229
- attention_mask=attention_mask,
1230
- encoder_attention_mask=encoder_attention_mask,
1231
- )
1232
- else:
1233
- sample = upsample_block(
1234
- hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
1235
- )
1236
- if torch.isnan(sample).any() or torch.isinf(sample).any():
1237
- print("NAN in sample, stop training.")
1238
- exit()
1239
- # 6. post-process
1240
- if self.conv_norm_out:
1241
- sample = self.conv_norm_out(sample)
1242
- sample = self.conv_act(sample)
1243
- sample = self.conv_out(sample)
1244
- if not return_dict:
1245
- return (sample, pose_pred)
1246
- if self.regress_elevation or self.regress_focal_length:
1247
- return UNetMV2DConditionOutput(sample=sample), pose_pred
1248
- else:
1249
- return UNetMV2DConditionOutput(sample=sample)
1250
-
1251
-
1252
- @classmethod
1253
- def from_pretrained_2d(
1254
- cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
1255
- camera_embedding_type: str, num_views: int, sample_size: int,
1256
- zero_init_conv_in: bool = True, zero_init_camera_projection: bool = False,
1257
- projection_camera_embeddings_input_dim: int=2,
1258
- cd_attention_last: bool = False, num_regress_blocks: int = 4,
1259
- cd_attention_mid: bool = False, multiview_attention: bool = True,
1260
- sparse_mv_attention: bool = False, selfattn_block: str = 'custom', mvcd_attention: bool = False,
1261
- in_channels: int = 8, out_channels: int = 4, unclip: bool = False, regress_elevation: bool = False, regress_focal_length: bool = False,
1262
- init_mvattn_with_selfattn: bool= False, use_dino: bool = False, addition_downsample: bool = False,
1263
- **kwargs
1264
- ):
1265
- r"""
1266
- Instantiate a pretrained PyTorch model from a pretrained model configuration.
1267
-
1268
- The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
1269
- train the model, set it back in training mode with `model.train()`.
1270
-
1271
- Parameters:
1272
- pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
1273
- Can be either:
1274
-
1275
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
1276
- the Hub.
1277
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
1278
- with [`~ModelMixin.save_pretrained`].
1279
-
1280
- cache_dir (`Union[str, os.PathLike]`, *optional*):
1281
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
1282
- is not used.
1283
- torch_dtype (`str` or `torch.dtype`, *optional*):
1284
- Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
1285
- dtype is automatically derived from the model's weights.
1286
- force_download (`bool`, *optional*, defaults to `False`):
1287
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
1288
- cached versions if they exist.
1289
- resume_download (`bool`, *optional*, defaults to `False`):
1290
- Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
1291
- incompletely downloaded files are deleted.
1292
- proxies (`Dict[str, str]`, *optional*):
1293
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
1294
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
1295
- output_loading_info (`bool`, *optional*, defaults to `False`):
1296
- Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
1297
- local_files_only(`bool`, *optional*, defaults to `False`):
1298
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
1299
- won't be downloaded from the Hub.
1300
- use_auth_token (`str` or *bool*, *optional*):
1301
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
1302
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
1303
- revision (`str`, *optional*, defaults to `"main"`):
1304
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
1305
- allowed by Git.
1306
- from_flax (`bool`, *optional*, defaults to `False`):
1307
- Load the model weights from a Flax checkpoint save file.
1308
- subfolder (`str`, *optional*, defaults to `""`):
1309
- The subfolder location of a model file within a larger model repository on the Hub or locally.
1310
- mirror (`str`, *optional*):
1311
- Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
1312
- guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
1313
- information.
1314
- device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
1315
- A map that specifies where each submodule should go. It doesn't need to be defined for each
1316
- parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
1317
- same device.
1318
-
1319
- Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
1320
- more information about each option see [designing a device
1321
- map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
1322
- max_memory (`Dict`, *optional*):
1323
- A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
1324
- each GPU and the available CPU RAM if unset.
1325
- offload_folder (`str` or `os.PathLike`, *optional*):
1326
- The path to offload weights if `device_map` contains the value `"disk"`.
1327
- offload_state_dict (`bool`, *optional*):
1328
- If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
1329
- the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
1330
- when there is some disk offload.
1331
- low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
1332
- Speed up model loading only loading the pretrained weights and not initializing the weights. This also
1333
- tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
1334
- Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
1335
- argument to `True` will raise an error.
1336
- variant (`str`, *optional*):
1337
- Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
1338
- loading `from_flax`.
1339
- use_safetensors (`bool`, *optional*, defaults to `None`):
1340
- If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
1341
- `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
1342
- weights. If set to `False`, `safetensors` weights are not loaded.
1343
-
1344
- <Tip>
1345
-
1346
- To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
1347
- `huggingface-cli login`. You can also activate the special
1348
- ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
1349
- firewalled environment.
1350
-
1351
- </Tip>
1352
-
1353
- Example:
1354
-
1355
- ```py
1356
- from diffusers import UNet2DConditionModel
1357
-
1358
- unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
1359
- ```
1360
-
1361
- If you get the error message below, you need to finetune the weights for your downstream task:
1362
-
1363
- ```bash
1364
- Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
1365
- - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
1366
- You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
1367
- ```
1368
- """
1369
- cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
1370
- ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
1371
- force_download = kwargs.pop("force_download", False)
1372
- from_flax = kwargs.pop("from_flax", False)
1373
- resume_download = kwargs.pop("resume_download", False)
1374
- proxies = kwargs.pop("proxies", None)
1375
- output_loading_info = kwargs.pop("output_loading_info", False)
1376
- local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
1377
- use_auth_token = kwargs.pop("use_auth_token", None)
1378
- revision = kwargs.pop("revision", None)
1379
- torch_dtype = kwargs.pop("torch_dtype", None)
1380
- subfolder = kwargs.pop("subfolder", None)
1381
- device_map = kwargs.pop("device_map", None)
1382
- max_memory = kwargs.pop("max_memory", None)
1383
- offload_folder = kwargs.pop("offload_folder", None)
1384
- offload_state_dict = kwargs.pop("offload_state_dict", False)
1385
- variant = kwargs.pop("variant", None)
1386
- use_safetensors = kwargs.pop("use_safetensors", None)
1387
-
1388
- if use_safetensors:
1389
- raise ValueError(
1390
- "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
1391
- )
1392
-
1393
- allow_pickle = False
1394
- if use_safetensors is None:
1395
- use_safetensors = True
1396
- allow_pickle = True
1397
-
1398
- if device_map is not None and not is_accelerate_available():
1399
- raise NotImplementedError(
1400
- "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
1401
- " `device_map=None`. You can install accelerate with `pip install accelerate`."
1402
- )
1403
-
1404
- # Check if we can handle device_map and dispatching the weights
1405
- if device_map is not None and not is_torch_version(">=", "1.9.0"):
1406
- raise NotImplementedError(
1407
- "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
1408
- " `device_map=None`."
1409
- )
1410
-
1411
- # Load config if we don't provide a configuration
1412
- config_path = pretrained_model_name_or_path
1413
-
1414
- user_agent = {
1415
- "diffusers": __version__,
1416
- "file_type": "model",
1417
- "framework": "pytorch",
1418
- }
1419
-
1420
- # load config
1421
- config, unused_kwargs, commit_hash = cls.load_config(
1422
- config_path,
1423
- cache_dir=cache_dir,
1424
- return_unused_kwargs=True,
1425
- return_commit_hash=True,
1426
- force_download=force_download,
1427
- resume_download=resume_download,
1428
- proxies=proxies,
1429
- local_files_only=local_files_only,
1430
- use_auth_token=use_auth_token,
1431
- revision=revision,
1432
- subfolder=subfolder,
1433
- device_map=device_map,
1434
- max_memory=max_memory,
1435
- offload_folder=offload_folder,
1436
- offload_state_dict=offload_state_dict,
1437
- user_agent=user_agent,
1438
- **kwargs,
1439
- )
1440
-
1441
- # modify config
1442
- config["_class_name"] = cls.__name__
1443
- config['in_channels'] = in_channels
1444
- config['out_channels'] = out_channels
1445
- config['sample_size'] = sample_size # training resolution
1446
- config['num_views'] = num_views
1447
- config['cd_attention_last'] = cd_attention_last
1448
- config['cd_attention_mid'] = cd_attention_mid
1449
- config['multiview_attention'] = multiview_attention
1450
- config['sparse_mv_attention'] = sparse_mv_attention
1451
- config['selfattn_block'] = selfattn_block
1452
- config['mvcd_attention'] = mvcd_attention
1453
- config["down_block_types"] = [
1454
- "CrossAttnDownBlockMV2D",
1455
- "CrossAttnDownBlockMV2D",
1456
- "CrossAttnDownBlockMV2D",
1457
- "DownBlock2D"
1458
- ]
1459
- config['mid_block_type'] = "UNetMidBlockMV2DCrossAttn"
1460
- config["up_block_types"] = [
1461
- "UpBlock2D",
1462
- "CrossAttnUpBlockMV2D",
1463
- "CrossAttnUpBlockMV2D",
1464
- "CrossAttnUpBlockMV2D"
1465
- ]
1466
-
1467
-
1468
- config['regress_elevation'] = regress_elevation # true
1469
- config['regress_focal_length'] = regress_focal_length # true
1470
- config['projection_camera_embeddings_input_dim'] = projection_camera_embeddings_input_dim # 2 for elevation and 10 for focal_length
1471
- config['use_dino'] = use_dino
1472
- config['num_regress_blocks'] = num_regress_blocks
1473
- config['addition_downsample'] = addition_downsample
1474
- # load model
1475
- model_file = None
1476
- if from_flax:
1477
- raise NotImplementedError
1478
- else:
1479
- if use_safetensors:
1480
- try:
1481
- model_file = _get_model_file(
1482
- pretrained_model_name_or_path,
1483
- weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
1484
- cache_dir=cache_dir,
1485
- force_download=force_download,
1486
- resume_download=resume_download,
1487
- proxies=proxies,
1488
- local_files_only=local_files_only,
1489
- use_auth_token=use_auth_token,
1490
- revision=revision,
1491
- subfolder=subfolder,
1492
- user_agent=user_agent,
1493
- commit_hash=commit_hash,
1494
- )
1495
- except IOError as e:
1496
- if not allow_pickle:
1497
- raise e
1498
- pass
1499
- if model_file is None:
1500
- model_file = _get_model_file(
1501
- pretrained_model_name_or_path,
1502
- weights_name=_add_variant(WEIGHTS_NAME, variant),
1503
- cache_dir=cache_dir,
1504
- force_download=force_download,
1505
- resume_download=resume_download,
1506
- proxies=proxies,
1507
- local_files_only=local_files_only,
1508
- use_auth_token=use_auth_token,
1509
- revision=revision,
1510
- subfolder=subfolder,
1511
- user_agent=user_agent,
1512
- commit_hash=commit_hash,
1513
- )
1514
-
1515
- model = cls.from_config(config, **unused_kwargs)
1516
- import copy
1517
- state_dict_pretrain = load_state_dict(model_file, variant=variant)
1518
- state_dict = copy.deepcopy(state_dict_pretrain)
1519
-
1520
- if init_mvattn_with_selfattn:
1521
- for key in state_dict_pretrain:
1522
- if 'attn1' in key:
1523
- key_mv = key.replace('attn1', 'attn_mv')
1524
- state_dict[key_mv] = state_dict_pretrain[key]
1525
- if 'to_out.0.weight' in key:
1526
- nn.init.zeros_(state_dict[key_mv].data)
1527
- if 'transformer_blocks' in key and 'norm1' in key: # in case that initialize the norm layer in resnet block
1528
- key_mv = key.replace('norm1', 'norm_mv')
1529
- state_dict[key_mv] = state_dict_pretrain[key]
1530
- # del state_dict_pretrain
1531
-
1532
- model._convert_deprecated_attention_blocks(state_dict)
1533
-
1534
- conv_in_weight = state_dict['conv_in.weight']
1535
- conv_out_weight = state_dict['conv_out.weight']
1536
- model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model_2d(
1537
- model,
1538
- state_dict,
1539
- model_file,
1540
- pretrained_model_name_or_path,
1541
- ignore_mismatched_sizes=True,
1542
- )
1543
- if any([key == 'conv_in.weight' for key, _, _ in mismatched_keys]):
1544
- # initialize from the original SD structure
1545
- model.conv_in.weight.data[:,:4] = conv_in_weight
1546
-
1547
- # whether to place all zero to new layers?
1548
- if zero_init_conv_in:
1549
- model.conv_in.weight.data[:,4:] = 0.
1550
-
1551
- if any([key == 'conv_out.weight' for key, _, _ in mismatched_keys]):
1552
- # initialize from the original SD structure
1553
- model.conv_out.weight.data[:,:4] = conv_out_weight
1554
- if out_channels == 8: # copy for the last 4 channels
1555
- model.conv_out.weight.data[:, 4:] = conv_out_weight
1556
-
1557
- if zero_init_camera_projection: # true
1558
- params = [p for p in model.camera_embedding.parameters()]
1559
- torch.nn.init.zeros_(params[-1].data)
1560
-
1561
- loading_info = {
1562
- "missing_keys": missing_keys,
1563
- "unexpected_keys": unexpected_keys,
1564
- "mismatched_keys": mismatched_keys,
1565
- "error_msgs": error_msgs,
1566
- }
1567
-
1568
- if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
1569
- raise ValueError(
1570
- f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
1571
- )
1572
- elif torch_dtype is not None:
1573
- model = model.to(torch_dtype)
1574
-
1575
- model.register_to_config(_name_or_path=pretrained_model_name_or_path)
1576
-
1577
- # Set model in evaluation mode to deactivate DropOut modules by default
1578
- model.eval()
1579
- if output_loading_info:
1580
- return model, loading_info
1581
- return model
1582
-
1583
- @classmethod
1584
- def _load_pretrained_model_2d(
1585
- cls,
1586
- model,
1587
- state_dict,
1588
- resolved_archive_file,
1589
- pretrained_model_name_or_path,
1590
- ignore_mismatched_sizes=False,
1591
- ):
1592
- # Retrieve missing & unexpected_keys
1593
- model_state_dict = model.state_dict()
1594
- loaded_keys = list(state_dict.keys())
1595
-
1596
- expected_keys = list(model_state_dict.keys())
1597
-
1598
- original_loaded_keys = loaded_keys
1599
-
1600
- missing_keys = list(set(expected_keys) - set(loaded_keys))
1601
- unexpected_keys = list(set(loaded_keys) - set(expected_keys))
1602
-
1603
- # Make sure we are able to load base models as well as derived models (with heads)
1604
- model_to_load = model
1605
-
1606
- def _find_mismatched_keys(
1607
- state_dict,
1608
- model_state_dict,
1609
- loaded_keys,
1610
- ignore_mismatched_sizes,
1611
- ):
1612
- mismatched_keys = []
1613
- if ignore_mismatched_sizes:
1614
- for checkpoint_key in loaded_keys:
1615
- model_key = checkpoint_key
1616
-
1617
- if (
1618
- model_key in model_state_dict
1619
- and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
1620
- ):
1621
- mismatched_keys.append(
1622
- (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
1623
- )
1624
- del state_dict[checkpoint_key]
1625
- return mismatched_keys
1626
-
1627
- if state_dict is not None:
1628
- # Whole checkpoint
1629
- mismatched_keys = _find_mismatched_keys(
1630
- state_dict,
1631
- model_state_dict,
1632
- original_loaded_keys,
1633
- ignore_mismatched_sizes,
1634
- )
1635
- error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
1636
-
1637
- if len(error_msgs) > 0:
1638
- error_msg = "\n\t".join(error_msgs)
1639
- if "size mismatch" in error_msg:
1640
- error_msg += (
1641
- "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
1642
- )
1643
- raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
1644
-
1645
- if len(unexpected_keys) > 0:
1646
- logger.warning(
1647
- f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
1648
- f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
1649
- f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
1650
- " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
1651
- " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
1652
- f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
1653
- " identical (initializing a BertForSequenceClassification model from a"
1654
- " BertForSequenceClassification model)."
1655
- )
1656
- else:
1657
- logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
1658
- if len(missing_keys) > 0:
1659
- logger.warning(
1660
- f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
1661
- f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
1662
- " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
1663
- )
1664
- elif len(mismatched_keys) == 0:
1665
- logger.info(
1666
- f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
1667
- f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
1668
- f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
1669
- " without further training."
1670
- )
1671
- if len(mismatched_keys) > 0:
1672
- mismatched_warning = "\n".join(
1673
- [
1674
- f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
1675
- for key, shape1, shape2 in mismatched_keys
1676
- ]
1677
- )
1678
- logger.warning(
1679
- f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
1680
- f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
1681
- f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
1682
- " able to use it for predictions and inference."
1683
- )
1684
-
1685
- return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
1686
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mvdiffusion/pipelines/pipeline_mvdiffusion_unclip.py DELETED
@@ -1,633 +0,0 @@
1
- import inspect
2
- import warnings
3
- from typing import Callable, List, Optional, Union, Dict, Any
4
- import PIL
5
- import torch
6
- from packaging import version
7
- from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, CLIPFeatureExtractor, CLIPTokenizer, CLIPTextModel
8
- from diffusers.utils.import_utils import is_accelerate_available
9
- from diffusers.configuration_utils import FrozenDict
10
- from diffusers.image_processor import VaeImageProcessor
11
- from diffusers.models import AutoencoderKL, UNet2DConditionModel
12
- from diffusers.models.embeddings import get_timestep_embedding
13
- from diffusers.schedulers import KarrasDiffusionSchedulers
14
- from diffusers.utils import deprecate, logging
15
- from diffusers.utils.torch_utils import randn_tensor
16
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
17
- from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
18
- import os
19
- import torchvision.transforms.functional as TF
20
- from einops import rearrange
21
- logger = logging.get_logger(__name__)
22
-
23
- class StableUnCLIPImg2ImgPipeline(DiffusionPipeline):
24
- """
25
- Pipeline for text-guided image to image generation using stable unCLIP.
26
-
27
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
28
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
29
-
30
- Args:
31
- feature_extractor ([`CLIPFeatureExtractor`]):
32
- Feature extractor for image pre-processing before being encoded.
33
- image_encoder ([`CLIPVisionModelWithProjection`]):
34
- CLIP vision model for encoding images.
35
- image_normalizer ([`StableUnCLIPImageNormalizer`]):
36
- Used to normalize the predicted image embeddings before the noise is applied and un-normalize the image
37
- embeddings after the noise has been applied.
38
- image_noising_scheduler ([`KarrasDiffusionSchedulers`]):
39
- Noise schedule for adding noise to the predicted image embeddings. The amount of noise to add is determined
40
- by `noise_level` in `StableUnCLIPPipeline.__call__`.
41
- tokenizer (`CLIPTokenizer`):
42
- Tokenizer of class
43
- [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
44
- text_encoder ([`CLIPTextModel`]):
45
- Frozen text-encoder.
46
- unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
47
- scheduler ([`KarrasDiffusionSchedulers`]):
48
- A scheduler to be used in combination with `unet` to denoise the encoded image latents.
49
- vae ([`AutoencoderKL`]):
50
- Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
51
- """
52
- # image encoding components
53
- feature_extractor: CLIPFeatureExtractor
54
- image_encoder: CLIPVisionModelWithProjection
55
- # image noising components
56
- image_normalizer: StableUnCLIPImageNormalizer
57
- image_noising_scheduler: KarrasDiffusionSchedulers
58
- # regular denoising components
59
- tokenizer: CLIPTokenizer
60
- text_encoder: CLIPTextModel
61
- unet: UNet2DConditionModel
62
- scheduler: KarrasDiffusionSchedulers
63
- vae: AutoencoderKL
64
-
65
- def __init__(
66
- self,
67
- # image encoding components
68
- feature_extractor: CLIPFeatureExtractor,
69
- image_encoder: CLIPVisionModelWithProjection,
70
- # image noising components
71
- image_normalizer: StableUnCLIPImageNormalizer,
72
- image_noising_scheduler: KarrasDiffusionSchedulers,
73
- # regular denoising components
74
- tokenizer: CLIPTokenizer,
75
- text_encoder: CLIPTextModel,
76
- unet: UNet2DConditionModel,
77
- scheduler: KarrasDiffusionSchedulers,
78
- # vae
79
- vae: AutoencoderKL,
80
- num_views: int = 4,
81
- ):
82
- super().__init__()
83
-
84
- self.register_modules(
85
- feature_extractor=feature_extractor,
86
- image_encoder=image_encoder,
87
- image_normalizer=image_normalizer,
88
- image_noising_scheduler=image_noising_scheduler,
89
- tokenizer=tokenizer,
90
- text_encoder=text_encoder,
91
- unet=unet,
92
- scheduler=scheduler,
93
- vae=vae,
94
- )
95
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
96
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
97
- self.num_views: int = num_views
98
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
99
- def enable_vae_slicing(self):
100
- r"""
101
- Enable sliced VAE decoding.
102
-
103
- When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
104
- steps. This is useful to save some memory and allow larger batch sizes.
105
- """
106
- self.vae.enable_slicing()
107
-
108
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
109
- def disable_vae_slicing(self):
110
- r"""
111
- Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
112
- computing decoding in one step.
113
- """
114
- self.vae.disable_slicing()
115
-
116
- def enable_sequential_cpu_offload(self, gpu_id=0):
117
- r"""
118
- Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
119
- models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
120
- when their specific submodule has its `forward` method called.
121
- """
122
- if is_accelerate_available():
123
- from accelerate import cpu_offload
124
- else:
125
- raise ImportError("Please install accelerate via `pip install accelerate`")
126
-
127
- device = torch.device(f"cuda:{gpu_id}")
128
-
129
- # TODO: self.image_normalizer.{scale,unscale} are not covered by the offload hooks, so they fails if added to the list
130
- models = [
131
- self.image_encoder,
132
- self.text_encoder,
133
- self.unet,
134
- self.vae,
135
- ]
136
- for cpu_offloaded_model in models:
137
- if cpu_offloaded_model is not None:
138
- cpu_offload(cpu_offloaded_model, device)
139
-
140
- @property
141
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
142
- def _execution_device(self):
143
- r"""
144
- Returns the device on which the pipeline's models will be executed. After calling
145
- `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
146
- hooks.
147
- """
148
- if not hasattr(self.unet, "_hf_hook"):
149
- return self.device
150
- for module in self.unet.modules():
151
- if (
152
- hasattr(module, "_hf_hook")
153
- and hasattr(module._hf_hook, "execution_device")
154
- and module._hf_hook.execution_device is not None
155
- ):
156
- return torch.device(module._hf_hook.execution_device)
157
- return self.device
158
-
159
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
160
- def _encode_prompt(
161
- self,
162
- prompt,
163
- device,
164
- num_images_per_prompt,
165
- do_classifier_free_guidance,
166
- negative_prompt=None,
167
- prompt_embeds: Optional[torch.FloatTensor] = None,
168
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
169
- lora_scale: Optional[float] = None,
170
- ):
171
- r"""
172
- Encodes the prompt into text encoder hidden states.
173
-
174
- Args:
175
- prompt (`str` or `List[str]`, *optional*):
176
- prompt to be encoded
177
- device: (`torch.device`):
178
- torch device
179
- num_images_per_prompt (`int`):
180
- number of images that should be generated per prompt
181
- do_classifier_free_guidance (`bool`):
182
- whether to use classifier free guidance or not
183
- negative_prompt (`str` or `List[str]`, *optional*):
184
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
185
- `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
186
- Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
187
- prompt_embeds (`torch.FloatTensor`, *optional*):
188
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
189
- provided, text embeddings will be generated from `prompt` input argument.
190
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
191
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
192
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
193
- argument.
194
- """
195
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
196
-
197
- if do_classifier_free_guidance:
198
- # For classifier free guidance, we need to do two forward passes.
199
- # Here we concatenate the unconditional and text embeddings into a single batch
200
- # to avoid doing two forward passes
201
- normal_prompt_embeds, color_prompt_embeds = torch.chunk(prompt_embeds, 2, dim=0)
202
-
203
- prompt_embeds = torch.cat([normal_prompt_embeds, normal_prompt_embeds, color_prompt_embeds, color_prompt_embeds], 0)
204
-
205
- return prompt_embeds
206
-
207
- def _encode_image(
208
- self,
209
- image_pil,
210
- device,
211
- num_images_per_prompt,
212
- do_classifier_free_guidance,
213
- noise_level: int=0,
214
- generator: Optional[torch.Generator] = None
215
- ):
216
- dtype = next(self.image_encoder.parameters()).dtype
217
- # ______________________________clip image embedding______________________________
218
- image = self.feature_extractor(images=image_pil, return_tensors="pt").pixel_values
219
- image = image.to(device=device, dtype=dtype)
220
- image_embeds = self.image_encoder(image).image_embeds
221
-
222
- image_embeds = self.noise_image_embeddings(
223
- image_embeds=image_embeds,
224
- noise_level=noise_level,
225
- generator=generator,
226
- )
227
- # duplicate image embeddings for each generation per prompt, using mps friendly method
228
- # image_embeds = image_embeds.unsqueeze(1)
229
- # note: the condition input is same
230
- image_embeds = image_embeds.repeat(num_images_per_prompt, 1)
231
-
232
- if do_classifier_free_guidance:
233
- normal_image_embeds, color_image_embeds = torch.chunk(image_embeds, 2, dim=0)
234
- negative_prompt_embeds = torch.zeros_like(normal_image_embeds)
235
-
236
- # For classifier free guidance, we need to do two forward passes.
237
- # Here we concatenate the unconditional and text embeddings into a single batch
238
- # to avoid doing two forward passes
239
- image_embeds = torch.cat([negative_prompt_embeds, normal_image_embeds, negative_prompt_embeds, color_image_embeds], 0)
240
-
241
- # _____________________________vae input latents__________________________________________________
242
- image_pt = torch.stack([TF.to_tensor(img) for img in image_pil], dim=0).to(device)
243
- image_pt = image_pt * 2.0 - 1.0
244
- image_latents = self.vae.encode(image_pt).latent_dist.mode() * self.vae.config.scaling_factor
245
- # Note: repeat differently from official pipelines
246
- image_latents = image_latents.repeat(num_images_per_prompt, 1, 1, 1)
247
-
248
- if do_classifier_free_guidance:
249
- normal_image_latents, color_image_latents = torch.chunk(image_latents, 2, dim=0)
250
- image_latents = torch.cat([torch.zeros_like(normal_image_latents), normal_image_latents,
251
- torch.zeros_like(color_image_latents), color_image_latents], 0)
252
-
253
- return image_embeds, image_latents
254
-
255
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
256
- def decode_latents(self, latents):
257
- latents = 1 / self.vae.config.scaling_factor * latents
258
- image = self.vae.decode(latents).sample
259
- image = (image / 2 + 0.5).clamp(0, 1)
260
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
261
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
262
- return image
263
-
264
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
265
- def prepare_extra_step_kwargs(self, generator, eta):
266
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
267
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
268
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
269
- # and should be between [0, 1]
270
-
271
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
272
- extra_step_kwargs = {}
273
- if accepts_eta:
274
- extra_step_kwargs["eta"] = eta
275
-
276
- # check if the scheduler accepts generator
277
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
278
- if accepts_generator:
279
- extra_step_kwargs["generator"] = generator
280
- return extra_step_kwargs
281
-
282
- def check_inputs(
283
- self,
284
- prompt,
285
- image,
286
- height,
287
- width,
288
- callback_steps,
289
- noise_level,
290
- ):
291
- if height % 8 != 0 or width % 8 != 0:
292
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
293
-
294
- if (callback_steps is None) or (
295
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
296
- ):
297
- raise ValueError(
298
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
299
- f" {type(callback_steps)}."
300
- )
301
-
302
- if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
303
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
304
-
305
-
306
- if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps:
307
- raise ValueError(
308
- f"`noise_level` must be between 0 and {self.image_noising_scheduler.config.num_train_timesteps - 1}, inclusive."
309
- )
310
-
311
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
312
- def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
313
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
314
- if isinstance(generator, list) and len(generator) != batch_size:
315
- raise ValueError(
316
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
317
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
318
- )
319
-
320
- if latents is None:
321
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
322
- else:
323
- latents = latents.to(device)
324
-
325
- # scale the initial noise by the standard deviation required by the scheduler
326
- latents = latents * self.scheduler.init_noise_sigma
327
- return latents
328
-
329
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_unclip.StableUnCLIPPipeline.noise_image_embeddings
330
- def noise_image_embeddings(
331
- self,
332
- image_embeds: torch.Tensor,
333
- noise_level: int,
334
- noise: Optional[torch.FloatTensor] = None,
335
- generator: Optional[torch.Generator] = None,
336
- ):
337
- """
338
- Add noise to the image embeddings. The amount of noise is controlled by a `noise_level` input. A higher
339
- `noise_level` increases the variance in the final un-noised images.
340
-
341
- The noise is applied in two ways
342
- 1. A noise schedule is applied directly to the embeddings
343
- 2. A vector of sinusoidal time embeddings are appended to the output.
344
-
345
- In both cases, the amount of noise is controlled by the same `noise_level`.
346
-
347
- The embeddings are normalized before the noise is applied and un-normalized after the noise is applied.
348
- """
349
- if noise is None:
350
- noise = randn_tensor(
351
- image_embeds.shape, generator=generator, device=image_embeds.device, dtype=image_embeds.dtype
352
- )
353
-
354
- noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device)
355
-
356
- image_embeds = self.image_normalizer.scale(image_embeds)
357
-
358
- image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise)
359
-
360
- image_embeds = self.image_normalizer.unscale(image_embeds)
361
-
362
- noise_level = get_timestep_embedding(
363
- timesteps=noise_level, embedding_dim=image_embeds.shape[-1], flip_sin_to_cos=True, downscale_freq_shift=0
364
- )
365
-
366
- # `get_timestep_embeddings` does not contain any weights and will always return f32 tensors,
367
- # but we might actually be running in fp16. so we need to cast here.
368
- # there might be better ways to encapsulate this.
369
- noise_level = noise_level.to(image_embeds.dtype)
370
-
371
- image_embeds = torch.cat((image_embeds, noise_level), 1)
372
-
373
- return image_embeds
374
-
375
- @torch.no_grad()
376
- # @replace_example_docstring(EXAMPLE_DOC_STRING)
377
- def __call__(
378
- self,
379
- image: Union[torch.FloatTensor, PIL.Image.Image],
380
- prompt: Union[str, List[str]],
381
- prompt_embeds: torch.FloatTensor = None,
382
- dino_feature: torch.FloatTensor = None,
383
- height: Optional[int] = None,
384
- width: Optional[int] = None,
385
- num_inference_steps: int = 20,
386
- guidance_scale: float = 10,
387
- negative_prompt: Optional[Union[str, List[str]]] = None,
388
- num_images_per_prompt: Optional[int] = 1,
389
- eta: float = 0.0,
390
- generator: Optional[torch.Generator] = None,
391
- latents: Optional[torch.FloatTensor] = None,
392
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
393
- output_type: Optional[str] = "pil",
394
- return_dict: bool = True,
395
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
396
- callback_steps: int = 1,
397
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
398
- noise_level: int = 0,
399
- image_embeds: Optional[torch.FloatTensor] = None,
400
- return_elevation_focal: Optional[bool] = False,
401
- gt_img_in: Optional[torch.FloatTensor] = None,
402
- ):
403
- r"""
404
- Function invoked when calling the pipeline for generation.
405
-
406
- Args:
407
- prompt (`str` or `List[str]`, *optional*):
408
- The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
409
- instead.
410
- image (`torch.FloatTensor` or `PIL.Image.Image`):
411
- `Image`, or tensor representing an image batch. The image will be encoded to its CLIP embedding which
412
- the unet will be conditioned on. Note that the image is _not_ encoded by the vae and then used as the
413
- latents in the denoising process such as in the standard stable diffusion text guided image variation
414
- process.
415
- height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
416
- The height in pixels of the generated image.
417
- width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
418
- The width in pixels of the generated image.
419
- num_inference_steps (`int`, *optional*, defaults to 20):
420
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
421
- expense of slower inference.
422
- guidance_scale (`float`, *optional*, defaults to 10.0):
423
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
424
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
425
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
426
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
427
- usually at the expense of lower image quality.
428
- negative_prompt (`str` or `List[str]`, *optional*):
429
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
430
- `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
431
- Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
432
- num_images_per_prompt (`int`, *optional*, defaults to 1):
433
- The number of images to generate per prompt.
434
- eta (`float`, *optional*, defaults to 0.0):
435
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
436
- [`schedulers.DDIMScheduler`], will be ignored for others.
437
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
438
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
439
- to make generation deterministic.
440
- latents (`torch.FloatTensor`, *optional*):
441
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
442
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
443
- tensor will ge generated by sampling using the supplied random `generator`.
444
- prompt_embeds (`torch.FloatTensor`, *optional*):
445
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
446
- provided, text embeddings will be generated from `prompt` input argument.
447
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
448
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
449
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
450
- argument.
451
- output_type (`str`, *optional*, defaults to `"pil"`):
452
- The output format of the generate image. Choose between
453
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
454
- return_dict (`bool`, *optional*, defaults to `True`):
455
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
456
- plain tuple.
457
- callback (`Callable`, *optional*):
458
- A function that will be called every `callback_steps` steps during inference. The function will be
459
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
460
- callback_steps (`int`, *optional*, defaults to 1):
461
- The frequency at which the `callback` function will be called. If not specified, the callback will be
462
- called at every step.
463
- cross_attention_kwargs (`dict`, *optional*):
464
- A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
465
- `self.processor` in
466
- [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
467
- noise_level (`int`, *optional*, defaults to `0`):
468
- The amount of noise to add to the image embeddings. A higher `noise_level` increases the variance in
469
- the final un-noised images. See `StableUnCLIPPipeline.noise_image_embeddings` for details.
470
- image_embeds (`torch.FloatTensor`, *optional*):
471
- Pre-generated CLIP embeddings to condition the unet on. Note that these are not latents to be used in
472
- the denoising process. If you want to provide pre-generated latents, pass them to `__call__` as
473
- `latents`.
474
-
475
- Examples:
476
-
477
- Returns:
478
- [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput`] if `return_dict` is
479
- True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images.
480
- """
481
- # 0. Default height and width to unet
482
- height = height or self.unet.config.sample_size * self.vae_scale_factor
483
- width = width or self.unet.config.sample_size * self.vae_scale_factor
484
-
485
- # 1. Check inputs. Raise error if not correct
486
- self.check_inputs(
487
- prompt=prompt,
488
- image=image,
489
- height=height,
490
- width=width,
491
- callback_steps=callback_steps,
492
- noise_level=noise_level
493
- )
494
-
495
- # 2. Define call parameters
496
- if isinstance(image, list):
497
- batch_size = len(image)
498
- elif isinstance(image, torch.Tensor):
499
- batch_size = image.shape[0]
500
- assert batch_size >= self.num_views and batch_size % self.num_views == 0
501
- elif isinstance(image, PIL.Image.Image):
502
- image = [image]*self.num_views*2
503
- batch_size = self.num_views*2
504
-
505
- if isinstance(prompt, str):
506
- prompt = [prompt] * self.num_views * 2
507
-
508
- device = self._execution_device
509
-
510
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
511
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
512
- # corresponds to doing no classifier free guidance.
513
- do_classifier_free_guidance = guidance_scale != 1.0
514
-
515
- # 3. Encode input prompt
516
- text_encoder_lora_scale = (
517
- cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
518
- )
519
- prompt_embeds = self._encode_prompt(
520
- prompt=prompt,
521
- device=device,
522
- num_images_per_prompt=num_images_per_prompt,
523
- do_classifier_free_guidance=do_classifier_free_guidance,
524
- negative_prompt=negative_prompt,
525
- prompt_embeds=prompt_embeds,
526
- negative_prompt_embeds=negative_prompt_embeds,
527
- lora_scale=text_encoder_lora_scale,
528
- )
529
-
530
-
531
- # 4. Encoder input image
532
- if isinstance(image, list):
533
- image_pil = image
534
- elif isinstance(image, torch.Tensor):
535
- image_pil = [TF.to_pil_image(image[i]) for i in range(image.shape[0])]
536
- noise_level = torch.tensor([noise_level], device=device)
537
- image_embeds, image_latents = self._encode_image(
538
- image_pil=image_pil,
539
- device=device,
540
- num_images_per_prompt=num_images_per_prompt,
541
- do_classifier_free_guidance=do_classifier_free_guidance,
542
- noise_level=noise_level,
543
- generator=generator,
544
- )
545
-
546
- # 5. Prepare timesteps
547
- self.scheduler.set_timesteps(num_inference_steps, device=device)
548
- timesteps = self.scheduler.timesteps
549
-
550
- # 6. Prepare latent variables
551
- num_channels_latents = self.unet.config.out_channels
552
- if gt_img_in is not None:
553
- latents = gt_img_in * self.scheduler.init_noise_sigma
554
- else:
555
- latents = self.prepare_latents(
556
- batch_size=batch_size,
557
- num_channels_latents=num_channels_latents,
558
- height=height,
559
- width=width,
560
- dtype=prompt_embeds.dtype,
561
- device=device,
562
- generator=generator,
563
- latents=latents,
564
- )
565
-
566
- # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
567
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
568
-
569
- eles, focals = [], []
570
- # 8. Denoising loop
571
- for i, t in enumerate(self.progress_bar(timesteps)):
572
- if do_classifier_free_guidance:
573
- normal_latents, color_latents = torch.chunk(latents, 2, dim=0)
574
- latent_model_input = torch.cat([normal_latents, normal_latents, color_latents, color_latents], 0)
575
- else:
576
- latent_model_input = latents
577
- latent_model_input = torch.cat([
578
- latent_model_input, image_latents
579
- ], dim=1)
580
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
581
-
582
- # predict the noise residual
583
- unet_out = self.unet(
584
- latent_model_input,
585
- t,
586
- encoder_hidden_states=prompt_embeds,
587
- dino_feature=dino_feature,
588
- class_labels=image_embeds,
589
- cross_attention_kwargs=cross_attention_kwargs,
590
- return_dict=False)
591
-
592
- noise_pred = unet_out[0]
593
- if return_elevation_focal:
594
- uncond_pose, pose = torch.chunk(unet_out[1], 2, 0)
595
- pose = uncond_pose + guidance_scale * (pose - uncond_pose)
596
- ele = pose[:, 0].detach().cpu().numpy() # b
597
- eles.append(ele)
598
- focal = pose[:, 1].detach().cpu().numpy()
599
- focals.append(focal)
600
-
601
- # perform guidance
602
- if do_classifier_free_guidance:
603
- normal_noise_pred_uncond, normal_noise_pred_text, color_noise_pred_uncond, color_noise_pred_text = torch.chunk(noise_pred, 4, dim=0)
604
-
605
- noise_pred_uncond, noise_pred_text = torch.cat([normal_noise_pred_uncond, color_noise_pred_uncond], 0), torch.cat([normal_noise_pred_text, color_noise_pred_text], 0)
606
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
607
-
608
- # compute the previous noisy sample x_t -> x_t-1
609
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
610
-
611
- if callback is not None and i % callback_steps == 0:
612
- callback(i, t, latents)
613
-
614
- # 9. Post-processing
615
- if not output_type == "latent":
616
- if num_channels_latents == 8:
617
- latents = torch.cat([latents[:, :4], latents[:, 4:]], dim=0)
618
- with torch.no_grad():
619
- image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
620
- else:
621
- image = latents
622
-
623
- image = self.image_processor.postprocess(image, output_type=output_type)
624
-
625
- # Offload last model to CPU
626
- # if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
627
- # self.final_offload_hook.offload()
628
- if not return_dict:
629
- return (image, )
630
- if return_elevation_focal:
631
- return ImagePipelineOutput(images=image), eles, focals
632
- else:
633
- return ImagePipelineOutput(images=image)