wangfuyun commited on
Commit
c0fdaf5
·
verified ·
1 Parent(s): b29d45f

Upload 42 files

Browse files
Files changed (42) hide show
  1. README.md +1 -1
  2. animatelcm/models/attention.py +296 -0
  3. animatelcm/models/embeddings.py +213 -0
  4. animatelcm/models/motion_module.py +337 -0
  5. animatelcm/models/resnet.py +313 -0
  6. animatelcm/models/unet.py +568 -0
  7. animatelcm/models/unet_blocks.py +904 -0
  8. animatelcm/pipelines/pipeline_animation.py +456 -0
  9. animatelcm/scheduler/lcm_scheduler.py +722 -0
  10. animatelcm/utils/convert_from_ckpt.py +951 -0
  11. animatelcm/utils/convert_lora_safetensor_to_diffusers.py +152 -0
  12. animatelcm/utils/lcm_utils.py +237 -0
  13. animatelcm/utils/util.py +153 -0
  14. app.py +392 -0
  15. models/.DS_Store +0 -0
  16. models/DreamBooth_LoRA/cartoon2d.safetensors +3 -0
  17. models/DreamBooth_LoRA/cartoon3d.safetensors +3 -0
  18. models/DreamBooth_LoRA/realistic1.safetensors +3 -0
  19. models/DreamBooth_LoRA/realistic2.safetensors +3 -0
  20. models/LCM_LoRA/Put LCMLoRA checkpoints here.txt +0 -0
  21. models/LCM_LoRA/sd15_t2v_beta_lora.safetensors +3 -0
  22. models/Motion_Module/Put motion module checkpoints here.txt +0 -0
  23. models/Motion_Module/sd15_t2v_beta_motion.ckpt +3 -0
  24. models/StableDiffusion/Put diffusers stable-diffusion-v1-5 repo here.txt +0 -0
  25. models/StableDiffusion/stable-diffusion-v1-5/.gitattributes +35 -0
  26. models/StableDiffusion/stable-diffusion-v1-5/README.md +207 -0
  27. models/StableDiffusion/stable-diffusion-v1-5/feature_extractor/preprocessor_config.json +20 -0
  28. models/StableDiffusion/stable-diffusion-v1-5/model_index.json +32 -0
  29. models/StableDiffusion/stable-diffusion-v1-5/safety_checker/config.json +175 -0
  30. models/StableDiffusion/stable-diffusion-v1-5/scheduler/scheduler_config.json +13 -0
  31. models/StableDiffusion/stable-diffusion-v1-5/text_encoder/config.json +25 -0
  32. models/StableDiffusion/stable-diffusion-v1-5/text_encoder/model.safetensors +3 -0
  33. models/StableDiffusion/stable-diffusion-v1-5/tokenizer/merges.txt +0 -0
  34. models/StableDiffusion/stable-diffusion-v1-5/tokenizer/special_tokens_map.json +24 -0
  35. models/StableDiffusion/stable-diffusion-v1-5/tokenizer/tokenizer_config.json +34 -0
  36. models/StableDiffusion/stable-diffusion-v1-5/tokenizer/vocab.json +0 -0
  37. models/StableDiffusion/stable-diffusion-v1-5/unet/config.json +36 -0
  38. models/StableDiffusion/stable-diffusion-v1-5/unet/diffusion_pytorch_model.bin +3 -0
  39. models/StableDiffusion/stable-diffusion-v1-5/v1-inference.yaml +70 -0
  40. models/StableDiffusion/stable-diffusion-v1-5/vae/config.json +29 -0
  41. models/StableDiffusion/stable-diffusion-v1-5/vae/diffusion_pytorch_model.bin +3 -0
  42. requirements.txt +15 -0
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🦀
4
  colorFrom: red
5
  colorTo: blue
6
  sdk: gradio
7
- sdk_version: 4.16.0
8
  app_file: app.py
9
  pinned: false
10
  ---
 
4
  colorFrom: red
5
  colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 3.48.0
8
  app_file: app.py
9
  pinned: false
10
  ---
animatelcm/models/attention.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ from diffusers.modeling_utils import ModelMixin
10
+ from diffusers.utils import BaseOutput
11
+ from diffusers.utils.import_utils import is_xformers_available
12
+ from diffusers.models.attention import CrossAttention, FeedForward, AdaLayerNorm
13
+
14
+ from einops import rearrange, repeat
15
+
16
+ @dataclass
17
+ class Transformer3DModelOutput(BaseOutput):
18
+ sample: torch.FloatTensor
19
+
20
+
21
+ if is_xformers_available():
22
+ import xformers
23
+ import xformers.ops
24
+ else:
25
+ xformers = None
26
+
27
+
28
+ class Transformer3DModel(ModelMixin, ConfigMixin):
29
+ @register_to_config
30
+ def __init__(
31
+ self,
32
+ num_attention_heads: int = 16,
33
+ attention_head_dim: int = 88,
34
+ in_channels: Optional[int] = None,
35
+ num_layers: int = 1,
36
+ dropout: float = 0.0,
37
+ norm_num_groups: int = 32,
38
+ cross_attention_dim: Optional[int] = None,
39
+ attention_bias: bool = False,
40
+ activation_fn: str = "geglu",
41
+ num_embeds_ada_norm: Optional[int] = None,
42
+ use_linear_projection: bool = False,
43
+ only_cross_attention: bool = False,
44
+ upcast_attention: bool = False,
45
+
46
+ unet_use_cross_frame_attention=None,
47
+ unet_use_temporal_attention=None,
48
+ ):
49
+ super().__init__()
50
+ self.use_linear_projection = use_linear_projection
51
+ self.num_attention_heads = num_attention_heads
52
+ self.attention_head_dim = attention_head_dim
53
+ inner_dim = num_attention_heads * attention_head_dim
54
+
55
+ # Define input layers
56
+ self.in_channels = in_channels
57
+
58
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
59
+ if use_linear_projection:
60
+ self.proj_in = nn.Linear(in_channels, inner_dim)
61
+ else:
62
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
63
+
64
+ # Define transformers blocks
65
+ self.transformer_blocks = nn.ModuleList(
66
+ [
67
+ BasicTransformerBlock(
68
+ inner_dim,
69
+ num_attention_heads,
70
+ attention_head_dim,
71
+ dropout=dropout,
72
+ cross_attention_dim=cross_attention_dim,
73
+ activation_fn=activation_fn,
74
+ num_embeds_ada_norm=num_embeds_ada_norm,
75
+ attention_bias=attention_bias,
76
+ only_cross_attention=only_cross_attention,
77
+ upcast_attention=upcast_attention,
78
+
79
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
80
+ unet_use_temporal_attention=unet_use_temporal_attention,
81
+ )
82
+ for d in range(num_layers)
83
+ ]
84
+ )
85
+
86
+ # 4. Define output layers
87
+ if use_linear_projection:
88
+ self.proj_out = nn.Linear(in_channels, inner_dim)
89
+ else:
90
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
91
+
92
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
93
+ # Input
94
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
95
+ video_length = hidden_states.shape[2]
96
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
97
+ encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
98
+
99
+ batch, channel, height, weight = hidden_states.shape
100
+ residual = hidden_states
101
+
102
+ hidden_states = self.norm(hidden_states)
103
+ if not self.use_linear_projection:
104
+ hidden_states = self.proj_in(hidden_states)
105
+ inner_dim = hidden_states.shape[1]
106
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
107
+ else:
108
+ inner_dim = hidden_states.shape[1]
109
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
110
+ hidden_states = self.proj_in(hidden_states)
111
+
112
+ # Blocks
113
+ for block in self.transformer_blocks:
114
+ hidden_states = block(
115
+ hidden_states,
116
+ encoder_hidden_states=encoder_hidden_states,
117
+ timestep=timestep,
118
+ video_length=video_length
119
+ )
120
+
121
+ # Output
122
+ if not self.use_linear_projection:
123
+ hidden_states = (
124
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
125
+ )
126
+ hidden_states = self.proj_out(hidden_states)
127
+ else:
128
+ hidden_states = self.proj_out(hidden_states)
129
+ hidden_states = (
130
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
131
+ )
132
+
133
+ output = hidden_states + residual
134
+
135
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
136
+ if not return_dict:
137
+ return (output,)
138
+
139
+ return Transformer3DModelOutput(sample=output)
140
+
141
+
142
+ class BasicTransformerBlock(nn.Module):
143
+ def __init__(
144
+ self,
145
+ dim: int,
146
+ num_attention_heads: int,
147
+ attention_head_dim: int,
148
+ dropout=0.0,
149
+ cross_attention_dim: Optional[int] = None,
150
+ activation_fn: str = "geglu",
151
+ num_embeds_ada_norm: Optional[int] = None,
152
+ attention_bias: bool = False,
153
+ only_cross_attention: bool = False,
154
+ upcast_attention: bool = False,
155
+
156
+ unet_use_cross_frame_attention = None,
157
+ unet_use_temporal_attention = None,
158
+ ):
159
+ super().__init__()
160
+ self.only_cross_attention = only_cross_attention
161
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
162
+ self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
163
+ self.unet_use_temporal_attention = unet_use_temporal_attention
164
+
165
+ # SC-Attn
166
+ assert unet_use_cross_frame_attention is not None
167
+ if unet_use_cross_frame_attention:
168
+ self.attn1 = SparseCausalAttention2D(
169
+ query_dim=dim,
170
+ heads=num_attention_heads,
171
+ dim_head=attention_head_dim,
172
+ dropout=dropout,
173
+ bias=attention_bias,
174
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
175
+ upcast_attention=upcast_attention,
176
+ )
177
+ else:
178
+ self.attn1 = CrossAttention(
179
+ query_dim=dim,
180
+ heads=num_attention_heads,
181
+ dim_head=attention_head_dim,
182
+ dropout=dropout,
183
+ bias=attention_bias,
184
+ upcast_attention=upcast_attention,
185
+ )
186
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
187
+
188
+ # Cross-Attn
189
+ if cross_attention_dim is not None:
190
+ self.attn2 = CrossAttention(
191
+ query_dim=dim,
192
+ cross_attention_dim=cross_attention_dim,
193
+ heads=num_attention_heads,
194
+ dim_head=attention_head_dim,
195
+ dropout=dropout,
196
+ bias=attention_bias,
197
+ upcast_attention=upcast_attention,
198
+ )
199
+ else:
200
+ self.attn2 = None
201
+
202
+ if cross_attention_dim is not None:
203
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
204
+ else:
205
+ self.norm2 = None
206
+
207
+ # Feed-forward
208
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
209
+ self.norm3 = nn.LayerNorm(dim)
210
+
211
+ # Temp-Attn
212
+ assert unet_use_temporal_attention is not None
213
+ if unet_use_temporal_attention:
214
+ self.attn_temp = CrossAttention(
215
+ query_dim=dim,
216
+ heads=num_attention_heads,
217
+ dim_head=attention_head_dim,
218
+ dropout=dropout,
219
+ bias=attention_bias,
220
+ upcast_attention=upcast_attention,
221
+ )
222
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
223
+ self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
224
+
225
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
226
+ if not is_xformers_available():
227
+ raise ModuleNotFoundError(
228
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
229
+ " xformers",
230
+ name="xformers",
231
+ )
232
+ elif not torch.cuda.is_available():
233
+ raise ValueError(
234
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
235
+ " available for GPU "
236
+ )
237
+ else:
238
+ try:
239
+ # Make sure we can run the memory efficient attention
240
+ _ = xformers.ops.memory_efficient_attention(
241
+ torch.randn((1, 2, 40), device="cuda"),
242
+ torch.randn((1, 2, 40), device="cuda"),
243
+ torch.randn((1, 2, 40), device="cuda"),
244
+ )
245
+ except Exception as e:
246
+ raise e
247
+ self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
248
+ if self.attn2 is not None:
249
+ self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
250
+ # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
251
+
252
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None):
253
+ # SparseCausal-Attention
254
+ norm_hidden_states = (
255
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
256
+ )
257
+
258
+ # if self.only_cross_attention:
259
+ # hidden_states = (
260
+ # self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
261
+ # )
262
+ # else:
263
+ # hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
264
+
265
+ # pdb.set_trace()
266
+ if self.unet_use_cross_frame_attention:
267
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
268
+ else:
269
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
270
+
271
+ if self.attn2 is not None:
272
+ # Cross-Attention
273
+ norm_hidden_states = (
274
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
275
+ )
276
+ hidden_states = (
277
+ self.attn2(
278
+ norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
279
+ )
280
+ + hidden_states
281
+ )
282
+
283
+ # Feed-forward
284
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
285
+
286
+ # Temporal-Attention
287
+ if self.unet_use_temporal_attention:
288
+ d = hidden_states.shape[1]
289
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
290
+ norm_hidden_states = (
291
+ self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
292
+ )
293
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
294
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
295
+
296
+ return hidden_states
animatelcm/models/embeddings.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+
16
+ import numpy as np
17
+ import torch
18
+ from torch import nn
19
+
20
+
21
+ def get_timestep_embedding(
22
+ timesteps: torch.Tensor,
23
+ embedding_dim: int,
24
+ flip_sin_to_cos: bool = False,
25
+ downscale_freq_shift: float = 1,
26
+ scale: float = 1,
27
+ max_period: int = 10000,
28
+ ):
29
+ """
30
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
31
+
32
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
33
+ These may be fractional.
34
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
35
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
36
+ """
37
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
38
+
39
+ half_dim = embedding_dim // 2
40
+ exponent = -math.log(max_period) * torch.arange(
41
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
42
+ )
43
+ exponent = exponent / (half_dim - downscale_freq_shift)
44
+
45
+ emb = torch.exp(exponent)
46
+ emb = timesteps[:, None].float() * emb[None, :]
47
+
48
+ # scale embeddings
49
+ emb = scale * emb
50
+
51
+ # concat sine and cosine embeddings
52
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
53
+
54
+ # flip sine and cosine embeddings
55
+ if flip_sin_to_cos:
56
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
57
+
58
+ # zero pad
59
+ if embedding_dim % 2 == 1:
60
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
61
+ return emb
62
+
63
+ def zero_module(module):
64
+ # Zero out the parameters of a module and return it.
65
+ for p in module.parameters():
66
+ p.detach().zero_()
67
+ return module
68
+
69
+ class TimestepEmbedding(nn.Module):
70
+ def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None, time_cond_proj_dim=None):
71
+ super().__init__()
72
+
73
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim)
74
+ self.act = None
75
+ if act_fn == "silu":
76
+ self.act = nn.SiLU()
77
+ elif act_fn == "mish":
78
+ self.act = nn.Mish()
79
+
80
+ if time_cond_proj_dim is not None:
81
+ self.cond_proj = zero_module(nn.Linear(time_cond_proj_dim, in_channels, bias=False))
82
+ else:
83
+ self.cond_proj = None
84
+
85
+
86
+ if out_dim is not None:
87
+ time_embed_dim_out = out_dim
88
+ else:
89
+ time_embed_dim_out = time_embed_dim
90
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
91
+
92
+ def forward(self, sample, condition=None):
93
+ if condition is not None:
94
+ sample = sample + self.cond_proj(condition)
95
+ sample = self.linear_1(sample)
96
+
97
+ if self.act is not None:
98
+ sample = self.act(sample)
99
+
100
+ sample = self.linear_2(sample)
101
+ return sample
102
+
103
+
104
+ class Timesteps(nn.Module):
105
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
106
+ super().__init__()
107
+ self.num_channels = num_channels
108
+ self.flip_sin_to_cos = flip_sin_to_cos
109
+ self.downscale_freq_shift = downscale_freq_shift
110
+
111
+ def forward(self, timesteps):
112
+ t_emb = get_timestep_embedding(
113
+ timesteps,
114
+ self.num_channels,
115
+ flip_sin_to_cos=self.flip_sin_to_cos,
116
+ downscale_freq_shift=self.downscale_freq_shift,
117
+ )
118
+ return t_emb
119
+
120
+
121
+ class GaussianFourierProjection(nn.Module):
122
+ """Gaussian Fourier embeddings for noise levels."""
123
+
124
+ def __init__(
125
+ self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
126
+ ):
127
+ super().__init__()
128
+ self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
129
+ self.log = log
130
+ self.flip_sin_to_cos = flip_sin_to_cos
131
+
132
+ if set_W_to_weight:
133
+ # to delete later
134
+ self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
135
+
136
+ self.weight = self.W
137
+
138
+ def forward(self, x):
139
+ if self.log:
140
+ x = torch.log(x)
141
+
142
+ x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
143
+
144
+ if self.flip_sin_to_cos:
145
+ out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
146
+ else:
147
+ out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
148
+ return out
149
+
150
+
151
+ class ImagePositionalEmbeddings(nn.Module):
152
+ """
153
+ Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
154
+ height and width of the latent space.
155
+
156
+ For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
157
+
158
+ For VQ-diffusion:
159
+
160
+ Output vector embeddings are used as input for the transformer.
161
+
162
+ Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
163
+
164
+ Args:
165
+ num_embed (`int`):
166
+ Number of embeddings for the latent pixels embeddings.
167
+ height (`int`):
168
+ Height of the latent image i.e. the number of height embeddings.
169
+ width (`int`):
170
+ Width of the latent image i.e. the number of width embeddings.
171
+ embed_dim (`int`):
172
+ Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
173
+ """
174
+
175
+ def __init__(
176
+ self,
177
+ num_embed: int,
178
+ height: int,
179
+ width: int,
180
+ embed_dim: int,
181
+ ):
182
+ super().__init__()
183
+
184
+ self.height = height
185
+ self.width = width
186
+ self.num_embed = num_embed
187
+ self.embed_dim = embed_dim
188
+
189
+ self.emb = nn.Embedding(self.num_embed, embed_dim)
190
+ self.height_emb = nn.Embedding(self.height, embed_dim)
191
+ self.width_emb = nn.Embedding(self.width, embed_dim)
192
+
193
+ def forward(self, index):
194
+ emb = self.emb(index)
195
+
196
+ height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))
197
+
198
+ # 1 x H x D -> 1 x H x 1 x D
199
+ height_emb = height_emb.unsqueeze(2)
200
+
201
+ width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))
202
+
203
+ # 1 x W x D -> 1 x 1 x W x D
204
+ width_emb = width_emb.unsqueeze(1)
205
+
206
+ pos_emb = height_emb + width_emb
207
+
208
+ # 1 x H x W x D -> 1 x L xD
209
+ pos_emb = pos_emb.view(1, self.height * self.width, -1)
210
+
211
+ emb = emb + pos_emb[:, : emb.shape[1], :]
212
+
213
+ return emb
animatelcm/models/motion_module.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ from diffusers.modeling_utils import ModelMixin
10
+ from diffusers.utils import BaseOutput
11
+ from diffusers.utils.import_utils import is_xformers_available
12
+ from diffusers.models.attention import CrossAttention, FeedForward
13
+
14
+ from einops import rearrange, repeat
15
+ import math
16
+
17
+
18
+ def zero_module(module):
19
+ for p in module.parameters():
20
+ p.detach().zero_()
21
+ return module
22
+
23
+
24
+ @dataclass
25
+ class TemporalTransformer3DModelOutput(BaseOutput):
26
+ sample: torch.FloatTensor
27
+
28
+
29
+ if is_xformers_available():
30
+ import xformers
31
+ import xformers.ops
32
+ else:
33
+ xformers = None
34
+
35
+
36
+ def get_motion_module(
37
+ in_channels,
38
+ motion_module_type: str,
39
+ motion_module_kwargs: dict
40
+ ):
41
+ if motion_module_type == "Vanilla":
42
+ return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,)
43
+ else:
44
+ raise ValueError
45
+
46
+
47
+ class VanillaTemporalModule(nn.Module):
48
+ def __init__(
49
+ self,
50
+ in_channels,
51
+ num_attention_heads=8,
52
+ num_transformer_block=2,
53
+ attention_block_types=("Temporal_Self", "Temporal_Self"),
54
+ cross_frame_attention_mode=None,
55
+ temporal_position_encoding=False,
56
+ temporal_attention_dim_div=1,
57
+ zero_initialize=True,
58
+ ):
59
+ super().__init__()
60
+
61
+ self.temporal_transformer = TemporalTransformer3DModel(
62
+ in_channels=in_channels,
63
+ num_attention_heads=num_attention_heads,
64
+ attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
65
+ num_layers=num_transformer_block,
66
+ attention_block_types=attention_block_types,
67
+ cross_frame_attention_mode=cross_frame_attention_mode,
68
+ temporal_position_encoding=temporal_position_encoding,
69
+ )
70
+
71
+ if zero_initialize:
72
+ self.temporal_transformer.proj_out = zero_module(
73
+ self.temporal_transformer.proj_out)
74
+
75
+ def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None):
76
+ hidden_states = input_tensor
77
+ hidden_states = self.temporal_transformer(
78
+ hidden_states, encoder_hidden_states, attention_mask)
79
+
80
+ output = hidden_states
81
+ return output
82
+
83
+
84
+ class TemporalTransformer3DModel(nn.Module):
85
+ def __init__(
86
+ self,
87
+ in_channels,
88
+ num_attention_heads,
89
+ attention_head_dim,
90
+
91
+ num_layers,
92
+ attention_block_types=("Temporal_Self", "Temporal_Self", ),
93
+ dropout=0.0,
94
+ norm_num_groups=32,
95
+ cross_attention_dim=768,
96
+ activation_fn="geglu",
97
+ attention_bias=False,
98
+ upcast_attention=False,
99
+
100
+ cross_frame_attention_mode=None,
101
+ temporal_position_encoding=False,
102
+ ):
103
+ super().__init__()
104
+
105
+ inner_dim = num_attention_heads * attention_head_dim
106
+
107
+ self.norm = torch.nn.GroupNorm(
108
+ num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
109
+ self.proj_in = nn.Linear(in_channels, inner_dim)
110
+
111
+ self.transformer_blocks = nn.ModuleList(
112
+ [
113
+ TemporalTransformerBlock(
114
+ dim=inner_dim,
115
+ num_attention_heads=num_attention_heads,
116
+ attention_head_dim=attention_head_dim,
117
+ attention_block_types=attention_block_types,
118
+ dropout=dropout,
119
+ norm_num_groups=norm_num_groups,
120
+ cross_attention_dim=cross_attention_dim,
121
+ activation_fn=activation_fn,
122
+ attention_bias=attention_bias,
123
+ upcast_attention=upcast_attention,
124
+ cross_frame_attention_mode=cross_frame_attention_mode,
125
+ temporal_position_encoding=temporal_position_encoding,
126
+ )
127
+ for d in range(num_layers)
128
+ ]
129
+ )
130
+ self.proj_out = nn.Linear(inner_dim, in_channels)
131
+
132
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
133
+ assert hidden_states.dim(
134
+ ) == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
135
+ video_length = hidden_states.shape[2]
136
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
137
+
138
+ batch, channel, height, weight = hidden_states.shape
139
+ residual = hidden_states
140
+
141
+ hidden_states = self.norm(hidden_states)
142
+ inner_dim = hidden_states.shape[1]
143
+ hidden_states = hidden_states.permute(
144
+ 0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
145
+ hidden_states = self.proj_in(hidden_states)
146
+
147
+ # Transformer Blocks
148
+ for block in self.transformer_blocks:
149
+ hidden_states = block(
150
+ hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length)
151
+
152
+ # output
153
+ hidden_states = self.proj_out(hidden_states)
154
+ hidden_states = hidden_states.reshape(
155
+ batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
156
+
157
+ output = hidden_states + residual
158
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
159
+
160
+ return output
161
+
162
+
163
+ class TemporalTransformerBlock(nn.Module):
164
+ def __init__(
165
+ self,
166
+ dim,
167
+ num_attention_heads,
168
+ attention_head_dim,
169
+ attention_block_types=("Temporal_Self", "Temporal_Self", ),
170
+ dropout=0.0,
171
+ norm_num_groups=32,
172
+ cross_attention_dim=768,
173
+ activation_fn="geglu",
174
+ attention_bias=False,
175
+ upcast_attention=False,
176
+ cross_frame_attention_mode=None,
177
+ temporal_position_encoding=False,
178
+ ):
179
+ super().__init__()
180
+
181
+ attention_blocks = []
182
+ norms = []
183
+
184
+ for block_name in attention_block_types:
185
+ attention_blocks.append(
186
+ VersatileAttention(
187
+ attention_mode=block_name.split("_")[0],
188
+ cross_attention_dim=cross_attention_dim if block_name.endswith(
189
+ "_Cross") else None,
190
+
191
+ query_dim=dim,
192
+ heads=num_attention_heads,
193
+ dim_head=attention_head_dim,
194
+ dropout=dropout,
195
+ bias=attention_bias,
196
+ upcast_attention=upcast_attention,
197
+
198
+ cross_frame_attention_mode=cross_frame_attention_mode,
199
+ temporal_position_encoding=temporal_position_encoding,
200
+ )
201
+ )
202
+ norms.append(nn.LayerNorm(dim))
203
+
204
+ self.attention_blocks = nn.ModuleList(attention_blocks)
205
+ self.norms = nn.ModuleList(norms)
206
+
207
+ self.ff = FeedForward(dim, dropout=dropout,
208
+ activation_fn=activation_fn)
209
+ self.ff_norm = nn.LayerNorm(dim)
210
+
211
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
212
+ for attention_block, norm in zip(self.attention_blocks, self.norms):
213
+ norm_hidden_states = norm(hidden_states)
214
+ hidden_states = attention_block(
215
+ norm_hidden_states,
216
+ encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
217
+ video_length=video_length,
218
+ ) + hidden_states
219
+
220
+ hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
221
+
222
+ output = hidden_states
223
+ return output
224
+
225
+
226
+ class PositionalEncoding(nn.Module):
227
+ def __init__(
228
+ self,
229
+ d_model,
230
+ dropout=0.,
231
+ ):
232
+ super().__init__()
233
+
234
+ max_length = 64
235
+ self.dropout = nn.Dropout(p=dropout)
236
+ position = torch.arange(max_length).unsqueeze(1)
237
+ div_term = torch.exp(torch.arange(0, d_model, 2)
238
+ * (-math.log(10000.0) / d_model))
239
+ pe = torch.zeros(1, max_length, d_model)
240
+ pe[0, :, 0::2] = torch.sin(position * div_term)
241
+ pe[0, :, 1::2] = torch.cos(position * div_term)
242
+ self.register_buffer('pos_encoding', pe)
243
+
244
+ def forward(self, x):
245
+ x = x + self.pos_encoding[:, :x.size(1)]
246
+ return self.dropout(x)
247
+
248
+
249
+ class VersatileAttention(CrossAttention):
250
+ def __init__(
251
+ self,
252
+ attention_mode=None,
253
+ cross_frame_attention_mode=None,
254
+ temporal_position_encoding=False,
255
+ *args, **kwargs
256
+ ):
257
+ super().__init__(*args, **kwargs)
258
+ assert attention_mode == "Temporal"
259
+
260
+ self.attention_mode = attention_mode
261
+ self.is_cross_attention = kwargs["cross_attention_dim"] is not None
262
+
263
+ self.pos_encoder = PositionalEncoding(
264
+ kwargs["query_dim"],
265
+ dropout=0.,
266
+ ) if (temporal_position_encoding and attention_mode == "Temporal") else None
267
+
268
+ def extra_repr(self):
269
+ return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
270
+
271
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
272
+ batch_size, sequence_length, _ = hidden_states.shape
273
+
274
+ if self.attention_mode == "Temporal":
275
+ d = hidden_states.shape[1]
276
+ hidden_states = rearrange(
277
+ hidden_states, "(b f) d c -> (b d) f c", f=video_length)
278
+
279
+ if self.pos_encoder is not None:
280
+ hidden_states = self.pos_encoder(hidden_states)
281
+
282
+ encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c",
283
+ d=d) if encoder_hidden_states is not None else encoder_hidden_states
284
+ else:
285
+ raise NotImplementedError
286
+
287
+ encoder_hidden_states = encoder_hidden_states
288
+
289
+ if self.group_norm is not None:
290
+ hidden_states = self.group_norm(
291
+ hidden_states.transpose(1, 2)).transpose(1, 2)
292
+
293
+ query = self.to_q(hidden_states)
294
+ dim = query.shape[-1]
295
+ query = self.reshape_heads_to_batch_dim(query)
296
+
297
+ if self.added_kv_proj_dim is not None:
298
+ raise NotImplementedError
299
+
300
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
301
+ key = self.to_k(encoder_hidden_states)
302
+ value = self.to_v(encoder_hidden_states)
303
+
304
+ key = self.reshape_heads_to_batch_dim(key)
305
+ value = self.reshape_heads_to_batch_dim(value)
306
+
307
+ if attention_mask is not None:
308
+ if attention_mask.shape[-1] != query.shape[1]:
309
+ target_length = query.shape[1]
310
+ attention_mask = F.pad(
311
+ attention_mask, (0, target_length), value=0.0)
312
+ attention_mask = attention_mask.repeat_interleave(
313
+ self.heads, dim=0)
314
+
315
+ if self._use_memory_efficient_attention_xformers:
316
+ hidden_states = self._memory_efficient_attention_xformers(
317
+ query, key, value, attention_mask)
318
+ hidden_states = hidden_states.to(query.dtype)
319
+ else:
320
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
321
+ hidden_states = self._attention(
322
+ query, key, value, attention_mask)
323
+ else:
324
+ hidden_states = self._sliced_attention(
325
+ query, key, value, sequence_length, dim, attention_mask)
326
+
327
+ # linear proj
328
+ hidden_states = self.to_out[0](hidden_states)
329
+
330
+ # dropout
331
+ hidden_states = self.to_out[1](hidden_states)
332
+
333
+ if self.attention_mode == "Temporal":
334
+ hidden_states = rearrange(
335
+ hidden_states, "(b d) f c -> (b f) d c", d=d)
336
+
337
+ return hidden_states
animatelcm/models/resnet.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from typing import Optional
7
+
8
+ from einops import rearrange
9
+
10
+
11
+ class InflatedConv3d(nn.Conv2d):
12
+ def forward(self, x):
13
+ video_length = x.shape[2]
14
+
15
+ x = rearrange(x, "b c f h w -> (b f) c h w")
16
+ x = super().forward(x)
17
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
18
+
19
+ return x
20
+
21
+
22
+ class InflatedGroupNorm(nn.GroupNorm):
23
+ def forward(self, x):
24
+ video_length = x.shape[2]
25
+
26
+ x = rearrange(x, "b c f h w -> (b f) c h w")
27
+ x = super().forward(x)
28
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
29
+
30
+ return x
31
+
32
+
33
+ class Upsample3D(nn.Module):
34
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
35
+ super().__init__()
36
+ self.channels = channels
37
+ self.out_channels = out_channels or channels
38
+ self.use_conv = use_conv
39
+ self.use_conv_transpose = use_conv_transpose
40
+ self.name = name
41
+
42
+ conv = None
43
+ if use_conv_transpose:
44
+ raise NotImplementedError
45
+ elif use_conv:
46
+ self.conv = InflatedConv3d(
47
+ self.channels, self.out_channels, 3, padding=1)
48
+
49
+ def forward(self, hidden_states, output_size=None):
50
+ assert hidden_states.shape[1] == self.channels
51
+
52
+ if self.use_conv_transpose:
53
+ raise NotImplementedError
54
+
55
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
56
+ dtype = hidden_states.dtype
57
+ if dtype == torch.bfloat16:
58
+ hidden_states = hidden_states.to(torch.float32)
59
+
60
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
61
+ if hidden_states.shape[0] >= 64:
62
+ hidden_states = hidden_states.contiguous()
63
+
64
+ # if `output_size` is passed we force the interpolation output
65
+ # size and do not make use of `scale_factor=2`
66
+ if output_size is None:
67
+ hidden_states = F.interpolate(hidden_states, scale_factor=[
68
+ 1.0, 2.0, 2.0], mode="nearest")
69
+ else:
70
+ hidden_states = F.interpolate(
71
+ hidden_states, size=output_size, mode="nearest")
72
+
73
+ # If the input is bfloat16, we cast back to bfloat16
74
+ if dtype == torch.bfloat16:
75
+ hidden_states = hidden_states.to(dtype)
76
+
77
+ # if self.use_conv:
78
+ # if self.name == "conv":
79
+ # hidden_states = self.conv(hidden_states)
80
+ # else:
81
+ # hidden_states = self.Conv2d_0(hidden_states)
82
+ hidden_states = self.conv(hidden_states)
83
+
84
+ return hidden_states
85
+
86
+
87
+ class Downsample3D(nn.Module):
88
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
89
+ super().__init__()
90
+ self.channels = channels
91
+ self.out_channels = out_channels or channels
92
+ self.use_conv = use_conv
93
+ self.padding = padding
94
+ stride = 2
95
+ self.name = name
96
+
97
+ if use_conv:
98
+ self.conv = InflatedConv3d(
99
+ self.channels, self.out_channels, 3, stride=stride, padding=padding)
100
+ else:
101
+ raise NotImplementedError
102
+
103
+ def forward(self, hidden_states):
104
+ assert hidden_states.shape[1] == self.channels
105
+ if self.use_conv and self.padding == 0:
106
+ raise NotImplementedError
107
+
108
+ assert hidden_states.shape[1] == self.channels
109
+ hidden_states = self.conv(hidden_states)
110
+
111
+ return hidden_states
112
+
113
+
114
+ class ResnetBlock3D(nn.Module):
115
+ def __init__(
116
+ self,
117
+ *,
118
+ in_channels,
119
+ out_channels=None,
120
+ conv_shortcut=False,
121
+ dropout=0.0,
122
+ temb_channels=512,
123
+ groups=32,
124
+ groups_out=None,
125
+ pre_norm=True,
126
+ eps=1e-6,
127
+ non_linearity="swish",
128
+ time_embedding_norm="default",
129
+ output_scale_factor=1.0,
130
+ use_in_shortcut=None,
131
+ use_inflated_groupnorm=None,
132
+ use_temporal_conv=False,
133
+ use_temporal_mixer=False,
134
+ ):
135
+ super().__init__()
136
+ self.pre_norm = pre_norm
137
+ self.pre_norm = True
138
+ self.in_channels = in_channels
139
+ out_channels = in_channels if out_channels is None else out_channels
140
+ self.out_channels = out_channels
141
+ self.use_conv_shortcut = conv_shortcut
142
+ self.time_embedding_norm = time_embedding_norm
143
+ self.output_scale_factor = output_scale_factor
144
+ self.use_temporal_mixer = use_temporal_mixer
145
+ if use_temporal_mixer:
146
+ self.temporal_mixer = AlphaBlender(0.3, "learned", None)
147
+
148
+ if groups_out is None:
149
+ groups_out = groups
150
+
151
+ assert use_inflated_groupnorm != None
152
+ if use_inflated_groupnorm:
153
+ self.norm1 = InflatedGroupNorm(
154
+ num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
155
+ else:
156
+ self.norm1 = torch.nn.GroupNorm(
157
+ num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
158
+
159
+ if use_temporal_conv:
160
+ self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=(
161
+ 3, 1, 1), stride=1, padding=(1, 0, 0))
162
+ else:
163
+ self.conv1 = InflatedConv3d(
164
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1)
165
+
166
+ if temb_channels is not None:
167
+ if self.time_embedding_norm == "default":
168
+ time_emb_proj_out_channels = out_channels
169
+ elif self.time_embedding_norm == "scale_shift":
170
+ time_emb_proj_out_channels = out_channels * 2
171
+ else:
172
+ raise ValueError(
173
+ f"unknown time_embedding_norm : {self.time_embedding_norm} ")
174
+
175
+ self.time_emb_proj = torch.nn.Linear(
176
+ temb_channels, time_emb_proj_out_channels)
177
+ else:
178
+ self.time_emb_proj = None
179
+
180
+ if use_inflated_groupnorm:
181
+ self.norm2 = InflatedGroupNorm(
182
+ num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
183
+ else:
184
+ self.norm2 = torch.nn.GroupNorm(
185
+ num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
186
+
187
+ self.dropout = torch.nn.Dropout(dropout)
188
+ if use_temporal_conv:
189
+ self.conv2 = nn.Conv3d(in_channels, out_channels, kernel_size=(
190
+ 3, 1, 1), stride=1, padding=(1, 0, 0))
191
+ else:
192
+ self.conv2 = InflatedConv3d(
193
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1)
194
+
195
+ if non_linearity == "swish":
196
+ self.nonlinearity = lambda x: F.silu(x)
197
+ elif non_linearity == "mish":
198
+ self.nonlinearity = Mish()
199
+ elif non_linearity == "silu":
200
+ self.nonlinearity = nn.SiLU()
201
+
202
+ self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
203
+
204
+ self.conv_shortcut = None
205
+ if self.use_in_shortcut:
206
+ self.conv_shortcut = InflatedConv3d(
207
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0)
208
+
209
+ def forward(self, input_tensor, temb):
210
+ if self.use_temporal_mixer:
211
+ residual = input_tensor
212
+
213
+ hidden_states = input_tensor
214
+
215
+ hidden_states = self.norm1(hidden_states)
216
+ hidden_states = self.nonlinearity(hidden_states)
217
+
218
+ hidden_states = self.conv1(hidden_states)
219
+
220
+ if temb is not None:
221
+ temb = self.time_emb_proj(self.nonlinearity(temb))[
222
+ :, :, None, None, None]
223
+
224
+ if temb is not None and self.time_embedding_norm == "default":
225
+ hidden_states = hidden_states + temb
226
+
227
+ hidden_states = self.norm2(hidden_states)
228
+
229
+ if temb is not None and self.time_embedding_norm == "scale_shift":
230
+ scale, shift = torch.chunk(temb, 2, dim=1)
231
+ hidden_states = hidden_states * (1 + scale) + shift
232
+
233
+ hidden_states = self.nonlinearity(hidden_states)
234
+
235
+ hidden_states = self.dropout(hidden_states)
236
+ hidden_states = self.conv2(hidden_states)
237
+
238
+ if self.conv_shortcut is not None:
239
+ input_tensor = self.conv_shortcut(input_tensor)
240
+
241
+ output_tensor = (input_tensor + hidden_states) / \
242
+ self.output_scale_factor
243
+
244
+ if self.use_temporal_mixer:
245
+ output_tensor = self.temporal_mixer(residual, output_tensor, None)
246
+ # return residual + 0.0 * self.temporal_mixer(residual, output_tensor, None)
247
+ return output_tensor
248
+
249
+
250
+ class Mish(torch.nn.Module):
251
+ def forward(self, hidden_states):
252
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
253
+
254
+
255
+ class AlphaBlender(nn.Module):
256
+ strategies = ["learned", "fixed", "learned_with_images"]
257
+
258
+ def __init__(
259
+ self,
260
+ alpha: float,
261
+ merge_strategy: str = "learned_with_images",
262
+ rearrange_pattern: str = "b t -> (b t) 1 1",
263
+ ):
264
+ super().__init__()
265
+ self.merge_strategy = merge_strategy
266
+ self.rearrange_pattern = rearrange_pattern
267
+ self.scaler = 10.
268
+
269
+ assert (
270
+ merge_strategy in self.strategies
271
+ ), f"merge_strategy needs to be in {self.strategies}"
272
+
273
+ if self.merge_strategy == "fixed":
274
+ self.register_buffer("mix_factor", torch.Tensor([alpha]))
275
+ elif (
276
+ self.merge_strategy == "learned"
277
+ or self.merge_strategy == "learned_with_images"
278
+ ):
279
+ self.register_parameter(
280
+ "mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
281
+ )
282
+ else:
283
+ raise ValueError(f"unknown merge strategy {self.merge_strategy}")
284
+
285
+ def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor:
286
+ if self.merge_strategy == "fixed":
287
+ alpha = self.mix_factor
288
+ elif self.merge_strategy == "learned":
289
+ alpha = torch.sigmoid(self.mix_factor*self.scaler)
290
+ elif self.merge_strategy == "learned_with_images":
291
+ assert image_only_indicator is not None, "need image_only_indicator ..."
292
+ alpha = torch.where(
293
+ image_only_indicator.bool(),
294
+ torch.ones(1, 1, device=image_only_indicator.device),
295
+ rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"),
296
+ )
297
+ alpha = rearrange(alpha, self.rearrange_pattern)
298
+ else:
299
+ raise NotImplementedError
300
+ return alpha
301
+
302
+ def forward(
303
+ self,
304
+ x_spatial: torch.Tensor,
305
+ x_temporal: torch.Tensor,
306
+ image_only_indicator: Optional[torch.Tensor] = None,
307
+ ) -> torch.Tensor:
308
+ alpha = self.get_alpha(image_only_indicator)
309
+ x = (
310
+ alpha.to(x_spatial.dtype) * x_spatial
311
+ + (1.0 - alpha).to(x_spatial.dtype) * x_temporal
312
+ )
313
+ return x
animatelcm/models/unet.py ADDED
@@ -0,0 +1,568 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
2
+
3
+ from dataclasses import dataclass
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import os
7
+ import json
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+
13
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
14
+ from diffusers.modeling_utils import ModelMixin
15
+ from diffusers.utils import BaseOutput, logging
16
+ from animatelcm.models.embeddings import TimestepEmbedding, Timesteps
17
+ from .unet_blocks import (
18
+ CrossAttnDownBlock3D,
19
+ CrossAttnUpBlock3D,
20
+ DownBlock3D,
21
+ UNetMidBlock3DCrossAttn,
22
+ UpBlock3D,
23
+ get_down_block,
24
+ get_up_block,
25
+ )
26
+ from .resnet import InflatedConv3d, InflatedGroupNorm
27
+ # from .adapter import Adapter, PixelAdapter # Not ready
28
+ from einops import repeat
29
+
30
+
31
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
+
33
+
34
+ @dataclass
35
+ class UNet3DConditionOutput(BaseOutput):
36
+ sample: torch.FloatTensor
37
+
38
+
39
+ class UNet3DConditionModel(ModelMixin, ConfigMixin):
40
+ _supports_gradient_checkpointing = True
41
+
42
+ @register_to_config
43
+ def __init__(
44
+ self,
45
+ sample_size: Optional[int] = None,
46
+ in_channels: int = 4,
47
+ out_channels: int = 4,
48
+ center_input_sample: bool = False,
49
+ flip_sin_to_cos: bool = True,
50
+ freq_shift: int = 0,
51
+ down_block_types: Tuple[str] = (
52
+ "CrossAttnDownBlock3D",
53
+ "CrossAttnDownBlock3D",
54
+ "CrossAttnDownBlock3D",
55
+ "DownBlock3D",
56
+ ),
57
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
58
+ up_block_types: Tuple[str] = (
59
+ "UpBlock3D",
60
+ "CrossAttnUpBlock3D",
61
+ "CrossAttnUpBlock3D",
62
+ "CrossAttnUpBlock3D"
63
+ ),
64
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
65
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
66
+ layers_per_block: int = 2,
67
+ downsample_padding: int = 1,
68
+ mid_block_scale_factor: float = 1,
69
+ act_fn: str = "silu",
70
+ norm_num_groups: int = 32,
71
+ norm_eps: float = 1e-5,
72
+ cross_attention_dim: int = 1280,
73
+ attention_head_dim: Union[int, Tuple[int]] = 8,
74
+ dual_cross_attention: bool = False,
75
+ use_linear_projection: bool = False,
76
+ class_embed_type: Optional[str] = None,
77
+ num_class_embeds: Optional[int] = None,
78
+ upcast_attention: bool = False,
79
+ resnet_time_scale_shift: str = "default",
80
+
81
+ use_inflated_groupnorm=False,
82
+
83
+ # Additional
84
+ use_motion_module=False,
85
+ use_motion_resnet=False,
86
+ motion_module_resolutions=(1, 2, 4, 8),
87
+ motion_module_mid_block=False,
88
+ motion_module_decoder_only=False,
89
+ motion_module_type=None,
90
+ motion_module_kwargs={},
91
+ unet_use_cross_frame_attention=None,
92
+ unet_use_temporal_attention=None,
93
+ time_cond_proj_dim=None, # not ready
94
+ use_img_encoder=False,
95
+ use_pixel_encoder=False,
96
+ ):
97
+ super().__init__()
98
+
99
+ self.sample_size = sample_size
100
+ time_embed_dim = block_out_channels[0] * 4
101
+
102
+ self.img_encoder = None if use_img_encoder else None # not ready
103
+ self.pixel_encoder = None if use_pixel_encoder else None # not ready
104
+
105
+
106
+ # input
107
+ self.conv_in = InflatedConv3d(
108
+ in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
109
+
110
+ # time
111
+ self.time_proj = Timesteps(
112
+ block_out_channels[0], flip_sin_to_cos, freq_shift)
113
+ timestep_input_dim = block_out_channels[0]
114
+
115
+ self.time_embedding = TimestepEmbedding(
116
+ timestep_input_dim, time_embed_dim, time_cond_proj_dim=time_cond_proj_dim)
117
+
118
+ # class embedding
119
+ if class_embed_type is None and num_class_embeds is not None:
120
+ self.class_embedding = nn.Embedding(
121
+ num_class_embeds, time_embed_dim)
122
+ elif class_embed_type == "timestep":
123
+ self.class_embedding = TimestepEmbedding(
124
+ timestep_input_dim, time_embed_dim)
125
+ elif class_embed_type == "identity":
126
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
127
+ else:
128
+ self.class_embedding = None
129
+
130
+ self.down_blocks = nn.ModuleList([])
131
+ self.mid_block = None
132
+ self.up_blocks = nn.ModuleList([])
133
+
134
+ if isinstance(only_cross_attention, bool):
135
+ only_cross_attention = [
136
+ only_cross_attention] * len(down_block_types)
137
+
138
+ if isinstance(attention_head_dim, int):
139
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
140
+
141
+ # down
142
+ output_channel = block_out_channels[0]
143
+ for i, down_block_type in enumerate(down_block_types):
144
+ res = 2 ** i
145
+ input_channel = output_channel
146
+ output_channel = block_out_channels[i]
147
+ is_final_block = i == len(block_out_channels) - 1
148
+
149
+ down_block = get_down_block(
150
+ down_block_type,
151
+ num_layers=layers_per_block,
152
+ in_channels=input_channel,
153
+ out_channels=output_channel,
154
+ temb_channels=time_embed_dim,
155
+ add_downsample=not is_final_block,
156
+ resnet_eps=norm_eps,
157
+ resnet_act_fn=act_fn,
158
+ resnet_groups=norm_num_groups,
159
+ cross_attention_dim=cross_attention_dim,
160
+ attn_num_head_channels=attention_head_dim[i],
161
+ downsample_padding=downsample_padding,
162
+ dual_cross_attention=dual_cross_attention,
163
+ use_linear_projection=use_linear_projection,
164
+ only_cross_attention=only_cross_attention[i],
165
+ upcast_attention=upcast_attention,
166
+ resnet_time_scale_shift=resnet_time_scale_shift,
167
+
168
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
169
+ unet_use_temporal_attention=unet_use_temporal_attention,
170
+ use_inflated_groupnorm=use_inflated_groupnorm,
171
+
172
+ use_motion_module=use_motion_module and (
173
+ res in motion_module_resolutions) and (not motion_module_decoder_only),
174
+ use_motion_resnet=use_motion_resnet and (
175
+ res in motion_module_resolutions) and (not motion_module_decoder_only),
176
+ motion_module_type=motion_module_type,
177
+ motion_module_kwargs=motion_module_kwargs,
178
+ )
179
+ self.down_blocks.append(down_block)
180
+
181
+ # mid
182
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
183
+ self.mid_block = UNetMidBlock3DCrossAttn(
184
+ in_channels=block_out_channels[-1],
185
+ temb_channels=time_embed_dim,
186
+ resnet_eps=norm_eps,
187
+ resnet_act_fn=act_fn,
188
+ output_scale_factor=mid_block_scale_factor,
189
+ resnet_time_scale_shift=resnet_time_scale_shift,
190
+ cross_attention_dim=cross_attention_dim,
191
+ attn_num_head_channels=attention_head_dim[-1],
192
+ resnet_groups=norm_num_groups,
193
+ dual_cross_attention=dual_cross_attention,
194
+ use_linear_projection=use_linear_projection,
195
+ upcast_attention=upcast_attention,
196
+
197
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
198
+ unet_use_temporal_attention=unet_use_temporal_attention,
199
+ use_inflated_groupnorm=use_inflated_groupnorm,
200
+
201
+ use_motion_module=use_motion_module and motion_module_mid_block,
202
+ use_motion_resnet=use_motion_resnet and motion_module_mid_block,
203
+
204
+ motion_module_type=motion_module_type,
205
+ motion_module_kwargs=motion_module_kwargs,
206
+ )
207
+ else:
208
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
209
+
210
+ # count how many layers upsample the videos
211
+ self.num_upsamplers = 0
212
+
213
+ # up
214
+ reversed_block_out_channels = list(reversed(block_out_channels))
215
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
216
+ only_cross_attention = list(reversed(only_cross_attention))
217
+ output_channel = reversed_block_out_channels[0]
218
+ for i, up_block_type in enumerate(up_block_types):
219
+ res = 2 ** (3 - i)
220
+ is_final_block = i == len(block_out_channels) - 1
221
+
222
+ prev_output_channel = output_channel
223
+ output_channel = reversed_block_out_channels[i]
224
+ input_channel = reversed_block_out_channels[min(
225
+ i + 1, len(block_out_channels) - 1)]
226
+
227
+ # add upsample block for all BUT final layer
228
+ if not is_final_block:
229
+ add_upsample = True
230
+ self.num_upsamplers += 1
231
+ else:
232
+ add_upsample = False
233
+
234
+ up_block = get_up_block(
235
+ up_block_type,
236
+ num_layers=layers_per_block + 1,
237
+ in_channels=input_channel,
238
+ out_channels=output_channel,
239
+ prev_output_channel=prev_output_channel,
240
+ temb_channels=time_embed_dim,
241
+ add_upsample=add_upsample,
242
+ resnet_eps=norm_eps,
243
+ resnet_act_fn=act_fn,
244
+ resnet_groups=norm_num_groups,
245
+ cross_attention_dim=cross_attention_dim,
246
+ attn_num_head_channels=reversed_attention_head_dim[i],
247
+ dual_cross_attention=dual_cross_attention,
248
+ use_linear_projection=use_linear_projection,
249
+ only_cross_attention=only_cross_attention[i],
250
+ upcast_attention=upcast_attention,
251
+ resnet_time_scale_shift=resnet_time_scale_shift,
252
+
253
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
254
+ unet_use_temporal_attention=unet_use_temporal_attention,
255
+ use_inflated_groupnorm=use_inflated_groupnorm,
256
+
257
+ use_motion_module=use_motion_module and (
258
+ res in motion_module_resolutions),
259
+ use_motion_resnet=use_motion_resnet and (
260
+ res in motion_module_resolutions),
261
+
262
+ motion_module_type=motion_module_type,
263
+ motion_module_kwargs=motion_module_kwargs,
264
+ )
265
+ self.up_blocks.append(up_block)
266
+ prev_output_channel = output_channel
267
+
268
+ # out
269
+ if use_inflated_groupnorm:
270
+ self.conv_norm_out = InflatedGroupNorm(
271
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
272
+ else:
273
+ self.conv_norm_out = nn.GroupNorm(
274
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
275
+ self.conv_act = nn.SiLU()
276
+ self.conv_out = InflatedConv3d(
277
+ block_out_channels[0], out_channels, kernel_size=3, padding=1)
278
+
279
+ def set_attention_slice(self, slice_size):
280
+ r"""
281
+ Enable sliced attention computation.
282
+
283
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
284
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
285
+
286
+ Args:
287
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
288
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
289
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
290
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
291
+ must be a multiple of `slice_size`.
292
+ """
293
+ sliceable_head_dims = []
294
+
295
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
296
+ if hasattr(module, "set_attention_slice"):
297
+ sliceable_head_dims.append(module.sliceable_head_dim)
298
+
299
+ for child in module.children():
300
+ fn_recursive_retrieve_slicable_dims(child)
301
+
302
+ # retrieve number of attention layers
303
+ for module in self.children():
304
+ fn_recursive_retrieve_slicable_dims(module)
305
+
306
+ num_slicable_layers = len(sliceable_head_dims)
307
+
308
+ if slice_size == "auto":
309
+ # half the attention head size is usually a good trade-off between
310
+ # speed and memory
311
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
312
+ elif slice_size == "max":
313
+ # make smallest slice possible
314
+ slice_size = num_slicable_layers * [1]
315
+
316
+ slice_size = num_slicable_layers * \
317
+ [slice_size] if not isinstance(slice_size, list) else slice_size
318
+
319
+ if len(slice_size) != len(sliceable_head_dims):
320
+ raise ValueError(
321
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
322
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
323
+ )
324
+
325
+ for i in range(len(slice_size)):
326
+ size = slice_size[i]
327
+ dim = sliceable_head_dims[i]
328
+ if size is not None and size > dim:
329
+ raise ValueError(
330
+ f"size {size} has to be smaller or equal to {dim}.")
331
+
332
+ # Recursively walk through all the children.
333
+ # Any children which exposes the set_attention_slice method
334
+ # gets the message
335
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
336
+ if hasattr(module, "set_attention_slice"):
337
+ module.set_attention_slice(slice_size.pop())
338
+
339
+ for child in module.children():
340
+ fn_recursive_set_attention_slice(child, slice_size)
341
+
342
+ reversed_slice_size = list(reversed(slice_size))
343
+ for module in self.children():
344
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
345
+
346
+ def _set_gradient_checkpointing(self, module, value=False):
347
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
348
+ module.gradient_checkpointing = value
349
+
350
+ def forward(
351
+ self,
352
+ sample: torch.FloatTensor,
353
+ timestep: Union[torch.Tensor, float, int],
354
+ encoder_hidden_states: torch.Tensor,
355
+ img_latent: torch.FloatTensor = None,
356
+ control: torch.FloatTensor = None,
357
+ time_cond: torch.FloatTensor = None, # not ready
358
+ class_labels: Optional[torch.Tensor] = None,
359
+ attention_mask: Optional[torch.Tensor] = None,
360
+ return_dict: bool = True,
361
+ ) -> Union[UNet3DConditionOutput, Tuple]:
362
+ r"""
363
+ Args:
364
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
365
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
366
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
367
+ return_dict (`bool`, *optional*, defaults to `True`):
368
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
369
+
370
+ Returns:
371
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
372
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
373
+ returning a tuple, the first element is the sample tensor.
374
+ """
375
+
376
+ if img_latent is not None and self.img_encoder is not None:
377
+ f = sample.shape[2]
378
+ img_latent = repeat(img_latent, "b c h w -> b c f h w",
379
+ f=f) if img_latent.ndim == 4 else img_latent
380
+ img_features = self.img_encoder(img_latent)
381
+ else:
382
+ img_features = None
383
+
384
+ if control is not None and self.pixel_encoder is not None:
385
+ ctrl_features = self.pixel_encoder(control)
386
+ else:
387
+ # assert 0
388
+ ctrl_features = None
389
+
390
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
391
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
392
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
393
+ # on the fly if necessary.
394
+ default_overall_up_factor = 2**self.num_upsamplers
395
+
396
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
397
+ forward_upsample_size = False
398
+ upsample_size = None
399
+
400
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
401
+ logger.info(
402
+ "Forward upsample size to force interpolation output size.")
403
+ forward_upsample_size = True
404
+
405
+ # prepare attention_mask
406
+ if attention_mask is not None:
407
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
408
+ attention_mask = attention_mask.unsqueeze(1)
409
+
410
+ # center input if necessary
411
+ if self.config.center_input_sample:
412
+ sample = 2 * sample - 1.0
413
+
414
+ # time
415
+ timesteps = timestep
416
+ if not torch.is_tensor(timesteps):
417
+ # This would be a good case for the `match` statement (Python 3.10+)
418
+ is_mps = sample.device.type == "mps"
419
+ if isinstance(timestep, float):
420
+ dtype = torch.float32 if is_mps else torch.float64
421
+ else:
422
+ dtype = torch.int32 if is_mps else torch.int64
423
+ timesteps = torch.tensor(
424
+ [timesteps], dtype=dtype, device=sample.device)
425
+ elif len(timesteps.shape) == 0:
426
+ timesteps = timesteps[None].to(sample.device)
427
+
428
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
429
+ timesteps = timesteps.expand(sample.shape[0])
430
+
431
+ t_emb = self.time_proj(timesteps)
432
+
433
+ # timesteps does not contain any weights and will always return f32 tensors
434
+ # but time_embedding might actually be running in fp16. so we need to cast here.
435
+ # there might be better ways to encapsulate this.
436
+ t_emb = t_emb.to(dtype=self.dtype)
437
+
438
+ emb = self.time_embedding(t_emb)
439
+
440
+ if self.class_embedding is not None:
441
+ if class_labels is None:
442
+ raise ValueError(
443
+ "class_labels should be provided when num_class_embeds > 0")
444
+
445
+ if self.config.class_embed_type == "timestep":
446
+ class_labels = self.time_proj(class_labels)
447
+
448
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
449
+ emb = emb + class_emb
450
+
451
+ # pre-process
452
+ sample = self.conv_in(sample)
453
+
454
+ # down
455
+
456
+ down_block_res_samples = (sample,)
457
+
458
+ img_feature_idx = 0
459
+
460
+ for downsample_block in self.down_blocks:
461
+
462
+ added_feature = img_features[img_feature_idx] if img_features is not None else torch.tensor(
463
+ 0.).to(sample.device, sample.dtype)
464
+ added_feature = added_feature + \
465
+ ctrl_features[img_feature_idx] if ctrl_features is not None else added_feature
466
+ added_feature = None if added_feature.abs().mean() == 0 else added_feature
467
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
468
+ sample, res_samples = downsample_block(
469
+ hidden_states=sample,
470
+ temb=emb,
471
+ encoder_hidden_states=encoder_hidden_states,
472
+ attention_mask=attention_mask,
473
+ img_feature=added_feature
474
+ )
475
+ else:
476
+ sample, res_samples = downsample_block(
477
+ hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, img_feature=added_feature)
478
+
479
+ down_block_res_samples += res_samples
480
+ img_feature_idx += 1
481
+ # mid
482
+ sample = self.mid_block(
483
+ sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
484
+ )
485
+
486
+ # up
487
+ for i, upsample_block in enumerate(self.up_blocks):
488
+ is_final_block = i == len(self.up_blocks) - 1
489
+
490
+ res_samples = down_block_res_samples[-len(upsample_block.resnets):]
491
+ down_block_res_samples = down_block_res_samples[: -len(
492
+ upsample_block.resnets)]
493
+
494
+ # if we have not reached the final block and need to forward the
495
+ # upsample size, we do it here
496
+ if not is_final_block and forward_upsample_size:
497
+ upsample_size = down_block_res_samples[-1].shape[2:]
498
+
499
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
500
+ sample = upsample_block(
501
+ hidden_states=sample,
502
+ temb=emb,
503
+ res_hidden_states_tuple=res_samples,
504
+ encoder_hidden_states=encoder_hidden_states,
505
+ upsample_size=upsample_size,
506
+ attention_mask=attention_mask,
507
+ )
508
+ else:
509
+ sample = upsample_block(
510
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states,
511
+ )
512
+
513
+ # post-process
514
+ sample = self.conv_norm_out(sample)
515
+ sample = self.conv_act(sample)
516
+ sample = self.conv_out(sample)
517
+
518
+ if not return_dict:
519
+ return (sample,)
520
+
521
+ return UNet3DConditionOutput(sample=sample)
522
+
523
+ @classmethod
524
+ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None):
525
+ if subfolder is not None:
526
+ pretrained_model_path = os.path.join(
527
+ pretrained_model_path, subfolder)
528
+ print(
529
+ f"loaded temporal unet's pretrained weights from {pretrained_model_path} ...")
530
+
531
+ config_file = os.path.join(pretrained_model_path, 'config.json')
532
+ if not os.path.isfile(config_file):
533
+ raise RuntimeError(f"{config_file} does not exist")
534
+ with open(config_file, "r") as f:
535
+ config = json.load(f)
536
+ config["_class_name"] = cls.__name__
537
+ config["down_block_types"] = [
538
+ "CrossAttnDownBlock3D",
539
+ "CrossAttnDownBlock3D",
540
+ "CrossAttnDownBlock3D",
541
+ "DownBlock3D"
542
+ ]
543
+ config["up_block_types"] = [
544
+ "UpBlock3D",
545
+ "CrossAttnUpBlock3D",
546
+ "CrossAttnUpBlock3D",
547
+ "CrossAttnUpBlock3D"
548
+ ]
549
+
550
+ from diffusers.utils import WEIGHTS_NAME
551
+ model = cls.from_config(config, **unet_additional_kwargs)
552
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
553
+ if not os.path.isfile(model_file):
554
+ raise RuntimeError(f"{model_file} does not exist")
555
+ state_dict = torch.load(model_file, map_location="cpu")
556
+ if "state_dict" in state_dict.keys():
557
+ state_dict = state_dict["state_dict"]
558
+ state_dict = {k.replace("module.", ""): v for k,
559
+ v in state_dict.items()}
560
+ m, u = model.load_state_dict(state_dict, strict=False)
561
+ print("###load unet weights")
562
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
563
+
564
+ params = [p.numel() if "motion" in n else 0 for n,
565
+ p in model.named_parameters()]
566
+ print(f"### Temporal Module Parameters: {sum(params) / 1e6} M")
567
+
568
+ return model
animatelcm/models/unet_blocks.py ADDED
@@ -0,0 +1,904 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from .attention import Transformer3DModel
7
+ from .resnet import Downsample3D, ResnetBlock3D, Upsample3D, AlphaBlender
8
+ from .motion_module import get_motion_module
9
+
10
+
11
+ def get_down_block(
12
+ down_block_type,
13
+ num_layers,
14
+ in_channels,
15
+ out_channels,
16
+ temb_channels,
17
+ add_downsample,
18
+ resnet_eps,
19
+ resnet_act_fn,
20
+ attn_num_head_channels,
21
+ resnet_groups=None,
22
+ cross_attention_dim=None,
23
+ downsample_padding=None,
24
+ dual_cross_attention=False,
25
+ use_linear_projection=False,
26
+ only_cross_attention=False,
27
+ upcast_attention=False,
28
+ resnet_time_scale_shift="default",
29
+
30
+ unet_use_cross_frame_attention=None,
31
+ unet_use_temporal_attention=None,
32
+ use_inflated_groupnorm=None,
33
+
34
+ use_motion_module=None,
35
+ use_motion_resnet=None, # not used for current weight
36
+
37
+ motion_module_type=None,
38
+ motion_module_kwargs=None,
39
+ ):
40
+ down_block_type = down_block_type[7:] if down_block_type.startswith(
41
+ "UNetRes") else down_block_type
42
+ if down_block_type == "DownBlock3D":
43
+ return DownBlock3D(
44
+ num_layers=num_layers,
45
+ in_channels=in_channels,
46
+ out_channels=out_channels,
47
+ temb_channels=temb_channels,
48
+ add_downsample=add_downsample,
49
+ resnet_eps=resnet_eps,
50
+ resnet_act_fn=resnet_act_fn,
51
+ resnet_groups=resnet_groups,
52
+ downsample_padding=downsample_padding,
53
+ resnet_time_scale_shift=resnet_time_scale_shift,
54
+
55
+ use_inflated_groupnorm=use_inflated_groupnorm,
56
+
57
+ use_motion_module=use_motion_module,
58
+ motion_module_type=motion_module_type,
59
+ motion_module_kwargs=motion_module_kwargs,
60
+ )
61
+ elif down_block_type == "CrossAttnDownBlock3D":
62
+ if cross_attention_dim is None:
63
+ raise ValueError(
64
+ "cross_attention_dim must be specified for CrossAttnDownBlock3D")
65
+ return CrossAttnDownBlock3D(
66
+ num_layers=num_layers,
67
+ in_channels=in_channels,
68
+ out_channels=out_channels,
69
+ temb_channels=temb_channels,
70
+ add_downsample=add_downsample,
71
+ resnet_eps=resnet_eps,
72
+ resnet_act_fn=resnet_act_fn,
73
+ resnet_groups=resnet_groups,
74
+ downsample_padding=downsample_padding,
75
+ cross_attention_dim=cross_attention_dim,
76
+ attn_num_head_channels=attn_num_head_channels,
77
+ dual_cross_attention=dual_cross_attention,
78
+ use_linear_projection=use_linear_projection,
79
+ only_cross_attention=only_cross_attention,
80
+ upcast_attention=upcast_attention,
81
+ resnet_time_scale_shift=resnet_time_scale_shift,
82
+
83
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
84
+ unet_use_temporal_attention=unet_use_temporal_attention,
85
+ use_inflated_groupnorm=use_inflated_groupnorm,
86
+
87
+ use_motion_module=use_motion_module,
88
+ use_motion_resnet=use_motion_resnet,
89
+ motion_module_type=motion_module_type,
90
+ motion_module_kwargs=motion_module_kwargs,
91
+ )
92
+ raise ValueError(f"{down_block_type} does not exist.")
93
+
94
+
95
+ def get_up_block(
96
+ up_block_type,
97
+ num_layers,
98
+ in_channels,
99
+ out_channels,
100
+ prev_output_channel,
101
+ temb_channels,
102
+ add_upsample,
103
+ resnet_eps,
104
+ resnet_act_fn,
105
+ attn_num_head_channels,
106
+ resnet_groups=None,
107
+ cross_attention_dim=None,
108
+ dual_cross_attention=False,
109
+ use_linear_projection=False,
110
+ only_cross_attention=False,
111
+ upcast_attention=False,
112
+ resnet_time_scale_shift="default",
113
+
114
+ unet_use_cross_frame_attention=None,
115
+ unet_use_temporal_attention=None,
116
+ use_inflated_groupnorm=None,
117
+
118
+ use_motion_module=None,
119
+ use_motion_resnet=None,
120
+ motion_module_type=None,
121
+ motion_module_kwargs=None,
122
+ ):
123
+ up_block_type = up_block_type[7:] if up_block_type.startswith(
124
+ "UNetRes") else up_block_type
125
+ if up_block_type == "UpBlock3D":
126
+ return UpBlock3D(
127
+ num_layers=num_layers,
128
+ in_channels=in_channels,
129
+ out_channels=out_channels,
130
+ prev_output_channel=prev_output_channel,
131
+ temb_channels=temb_channels,
132
+ add_upsample=add_upsample,
133
+ resnet_eps=resnet_eps,
134
+ resnet_act_fn=resnet_act_fn,
135
+ resnet_groups=resnet_groups,
136
+ resnet_time_scale_shift=resnet_time_scale_shift,
137
+
138
+ use_inflated_groupnorm=use_inflated_groupnorm,
139
+
140
+ use_motion_module=use_motion_module,
141
+ motion_module_type=motion_module_type,
142
+ motion_module_kwargs=motion_module_kwargs,
143
+ )
144
+ elif up_block_type == "CrossAttnUpBlock3D":
145
+ if cross_attention_dim is None:
146
+ raise ValueError(
147
+ "cross_attention_dim must be specified for CrossAttnUpBlock3D")
148
+ return CrossAttnUpBlock3D(
149
+ num_layers=num_layers,
150
+ in_channels=in_channels,
151
+ out_channels=out_channels,
152
+ prev_output_channel=prev_output_channel,
153
+ temb_channels=temb_channels,
154
+ add_upsample=add_upsample,
155
+ resnet_eps=resnet_eps,
156
+ resnet_act_fn=resnet_act_fn,
157
+ resnet_groups=resnet_groups,
158
+ cross_attention_dim=cross_attention_dim,
159
+ attn_num_head_channels=attn_num_head_channels,
160
+ dual_cross_attention=dual_cross_attention,
161
+ use_linear_projection=use_linear_projection,
162
+ only_cross_attention=only_cross_attention,
163
+ upcast_attention=upcast_attention,
164
+ resnet_time_scale_shift=resnet_time_scale_shift,
165
+
166
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
167
+ unet_use_temporal_attention=unet_use_temporal_attention,
168
+ use_inflated_groupnorm=use_inflated_groupnorm,
169
+
170
+ use_motion_module=use_motion_module,
171
+ use_motion_resnet=use_motion_resnet,
172
+ motion_module_type=motion_module_type,
173
+ motion_module_kwargs=motion_module_kwargs,
174
+ )
175
+ raise ValueError(f"{up_block_type} does not exist.")
176
+
177
+
178
+ class UNetMidBlock3DCrossAttn(nn.Module):
179
+ def __init__(
180
+ self,
181
+ in_channels: int,
182
+ temb_channels: int,
183
+ dropout: float = 0.0,
184
+ num_layers: int = 1,
185
+ resnet_eps: float = 1e-6,
186
+ resnet_time_scale_shift: str = "default",
187
+ resnet_act_fn: str = "swish",
188
+ resnet_groups: int = 32,
189
+ resnet_pre_norm: bool = True,
190
+ attn_num_head_channels=1,
191
+ output_scale_factor=1.0,
192
+ cross_attention_dim=1280,
193
+ dual_cross_attention=False,
194
+ use_linear_projection=False,
195
+ upcast_attention=False,
196
+
197
+ unet_use_cross_frame_attention=None,
198
+ unet_use_temporal_attention=None,
199
+ use_inflated_groupnorm=None,
200
+
201
+ use_motion_module=None,
202
+ use_motion_resnet=None,
203
+
204
+ motion_module_type=None,
205
+ motion_module_kwargs=None,
206
+ ):
207
+ super().__init__()
208
+
209
+ self.has_cross_attention = True
210
+ self.attn_num_head_channels = attn_num_head_channels
211
+ resnet_groups = resnet_groups if resnet_groups is not None else min(
212
+ in_channels // 4, 32)
213
+
214
+ # there is always at least one resnet
215
+ resnets = [
216
+ ResnetBlock3D(
217
+ in_channels=in_channels,
218
+ out_channels=in_channels,
219
+ temb_channels=temb_channels,
220
+ eps=resnet_eps,
221
+ groups=resnet_groups,
222
+ dropout=dropout,
223
+ time_embedding_norm=resnet_time_scale_shift,
224
+ non_linearity=resnet_act_fn,
225
+ output_scale_factor=output_scale_factor,
226
+ pre_norm=resnet_pre_norm,
227
+ use_inflated_groupnorm=use_inflated_groupnorm,
228
+ )
229
+ ]
230
+ motion_resnets = [
231
+ ResnetBlock3D(
232
+ in_channels=in_channels,
233
+ out_channels=in_channels,
234
+ temb_channels=temb_channels,
235
+ eps=resnet_eps,
236
+ groups=resnet_groups,
237
+ dropout=dropout,
238
+ time_embedding_norm=resnet_time_scale_shift,
239
+ non_linearity=resnet_act_fn,
240
+ output_scale_factor=output_scale_factor,
241
+ pre_norm=resnet_pre_norm,
242
+ use_inflated_groupnorm=use_inflated_groupnorm,
243
+ use_temporal_conv=True,
244
+ use_temporal_mixer=True,
245
+ ) if use_motion_resnet else None
246
+ ]
247
+
248
+ attentions = []
249
+ motion_modules = []
250
+
251
+ for _ in range(num_layers):
252
+ if dual_cross_attention:
253
+ raise NotImplementedError
254
+ attentions.append(
255
+ Transformer3DModel(
256
+ attn_num_head_channels,
257
+ in_channels // attn_num_head_channels,
258
+ in_channels=in_channels,
259
+ num_layers=1,
260
+ cross_attention_dim=cross_attention_dim,
261
+ norm_num_groups=resnet_groups,
262
+ use_linear_projection=use_linear_projection,
263
+ upcast_attention=upcast_attention,
264
+
265
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
266
+ unet_use_temporal_attention=unet_use_temporal_attention,
267
+ )
268
+ )
269
+ motion_modules.append(
270
+ get_motion_module(
271
+ in_channels=in_channels,
272
+ motion_module_type=motion_module_type,
273
+ motion_module_kwargs=motion_module_kwargs,
274
+ ) if use_motion_module else None
275
+ )
276
+ resnets.append(
277
+ ResnetBlock3D(
278
+ in_channels=in_channels,
279
+ out_channels=in_channels,
280
+ temb_channels=temb_channels,
281
+ eps=resnet_eps,
282
+ groups=resnet_groups,
283
+ dropout=dropout,
284
+ time_embedding_norm=resnet_time_scale_shift,
285
+ non_linearity=resnet_act_fn,
286
+ output_scale_factor=output_scale_factor,
287
+ pre_norm=resnet_pre_norm,
288
+
289
+ use_inflated_groupnorm=use_inflated_groupnorm,
290
+ )
291
+ )
292
+ motion_resnets.append(
293
+ ResnetBlock3D(
294
+ in_channels=in_channels,
295
+ out_channels=in_channels,
296
+ temb_channels=temb_channels,
297
+ eps=resnet_eps,
298
+ groups=resnet_groups,
299
+ dropout=dropout,
300
+ time_embedding_norm=resnet_time_scale_shift,
301
+ non_linearity=resnet_act_fn,
302
+ output_scale_factor=output_scale_factor,
303
+ pre_norm=resnet_pre_norm,
304
+ use_inflated_groupnorm=use_inflated_groupnorm,
305
+ use_temporal_conv=True,
306
+ use_temporal_mixer=True,
307
+ ) if use_motion_resnet else None
308
+ )
309
+
310
+ self.attentions = nn.ModuleList(attentions)
311
+ self.resnets = nn.ModuleList(resnets)
312
+ self.motion_modules = nn.ModuleList(motion_modules)
313
+ self.motion_resnets = nn.ModuleList(motion_resnets)
314
+
315
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
316
+ hidden_states = self.resnets[0](hidden_states, temb)
317
+ hidden_states = self.motion_resnets[0](
318
+ hidden_states, temb) if self.motion_resnets[0] is not None else hidden_states
319
+
320
+ for attn, resnet, motion_module, motion_resnet in zip(self.attentions, self.resnets[1:], self.motion_modules, self.motion_resnets[1:]):
321
+ hidden_states = attn(
322
+ hidden_states, encoder_hidden_states=encoder_hidden_states).sample
323
+ hidden_states = motion_module(
324
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
325
+ hidden_states = resnet(hidden_states, temb)
326
+ hidden_states = motion_resnet(
327
+ hidden_states, temb) if motion_resnet is not None else hidden_states
328
+
329
+ return hidden_states
330
+
331
+
332
+ class CrossAttnDownBlock3D(nn.Module):
333
+ def __init__(
334
+ self,
335
+ in_channels: int,
336
+ out_channels: int,
337
+ temb_channels: int,
338
+ dropout: float = 0.0,
339
+ num_layers: int = 1,
340
+ resnet_eps: float = 1e-6,
341
+ resnet_time_scale_shift: str = "default",
342
+ resnet_act_fn: str = "swish",
343
+ resnet_groups: int = 32,
344
+ resnet_pre_norm: bool = True,
345
+ attn_num_head_channels=1,
346
+ cross_attention_dim=1280,
347
+ output_scale_factor=1.0,
348
+ downsample_padding=1,
349
+ add_downsample=True,
350
+ dual_cross_attention=False,
351
+ use_linear_projection=False,
352
+ only_cross_attention=False,
353
+ upcast_attention=False,
354
+
355
+ unet_use_cross_frame_attention=None,
356
+ unet_use_temporal_attention=None,
357
+ use_inflated_groupnorm=None,
358
+
359
+ use_motion_module=None,
360
+ use_motion_resnet=None,
361
+
362
+ motion_module_type=None,
363
+ motion_module_kwargs=None,
364
+ ):
365
+ super().__init__()
366
+ resnets = []
367
+ motion_resnets = []
368
+ attentions = []
369
+ motion_modules = []
370
+
371
+ self.has_cross_attention = True
372
+ self.attn_num_head_channels = attn_num_head_channels
373
+
374
+ for i in range(num_layers):
375
+ in_channels = in_channels if i == 0 else out_channels
376
+ resnets.append(
377
+ ResnetBlock3D(
378
+ in_channels=in_channels,
379
+ out_channels=out_channels,
380
+ temb_channels=temb_channels,
381
+ eps=resnet_eps,
382
+ groups=resnet_groups,
383
+ dropout=dropout,
384
+ time_embedding_norm=resnet_time_scale_shift,
385
+ non_linearity=resnet_act_fn,
386
+ output_scale_factor=output_scale_factor,
387
+ pre_norm=resnet_pre_norm,
388
+
389
+ use_inflated_groupnorm=use_inflated_groupnorm,
390
+ )
391
+ )
392
+ motion_resnets.append(
393
+ ResnetBlock3D(
394
+ in_channels=out_channels,
395
+ out_channels=out_channels,
396
+ temb_channels=temb_channels,
397
+ eps=resnet_eps,
398
+ groups=resnet_groups,
399
+ dropout=dropout,
400
+ time_embedding_norm=resnet_time_scale_shift,
401
+ non_linearity=resnet_act_fn,
402
+ output_scale_factor=output_scale_factor,
403
+ pre_norm=resnet_pre_norm,
404
+ use_inflated_groupnorm=use_inflated_groupnorm,
405
+ use_temporal_conv=True,
406
+ use_temporal_mixer=True,
407
+ ) if use_motion_resnet else None
408
+ )
409
+ if dual_cross_attention:
410
+ raise NotImplementedError
411
+ attentions.append(
412
+ Transformer3DModel(
413
+ attn_num_head_channels,
414
+ out_channels // attn_num_head_channels,
415
+ in_channels=out_channels,
416
+ num_layers=1,
417
+ cross_attention_dim=cross_attention_dim,
418
+ norm_num_groups=resnet_groups,
419
+ use_linear_projection=use_linear_projection,
420
+ only_cross_attention=only_cross_attention,
421
+ upcast_attention=upcast_attention,
422
+
423
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
424
+ unet_use_temporal_attention=unet_use_temporal_attention,
425
+ )
426
+ )
427
+ motion_modules.append(
428
+ get_motion_module(
429
+ in_channels=out_channels,
430
+ motion_module_type=motion_module_type,
431
+ motion_module_kwargs=motion_module_kwargs,
432
+ ) if use_motion_module else None
433
+ )
434
+
435
+ self.attentions = nn.ModuleList(attentions)
436
+ self.resnets = nn.ModuleList(resnets)
437
+ self.motion_modules = nn.ModuleList(motion_modules)
438
+ self.motion_resnets = nn.ModuleList(motion_resnets)
439
+
440
+ if add_downsample:
441
+ self.downsamplers = nn.ModuleList(
442
+ [
443
+ Downsample3D(
444
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
445
+ )
446
+ ]
447
+ )
448
+ else:
449
+ self.downsamplers = None
450
+
451
+ self.gradient_checkpointing = False
452
+
453
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, img_feature=None):
454
+ output_states = ()
455
+ idx = 1
456
+ for resnet, attn, motion_module, motion_resnet in zip(self.resnets, self.attentions, self.motion_modules, self.motion_resnets):
457
+ if self.training and self.gradient_checkpointing:
458
+
459
+ def create_custom_forward(module, return_dict=None):
460
+ def custom_forward(*inputs):
461
+ if return_dict is not None:
462
+ return module(*inputs, return_dict=return_dict)
463
+ else:
464
+ return module(*inputs)
465
+
466
+ return custom_forward
467
+
468
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(
469
+ resnet), hidden_states.requires_grad_(), temb, use_reentrant=False)
470
+ if motion_resnet is not None:
471
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(
472
+ motion_resnet), hidden_states.requires_grad_(), temb, use_reentrant=False)
473
+
474
+ hidden_states = torch.utils.checkpoint.checkpoint(
475
+ create_custom_forward(attn, return_dict=False),
476
+ hidden_states.requires_grad_(),
477
+ encoder_hidden_states,
478
+ use_reentrant=False
479
+ )[0]
480
+
481
+ hidden_states = hidden_states + \
482
+ img_feature if (
483
+ img_feature is not None and idx == 2) else hidden_states
484
+
485
+ if motion_module is not None:
486
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(
487
+ motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states, use_reentrant=False)
488
+
489
+ else:
490
+ hidden_states = resnet(hidden_states, temb)
491
+
492
+ hidden_states = motion_resnet(
493
+ hidden_states, temb) if motion_resnet is not None else hidden_states
494
+
495
+ hidden_states = attn(
496
+ hidden_states, encoder_hidden_states=encoder_hidden_states).sample
497
+
498
+ hidden_states = hidden_states + \
499
+ img_feature if (
500
+ img_feature is not None and idx == 2) else hidden_states
501
+
502
+ # add motion module
503
+ hidden_states = motion_module(
504
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
505
+
506
+ idx += 1
507
+ output_states += (hidden_states,)
508
+
509
+ if self.downsamplers is not None:
510
+ for downsampler in self.downsamplers:
511
+ hidden_states = downsampler(hidden_states)
512
+
513
+ output_states += (hidden_states,)
514
+
515
+ return hidden_states, output_states
516
+
517
+
518
+ class DownBlock3D(nn.Module):
519
+ def __init__(
520
+ self,
521
+ in_channels: int,
522
+ out_channels: int,
523
+ temb_channels: int,
524
+ dropout: float = 0.0,
525
+ num_layers: int = 1,
526
+ resnet_eps: float = 1e-6,
527
+ resnet_time_scale_shift: str = "default",
528
+ resnet_act_fn: str = "swish",
529
+ resnet_groups: int = 32,
530
+ resnet_pre_norm: bool = True,
531
+ output_scale_factor=1.0,
532
+ add_downsample=True,
533
+ downsample_padding=1,
534
+
535
+ use_inflated_groupnorm=None,
536
+
537
+ use_motion_module=None,
538
+ motion_module_type=None,
539
+ motion_module_kwargs=None,
540
+ ):
541
+ super().__init__()
542
+ resnets = []
543
+ motion_modules = []
544
+
545
+ for i in range(num_layers):
546
+ in_channels = in_channels if i == 0 else out_channels
547
+ resnets.append(
548
+ ResnetBlock3D(
549
+ in_channels=in_channels,
550
+ out_channels=out_channels,
551
+ temb_channels=temb_channels,
552
+ eps=resnet_eps,
553
+ groups=resnet_groups,
554
+ dropout=dropout,
555
+ time_embedding_norm=resnet_time_scale_shift,
556
+ non_linearity=resnet_act_fn,
557
+ output_scale_factor=output_scale_factor,
558
+ pre_norm=resnet_pre_norm,
559
+
560
+ use_inflated_groupnorm=use_inflated_groupnorm,
561
+ )
562
+ )
563
+ motion_modules.append(
564
+ get_motion_module(
565
+ in_channels=out_channels,
566
+ motion_module_type=motion_module_type,
567
+ motion_module_kwargs=motion_module_kwargs,
568
+ ) if use_motion_module else None
569
+ )
570
+
571
+ self.resnets = nn.ModuleList(resnets)
572
+ self.motion_modules = nn.ModuleList(motion_modules)
573
+
574
+ if add_downsample:
575
+ self.downsamplers = nn.ModuleList(
576
+ [
577
+ Downsample3D(
578
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
579
+ )
580
+ ]
581
+ )
582
+ else:
583
+ self.downsamplers = None
584
+
585
+ self.gradient_checkpointing = False
586
+
587
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, img_feature=None):
588
+ output_states = ()
589
+
590
+ idx = 1
591
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
592
+ if self.training and self.gradient_checkpointing:
593
+ def create_custom_forward(module):
594
+ def custom_forward(*inputs):
595
+ return module(*inputs)
596
+
597
+ return custom_forward
598
+
599
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(
600
+ resnet), hidden_states.requires_grad_(), temb, use_reentrant=False)
601
+ hidden_states = hidden_states + \
602
+ img_feature if (
603
+ img_feature is not None and idx == 2) else hidden_states
604
+ if motion_module is not None:
605
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(
606
+ motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states, use_reentrant=False)
607
+ else:
608
+ hidden_states = resnet(hidden_states, temb)
609
+ hidden_states = hidden_states + \
610
+ img_feature if (
611
+ img_feature is not None and idx == 2) else hidden_states
612
+ # add motion module
613
+ hidden_states = motion_module(
614
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
615
+
616
+ output_states += (hidden_states,)
617
+ idx += 1
618
+
619
+ if self.downsamplers is not None:
620
+ for downsampler in self.downsamplers:
621
+ hidden_states = downsampler(hidden_states)
622
+
623
+ output_states += (hidden_states,)
624
+
625
+ return hidden_states, output_states
626
+
627
+
628
+ class CrossAttnUpBlock3D(nn.Module):
629
+ def __init__(
630
+ self,
631
+ in_channels: int,
632
+ out_channels: int,
633
+ prev_output_channel: int,
634
+ temb_channels: int,
635
+ dropout: float = 0.0,
636
+ num_layers: int = 1,
637
+ resnet_eps: float = 1e-6,
638
+ resnet_time_scale_shift: str = "default",
639
+ resnet_act_fn: str = "swish",
640
+ resnet_groups: int = 32,
641
+ resnet_pre_norm: bool = True,
642
+ attn_num_head_channels=1,
643
+ cross_attention_dim=1280,
644
+ output_scale_factor=1.0,
645
+ add_upsample=True,
646
+ dual_cross_attention=False,
647
+ use_linear_projection=False,
648
+ only_cross_attention=False,
649
+ upcast_attention=False,
650
+
651
+ unet_use_cross_frame_attention=None,
652
+ unet_use_temporal_attention=None,
653
+ use_inflated_groupnorm=None,
654
+
655
+ use_motion_module=None,
656
+ use_motion_resnet=None,
657
+
658
+ motion_module_type=None,
659
+ motion_module_kwargs=None,
660
+ ):
661
+ super().__init__()
662
+ resnets = []
663
+ attentions = []
664
+ motion_modules = []
665
+ motion_resnets = []
666
+ self.has_cross_attention = True
667
+ self.attn_num_head_channels = attn_num_head_channels
668
+
669
+ for i in range(num_layers):
670
+ res_skip_channels = in_channels if (
671
+ i == num_layers - 1) else out_channels
672
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
673
+
674
+ resnets.append(
675
+ ResnetBlock3D(
676
+ in_channels=resnet_in_channels + res_skip_channels,
677
+ out_channels=out_channels,
678
+ temb_channels=temb_channels,
679
+ eps=resnet_eps,
680
+ groups=resnet_groups,
681
+ dropout=dropout,
682
+ time_embedding_norm=resnet_time_scale_shift,
683
+ non_linearity=resnet_act_fn,
684
+ output_scale_factor=output_scale_factor,
685
+ pre_norm=resnet_pre_norm,
686
+
687
+ use_inflated_groupnorm=use_inflated_groupnorm,
688
+ )
689
+ )
690
+ motion_resnets.append(
691
+ ResnetBlock3D(
692
+ in_channels=out_channels,
693
+ out_channels=out_channels,
694
+ temb_channels=temb_channels,
695
+ eps=resnet_eps,
696
+ groups=resnet_groups,
697
+ dropout=dropout,
698
+ time_embedding_norm=resnet_time_scale_shift,
699
+ non_linearity=resnet_act_fn,
700
+ output_scale_factor=output_scale_factor,
701
+ pre_norm=resnet_pre_norm,
702
+ use_inflated_groupnorm=use_inflated_groupnorm,
703
+ use_temporal_conv=True,
704
+ use_temporal_mixer=True
705
+ ) if use_motion_resnet else None
706
+ )
707
+
708
+ if dual_cross_attention:
709
+ raise NotImplementedError
710
+ attentions.append(
711
+ Transformer3DModel(
712
+ attn_num_head_channels,
713
+ out_channels // attn_num_head_channels,
714
+ in_channels=out_channels,
715
+ num_layers=1,
716
+ cross_attention_dim=cross_attention_dim,
717
+ norm_num_groups=resnet_groups,
718
+ use_linear_projection=use_linear_projection,
719
+ only_cross_attention=only_cross_attention,
720
+ upcast_attention=upcast_attention,
721
+
722
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
723
+ unet_use_temporal_attention=unet_use_temporal_attention,
724
+ )
725
+ )
726
+ motion_modules.append(
727
+ get_motion_module(
728
+ in_channels=out_channels,
729
+ motion_module_type=motion_module_type,
730
+ motion_module_kwargs=motion_module_kwargs,
731
+ ) if use_motion_module else None
732
+ )
733
+
734
+ self.attentions = nn.ModuleList(attentions)
735
+ self.resnets = nn.ModuleList(resnets)
736
+ self.motion_modules = nn.ModuleList(motion_modules)
737
+ self.motion_resnets = nn.ModuleList(motion_resnets)
738
+
739
+ if add_upsample:
740
+ self.upsamplers = nn.ModuleList(
741
+ [Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
742
+ else:
743
+ self.upsamplers = None
744
+
745
+ self.gradient_checkpointing = False
746
+
747
+ def forward(
748
+ self,
749
+ hidden_states,
750
+ res_hidden_states_tuple,
751
+ temb=None,
752
+ encoder_hidden_states=None,
753
+ upsample_size=None,
754
+ attention_mask=None,
755
+ ):
756
+ for resnet, attn, motion_module, motion_resnet in zip(self.resnets, self.attentions, self.motion_modules, self.motion_resnets):
757
+ # pop res hidden states
758
+ res_hidden_states = res_hidden_states_tuple[-1]
759
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
760
+ hidden_states = torch.cat(
761
+ [hidden_states, res_hidden_states], dim=1)
762
+
763
+ if self.training and self.gradient_checkpointing:
764
+
765
+ def create_custom_forward(module, return_dict=None):
766
+ def custom_forward(*inputs):
767
+ if return_dict is not None:
768
+ return module(*inputs, return_dict=return_dict)
769
+ else:
770
+ return module(*inputs)
771
+
772
+ return custom_forward
773
+
774
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(
775
+ resnet), hidden_states.requires_grad_(), temb, use_reentrant=False)
776
+ if motion_resnet is not None:
777
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(
778
+ motion_resnet), hidden_states.requires_grad_(), temb, use_reentrant=False)
779
+
780
+ hidden_states = torch.utils.checkpoint.checkpoint(
781
+ create_custom_forward(attn, return_dict=False),
782
+ hidden_states.requires_grad_(),
783
+ encoder_hidden_states,
784
+ use_reentrant=False,
785
+ )[0]
786
+ if motion_module is not None:
787
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(
788
+ motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states, use_reentrant=False)
789
+
790
+ else:
791
+ hidden_states = resnet(hidden_states, temb)
792
+ hidden_states = motion_resnet(
793
+ hidden_states, temb) if motion_resnet is not None else hidden_states
794
+ hidden_states = attn(
795
+ hidden_states, encoder_hidden_states=encoder_hidden_states).sample
796
+
797
+ # add motion module
798
+ hidden_states = motion_module(
799
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
800
+
801
+ if self.upsamplers is not None:
802
+ for upsampler in self.upsamplers:
803
+ hidden_states = upsampler(hidden_states, upsample_size)
804
+
805
+ return hidden_states
806
+
807
+
808
+ class UpBlock3D(nn.Module):
809
+ def __init__(
810
+ self,
811
+ in_channels: int,
812
+ prev_output_channel: int,
813
+ out_channels: int,
814
+ temb_channels: int,
815
+ dropout: float = 0.0,
816
+ num_layers: int = 1,
817
+ resnet_eps: float = 1e-6,
818
+ resnet_time_scale_shift: str = "default",
819
+ resnet_act_fn: str = "swish",
820
+ resnet_groups: int = 32,
821
+ resnet_pre_norm: bool = True,
822
+ output_scale_factor=1.0,
823
+ add_upsample=True,
824
+
825
+ use_inflated_groupnorm=None,
826
+
827
+ use_motion_module=None,
828
+ motion_module_type=None,
829
+ motion_module_kwargs=None,
830
+ ):
831
+ super().__init__()
832
+ resnets = []
833
+ motion_modules = []
834
+
835
+ for i in range(num_layers):
836
+ res_skip_channels = in_channels if (
837
+ i == num_layers - 1) else out_channels
838
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
839
+
840
+ resnets.append(
841
+ ResnetBlock3D(
842
+ in_channels=resnet_in_channels + res_skip_channels,
843
+ out_channels=out_channels,
844
+ temb_channels=temb_channels,
845
+ eps=resnet_eps,
846
+ groups=resnet_groups,
847
+ dropout=dropout,
848
+ time_embedding_norm=resnet_time_scale_shift,
849
+ non_linearity=resnet_act_fn,
850
+ output_scale_factor=output_scale_factor,
851
+ pre_norm=resnet_pre_norm,
852
+
853
+ use_inflated_groupnorm=use_inflated_groupnorm,
854
+ )
855
+ )
856
+ motion_modules.append(
857
+ get_motion_module(
858
+ in_channels=out_channels,
859
+ motion_module_type=motion_module_type,
860
+ motion_module_kwargs=motion_module_kwargs,
861
+ ) if use_motion_module else None
862
+ )
863
+
864
+ self.resnets = nn.ModuleList(resnets)
865
+ self.motion_modules = nn.ModuleList(motion_modules)
866
+
867
+ if add_upsample:
868
+ self.upsamplers = nn.ModuleList(
869
+ [Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
870
+ else:
871
+ self.upsamplers = None
872
+
873
+ self.gradient_checkpointing = False
874
+
875
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, encoder_hidden_states=None,):
876
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
877
+ # pop res hidden states
878
+ res_hidden_states = res_hidden_states_tuple[-1]
879
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
880
+ hidden_states = torch.cat(
881
+ [hidden_states, res_hidden_states], dim=1)
882
+
883
+ if self.training and self.gradient_checkpointing:
884
+ def create_custom_forward(module):
885
+ def custom_forward(*inputs):
886
+ return module(*inputs)
887
+
888
+ return custom_forward
889
+
890
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(
891
+ resnet), hidden_states.requires_grad_(), temb, use_reentrant=False)
892
+ if motion_module is not None:
893
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(
894
+ motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states, use_reentrant=False)
895
+ else:
896
+ hidden_states = resnet(hidden_states, temb)
897
+ hidden_states = motion_module(
898
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
899
+
900
+ if self.upsamplers is not None:
901
+ for upsampler in self.upsamplers:
902
+ hidden_states = upsampler(hidden_states, upsample_size)
903
+
904
+ return hidden_states
animatelcm/pipelines/pipeline_animation.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py
2
+
3
+ import inspect
4
+ from typing import Callable, List, Optional, Union
5
+ from dataclasses import dataclass
6
+
7
+ import numpy as np
8
+ import torch
9
+ from tqdm import tqdm
10
+
11
+ from diffusers.utils import is_accelerate_available
12
+ from packaging import version
13
+ from transformers import CLIPTextModel, CLIPTokenizer
14
+
15
+ from diffusers.configuration_utils import FrozenDict
16
+ from diffusers.models import AutoencoderKL
17
+ from diffusers.pipeline_utils import DiffusionPipeline
18
+ from diffusers.schedulers import (
19
+ DDIMScheduler,
20
+ DPMSolverMultistepScheduler,
21
+ EulerAncestralDiscreteScheduler,
22
+ EulerDiscreteScheduler,
23
+ LMSDiscreteScheduler,
24
+ PNDMScheduler,
25
+ )
26
+ from diffusers.utils import deprecate, logging, BaseOutput
27
+
28
+ from einops import rearrange
29
+
30
+ from ..models.unet import UNet3DConditionModel
31
+
32
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
33
+
34
+
35
+ @dataclass
36
+ class AnimationPipelineOutput(BaseOutput):
37
+ videos: Union[torch.Tensor, np.ndarray]
38
+
39
+
40
+ class AnimationPipeline(DiffusionPipeline):
41
+ _optional_components = []
42
+
43
+ def __init__(
44
+ self,
45
+ vae: AutoencoderKL,
46
+ text_encoder: CLIPTextModel,
47
+ tokenizer: CLIPTokenizer,
48
+ unet: UNet3DConditionModel,
49
+ scheduler: Union[
50
+ DDIMScheduler,
51
+ PNDMScheduler,
52
+ LMSDiscreteScheduler,
53
+ EulerDiscreteScheduler,
54
+ EulerAncestralDiscreteScheduler,
55
+ DPMSolverMultistepScheduler,
56
+ ],
57
+ ):
58
+ super().__init__()
59
+
60
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
61
+ deprecation_message = (
62
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
63
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
64
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
65
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
66
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
67
+ " file"
68
+ )
69
+ deprecate("steps_offset!=1", "1.0.0",
70
+ deprecation_message, standard_warn=False)
71
+ new_config = dict(scheduler.config)
72
+ new_config["steps_offset"] = 1
73
+ scheduler._internal_dict = FrozenDict(new_config)
74
+
75
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
76
+ deprecation_message = (
77
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
78
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
79
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
80
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
81
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
82
+ )
83
+ deprecate("clip_sample not set", "1.0.0",
84
+ deprecation_message, standard_warn=False)
85
+ new_config = dict(scheduler.config)
86
+ new_config["clip_sample"] = False
87
+ scheduler._internal_dict = FrozenDict(new_config)
88
+
89
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
90
+ version.parse(unet.config._diffusers_version).base_version
91
+ ) < version.parse("0.9.0.dev0")
92
+ is_unet_sample_size_less_64 = hasattr(
93
+ unet.config, "sample_size") and unet.config.sample_size < 64
94
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
95
+ deprecation_message = (
96
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
97
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
98
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
99
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
100
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
101
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
102
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
103
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
104
+ " the `unet/config.json` file"
105
+ )
106
+ deprecate("sample_size<64", "1.0.0",
107
+ deprecation_message, standard_warn=False)
108
+ new_config = dict(unet.config)
109
+ new_config["sample_size"] = 64
110
+ unet._internal_dict = FrozenDict(new_config)
111
+
112
+ self.register_modules(
113
+ vae=vae,
114
+ text_encoder=text_encoder,
115
+ tokenizer=tokenizer,
116
+ unet=unet,
117
+ scheduler=scheduler,
118
+ )
119
+ self.vae_scale_factor = 2 ** (
120
+ len(self.vae.config.block_out_channels) - 1)
121
+
122
+ def enable_vae_slicing(self):
123
+ self.vae.enable_slicing()
124
+
125
+ def disable_vae_slicing(self):
126
+ self.vae.disable_slicing()
127
+
128
+ def enable_sequential_cpu_offload(self, gpu_id=0):
129
+ if is_accelerate_available():
130
+ from accelerate import cpu_offload
131
+ else:
132
+ raise ImportError(
133
+ "Please install accelerate via `pip install accelerate`")
134
+
135
+ device = torch.device(f"cuda:{gpu_id}")
136
+
137
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
138
+ if cpu_offloaded_model is not None:
139
+ cpu_offload(cpu_offloaded_model, device)
140
+
141
+ @property
142
+ def _execution_device(self):
143
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
144
+ return self.device
145
+ for module in self.unet.modules():
146
+ if (
147
+ hasattr(module, "_hf_hook")
148
+ and hasattr(module._hf_hook, "execution_device")
149
+ and module._hf_hook.execution_device is not None
150
+ ):
151
+ return torch.device(module._hf_hook.execution_device)
152
+ return self.device
153
+
154
+ def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt):
155
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
156
+
157
+ text_inputs = self.tokenizer(
158
+ prompt,
159
+ padding="max_length",
160
+ max_length=self.tokenizer.model_max_length,
161
+ truncation=True,
162
+ return_tensors="pt",
163
+ )
164
+ text_input_ids = text_inputs.input_ids
165
+ untruncated_ids = self.tokenizer(
166
+ prompt, padding="longest", return_tensors="pt").input_ids
167
+
168
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
169
+ removed_text = self.tokenizer.batch_decode(
170
+ untruncated_ids[:, self.tokenizer.model_max_length - 1: -1])
171
+ logger.warning(
172
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
173
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
174
+ )
175
+
176
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
177
+ attention_mask = text_inputs.attention_mask.to(device)
178
+ else:
179
+ attention_mask = None
180
+
181
+ text_embeddings = self.text_encoder(
182
+ text_input_ids.to(device),
183
+ attention_mask=attention_mask,
184
+ )
185
+ text_embeddings = text_embeddings[0]
186
+
187
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
188
+ bs_embed, seq_len, _ = text_embeddings.shape
189
+ text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
190
+ text_embeddings = text_embeddings.view(
191
+ bs_embed * num_videos_per_prompt, seq_len, -1)
192
+
193
+ # get unconditional embeddings for classifier free guidance
194
+ if do_classifier_free_guidance:
195
+ uncond_tokens: List[str]
196
+ if negative_prompt is None:
197
+ uncond_tokens = [""] * batch_size
198
+ elif type(prompt) is not type(negative_prompt):
199
+ raise TypeError(
200
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
201
+ f" {type(prompt)}."
202
+ )
203
+ elif isinstance(negative_prompt, str):
204
+ uncond_tokens = [negative_prompt]
205
+ elif batch_size != len(negative_prompt):
206
+ raise ValueError(
207
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
208
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
209
+ " the batch size of `prompt`."
210
+ )
211
+ else:
212
+ uncond_tokens = negative_prompt
213
+
214
+ max_length = text_input_ids.shape[-1]
215
+ uncond_input = self.tokenizer(
216
+ uncond_tokens,
217
+ padding="max_length",
218
+ max_length=max_length,
219
+ truncation=True,
220
+ return_tensors="pt",
221
+ )
222
+
223
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
224
+ attention_mask = uncond_input.attention_mask.to(device)
225
+ else:
226
+ attention_mask = None
227
+
228
+ uncond_embeddings = self.text_encoder(
229
+ uncond_input.input_ids.to(device),
230
+ attention_mask=attention_mask,
231
+ )
232
+ uncond_embeddings = uncond_embeddings[0]
233
+
234
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
235
+ seq_len = uncond_embeddings.shape[1]
236
+ uncond_embeddings = uncond_embeddings.repeat(
237
+ 1, num_videos_per_prompt, 1)
238
+ uncond_embeddings = uncond_embeddings.view(
239
+ batch_size * num_videos_per_prompt, seq_len, -1)
240
+
241
+ # For classifier free guidance, we need to do two forward passes.
242
+ # Here we concatenate the unconditional and text embeddings into a single batch
243
+ # to avoid doing two forward passes
244
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
245
+
246
+ return text_embeddings
247
+
248
+ def decode_latents(self, latents):
249
+ video_length = latents.shape[2]
250
+ latents = 1 / 0.18215 * latents
251
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
252
+ # video = self.vae.decode(latents).sample
253
+ video = []
254
+ for frame_idx in tqdm(range(latents.shape[0])):
255
+ video.append(self.vae.decode(
256
+ latents[frame_idx:frame_idx+1]).sample)
257
+ video = torch.cat(video)
258
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
259
+ video = (video / 2 + 0.5).clamp(0, 1)
260
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
261
+ video = video.cpu().float().numpy()
262
+ return video
263
+
264
+ def prepare_extra_step_kwargs(self, generator, eta):
265
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
266
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
267
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
268
+ # and should be between [0, 1]
269
+
270
+ accepts_eta = "eta" in set(inspect.signature(
271
+ self.scheduler.step).parameters.keys())
272
+ extra_step_kwargs = {}
273
+ if accepts_eta:
274
+ extra_step_kwargs["eta"] = eta
275
+
276
+ # check if the scheduler accepts generator
277
+ accepts_generator = "generator" in set(
278
+ inspect.signature(self.scheduler.step).parameters.keys())
279
+ if accepts_generator:
280
+ extra_step_kwargs["generator"] = generator
281
+ return extra_step_kwargs
282
+
283
+ def check_inputs(self, prompt, height, width, callback_steps):
284
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
285
+ raise ValueError(
286
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
287
+
288
+ if height % 8 != 0 or width % 8 != 0:
289
+ raise ValueError(
290
+ f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
291
+
292
+ if (callback_steps is None) or (
293
+ callback_steps is not None and (not isinstance(
294
+ callback_steps, int) or callback_steps <= 0)
295
+ ):
296
+ raise ValueError(
297
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
298
+ f" {type(callback_steps)}."
299
+ )
300
+
301
+ def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
302
+ shape = (batch_size, num_channels_latents, video_length, height //
303
+ self.vae_scale_factor, width // self.vae_scale_factor)
304
+ if isinstance(generator, list) and len(generator) != batch_size:
305
+ raise ValueError(
306
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
307
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
308
+ )
309
+ if latents is None:
310
+ rand_device = "cpu" if device.type == "mps" else device
311
+
312
+ if isinstance(generator, list):
313
+ shape = shape
314
+ # shape = (1,) + shape[1:]
315
+ latents = [
316
+ torch.randn(
317
+ shape, generator=generator[i], device=rand_device, dtype=dtype)
318
+ for i in range(batch_size)
319
+ ]
320
+ latents = torch.cat(latents, dim=0).to(device)
321
+ else:
322
+ latents = torch.randn(
323
+ shape, generator=generator, device=rand_device, dtype=dtype).to(device)
324
+ else:
325
+ if latents.shape != shape:
326
+ raise ValueError(
327
+ f"Unexpected latents shape, got {latents.shape}, expected {shape}")
328
+ latents = latents.to(device)
329
+
330
+ # scale the initial noise by the standard deviation required by the scheduler
331
+ latents = latents * self.scheduler.init_noise_sigma
332
+ return latents
333
+
334
+ @torch.no_grad()
335
+ def __call__(
336
+ self,
337
+ prompt: Union[str, List[str]],
338
+ video_length: Optional[int],
339
+ height: Optional[int] = None,
340
+ width: Optional[int] = None,
341
+ num_inference_steps: int = 50,
342
+ guidance_scale: float = 7.5,
343
+ negative_prompt: Optional[Union[str, List[str]]] = None,
344
+ num_videos_per_prompt: Optional[int] = 1,
345
+ eta: float = 0.0,
346
+ generator: Optional[Union[torch.Generator,
347
+ List[torch.Generator]]] = None,
348
+ latents: Optional[torch.FloatTensor] = None,
349
+ output_type: Optional[str] = "tensor",
350
+ return_dict: bool = True,
351
+ callback: Optional[Callable[[
352
+ int, int, torch.FloatTensor], None]] = None,
353
+ callback_steps: Optional[int] = 1,
354
+ do_classifier_free_guidance: bool = True,
355
+ image_path: str = None, # not ready
356
+ control_path: str = None, # not ready
357
+ sparse_control: str = False, # not ready
358
+ **kwargs,
359
+ ):
360
+
361
+ # Default height and width to unet
362
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
363
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
364
+
365
+ # Check inputs. Raise error if not correct
366
+ self.check_inputs(prompt, height, width, callback_steps)
367
+
368
+ # Define call parameters
369
+ # batch_size = 1 if isinstance(prompt, str) else len(prompt)
370
+ batch_size = 1
371
+ if latents is not None:
372
+ batch_size = latents.shape[0]
373
+ if isinstance(prompt, list):
374
+ batch_size = len(prompt)
375
+
376
+ device = self._execution_device
377
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
378
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
379
+ # corresponds to doing no classifier free guidance.
380
+ do_classifier_free_guidance = (
381
+ guidance_scale > 1.0) & do_classifier_free_guidance
382
+
383
+ prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size
384
+ if negative_prompt is not None:
385
+ negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [
386
+ negative_prompt] * batch_size
387
+ text_embeddings = self._encode_prompt(
388
+ prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
389
+ )
390
+
391
+ # Prepare timesteps
392
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
393
+ timesteps = self.scheduler.timesteps
394
+
395
+ # Prepare latent variables
396
+ num_channels_latents = self.unet.in_channels
397
+ latents = self.prepare_latents(
398
+ batch_size * num_videos_per_prompt,
399
+ num_channels_latents,
400
+ video_length,
401
+ height,
402
+ width,
403
+ text_embeddings.dtype,
404
+ device,
405
+ generator,
406
+ latents,
407
+ )
408
+ latents_dtype = latents.dtype
409
+
410
+ w_embedding = None # not ready
411
+
412
+ # Prepare extra step kwargs.
413
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
414
+
415
+ # Denoising loop
416
+ num_warmup_steps = len(timesteps) - \
417
+ num_inference_steps * self.scheduler.order
418
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
419
+ for i, t in enumerate(timesteps):
420
+ # expand the latents if we are doing classifier free guidance
421
+ latent_model_input = torch.cat(
422
+ [latents] * 2) if do_classifier_free_guidance else latents
423
+ latent_model_input = self.scheduler.scale_model_input(
424
+ latent_model_input, t)
425
+
426
+ # predict the noise residual
427
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings,
428
+ time_cond=w_embedding).sample.to(dtype=latents_dtype)
429
+
430
+ # perform guidance
431
+ if do_classifier_free_guidance:
432
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
433
+ noise_pred = noise_pred_uncond + guidance_scale * \
434
+ (noise_pred_text - noise_pred_uncond)
435
+
436
+ # compute the previous noisy sample x_t -> x_t-1
437
+ latents = self.scheduler.step(
438
+ noise_pred, t, latents, **extra_step_kwargs).prev_sample
439
+
440
+ # call the callback, if provided
441
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
442
+ progress_bar.update()
443
+ if callback is not None and i % callback_steps == 0:
444
+ callback(i, t, latents)
445
+
446
+ # Post-processing
447
+ video = self.decode_latents(latents)
448
+
449
+ # Convert to tensor
450
+ if output_type == "tensor":
451
+ video = torch.from_numpy(video)
452
+
453
+ if not return_dict:
454
+ return video
455
+
456
+ return AnimationPipelineOutput(videos=video)
animatelcm/scheduler/lcm_scheduler.py ADDED
@@ -0,0 +1,722 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
16
+ # and https://github.com/hojonathanho/diffusion
17
+
18
+ import math
19
+ from dataclasses import dataclass
20
+ from typing import List, Optional, Tuple, Union
21
+
22
+ import numpy as np
23
+ import torch
24
+
25
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
26
+ from diffusers.utils import BaseOutput, logging
27
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
28
+
29
+
30
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
31
+
32
+
33
+
34
+ @dataclass
35
+ class LCMSchedulerOutput(BaseOutput):
36
+ """
37
+ Output class for the scheduler's `step` function output.
38
+
39
+ Args:
40
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
41
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
42
+ denoising loop.
43
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
44
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
45
+ `pred_original_sample` can be used to preview progress or for guidance.
46
+ """
47
+
48
+ prev_sample: torch.FloatTensor
49
+ denoised: Optional[torch.FloatTensor] = None
50
+
51
+
52
+ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
53
+ def betas_for_alpha_bar(
54
+ num_diffusion_timesteps,
55
+ max_beta=0.999,
56
+ alpha_transform_type="cosine",
57
+ ):
58
+ """
59
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
60
+ (1-beta) over time from t = [0,1].
61
+
62
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
63
+ to that part of the diffusion process.
64
+
65
+
66
+ Args:
67
+ num_diffusion_timesteps (`int`): the number of betas to produce.
68
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
69
+ prevent singularities.
70
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
71
+ Choose from `cosine` or `exp`
72
+
73
+ Returns:
74
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
75
+ """
76
+ if alpha_transform_type == "cosine":
77
+
78
+ def alpha_bar_fn(t):
79
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
80
+
81
+ elif alpha_transform_type == "exp":
82
+
83
+ def alpha_bar_fn(t):
84
+ return math.exp(t * -12.0)
85
+
86
+ else:
87
+ raise ValueError(
88
+ f"Unsupported alpha_tranform_type: {alpha_transform_type}")
89
+
90
+ betas = []
91
+ for i in range(num_diffusion_timesteps):
92
+ t1 = i / num_diffusion_timesteps
93
+ t2 = (i + 1) / num_diffusion_timesteps
94
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
95
+ return torch.tensor(betas, dtype=torch.float32)
96
+
97
+
98
+ # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
99
+ def rescale_zero_terminal_snr(betas: torch.FloatTensor) -> torch.FloatTensor:
100
+ """
101
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
102
+
103
+
104
+ Args:
105
+ betas (`torch.FloatTensor`):
106
+ the betas that the scheduler is being initialized with.
107
+
108
+ Returns:
109
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
110
+ """
111
+ # Convert betas to alphas_bar_sqrt
112
+ alphas = 1.0 - betas
113
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
114
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
115
+
116
+ # Store old values.
117
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
118
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
119
+
120
+ # Shift so the last timestep is zero.
121
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
122
+
123
+ # Scale so the first timestep is back to the old value.
124
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / \
125
+ (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
126
+
127
+ # Convert alphas_bar_sqrt to betas
128
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
129
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
130
+ alphas = torch.cat([alphas_bar[0:1], alphas])
131
+ betas = 1 - alphas
132
+
133
+ return betas
134
+
135
+
136
+ def randn_tensor(
137
+ shape: Union[Tuple, List],
138
+ generator: Optional[Union[List["torch.Generator"],
139
+ "torch.Generator"]] = None,
140
+ device: Optional["torch.device"] = None,
141
+ dtype: Optional["torch.dtype"] = None,
142
+ layout: Optional["torch.layout"] = None,
143
+ ):
144
+ """A helper function to create random tensors on the desired `device` with the desired `dtype`. When
145
+ passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor
146
+ is always created on the CPU.
147
+ """
148
+ # device on which tensor is created defaults to device
149
+ rand_device = device
150
+ batch_size = shape[0]
151
+
152
+ layout = layout or torch.strided
153
+ device = device or torch.device("cpu")
154
+
155
+ if generator is not None:
156
+ gen_device_type = generator.device.type if not isinstance(
157
+ generator, list) else generator[0].device.type
158
+ if gen_device_type != device.type and gen_device_type == "cpu":
159
+ rand_device = "cpu"
160
+ if device != "mps":
161
+ logger.info(
162
+ f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
163
+ f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
164
+ f" slighly speed up this function by passing a generator that was created on the {device} device."
165
+ )
166
+ elif gen_device_type != device.type and gen_device_type == "cuda":
167
+ raise ValueError(
168
+ f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")
169
+
170
+ # make sure generator list of length 1 is treated like a non-list
171
+ if isinstance(generator, list) and len(generator) == 1:
172
+ generator = generator[0]
173
+
174
+ if isinstance(generator, list):
175
+ shape = (1,) + shape[1:]
176
+ latents = [
177
+ torch.randn(
178
+ shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout)
179
+ for i in range(batch_size)
180
+ ]
181
+ latents = torch.cat(latents, dim=0).to(device)
182
+ else:
183
+ latents = torch.randn(shape, generator=generator,
184
+ device=rand_device, dtype=dtype, layout=layout).to(device)
185
+
186
+ return latents
187
+
188
+
189
+ class LCMScheduler(SchedulerMixin, ConfigMixin):
190
+ """
191
+ `LCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
192
+ non-Markovian guidance.
193
+
194
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. [`~ConfigMixin`] takes care of storing all config
195
+ attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be
196
+ accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving
197
+ functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions.
198
+
199
+ Args:
200
+ num_train_timesteps (`int`, defaults to 1000):
201
+ The number of diffusion steps to train the model.
202
+ beta_start (`float`, defaults to 0.0001):
203
+ The starting `beta` value of inference.
204
+ beta_end (`float`, defaults to 0.02):
205
+ The final `beta` value.
206
+ beta_schedule (`str`, defaults to `"linear"`):
207
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
208
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
209
+ trained_betas (`np.ndarray`, *optional*):
210
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
211
+ original_inference_steps (`int`, *optional*, defaults to 50):
212
+ The default number of inference steps used to generate a linearly-spaced timestep schedule, from which we
213
+ will ultimately take `num_inference_steps` evenly spaced timesteps to form the final timestep schedule.
214
+ clip_sample (`bool`, defaults to `True`):
215
+ Clip the predicted sample for numerical stability.
216
+ clip_sample_range (`float`, defaults to 1.0):
217
+ The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
218
+ set_alpha_to_one (`bool`, defaults to `True`):
219
+ Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
220
+ there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
221
+ otherwise it uses the alpha value at step 0.
222
+ steps_offset (`int`, defaults to 0):
223
+ An offset added to the inference steps. You can use a combination of `offset=1` and
224
+ `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
225
+ Diffusion.
226
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
227
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
228
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
229
+ Video](https://imagen.research.google/video/paper.pdf) paper).
230
+ thresholding (`bool`, defaults to `False`):
231
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
232
+ as Stable Diffusion.
233
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
234
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
235
+ sample_max_value (`float`, defaults to 1.0):
236
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
237
+ timestep_spacing (`str`, defaults to `"leading"`):
238
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
239
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
240
+ timestep_scaling (`float`, defaults to 10.0):
241
+ The factor the timesteps will be multiplied by when calculating the consistency model boundary conditions
242
+ `c_skip` and `c_out`. Increasing this will decrease the approximation error (although the approximation
243
+ error at the default of `10.0` is already pretty small).
244
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
245
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
246
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
247
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
248
+ """
249
+
250
+ order = 1
251
+
252
+ @register_to_config
253
+ def __init__(
254
+ self,
255
+ num_train_timesteps: int = 1000,
256
+ beta_start: float = 0.00085,
257
+ beta_end: float = 0.012,
258
+ beta_schedule: str = "scaled_linear",
259
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
260
+ original_inference_steps: int = 50,
261
+ clip_sample: bool = False,
262
+ clip_sample_range: float = 1.0,
263
+ set_alpha_to_one: bool = True,
264
+ steps_offset: int = 0,
265
+ prediction_type: str = "epsilon",
266
+ thresholding: bool = False,
267
+ dynamic_thresholding_ratio: float = 0.995,
268
+ sample_max_value: float = 1.0,
269
+ timestep_spacing: str = "leading",
270
+ timestep_scaling: float = 10.0,
271
+ rescale_betas_zero_snr: bool = False,
272
+ ):
273
+ if trained_betas is not None:
274
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
275
+ elif beta_schedule == "linear":
276
+ self.betas = torch.linspace(
277
+ beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
278
+ elif beta_schedule == "scaled_linear":
279
+ # this schedule is very specific to the latent diffusion model.
280
+ self.betas = torch.linspace(
281
+ beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
282
+ elif beta_schedule == "squaredcos_cap_v2":
283
+ # Glide cosine schedule
284
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
285
+ else:
286
+ raise NotImplementedError(
287
+ f"{beta_schedule} does is not implemented for {self.__class__}")
288
+
289
+ # Rescale for zero SNR
290
+ if rescale_betas_zero_snr:
291
+ self.betas = rescale_zero_terminal_snr(self.betas)
292
+
293
+ self.alphas = 1.0 - self.betas
294
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
295
+
296
+ # At every step in ddim, we are looking into the previous alphas_cumprod
297
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
298
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
299
+ # whether we use the final alpha of the "non-previous" one.
300
+ self.final_alpha_cumprod = torch.tensor(
301
+ 1.0) if set_alpha_to_one else self.alphas_cumprod[0]
302
+
303
+ # standard deviation of the initial noise distribution
304
+ self.init_noise_sigma = 1.0
305
+
306
+ # setable values
307
+ self.num_inference_steps = None
308
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[
309
+ ::-1].copy().astype(np.int64))
310
+ self.custom_timesteps = False
311
+
312
+ self._step_index = None
313
+
314
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
315
+ def _init_step_index(self, timestep):
316
+ if isinstance(timestep, torch.Tensor):
317
+ timestep = timestep.to(self.timesteps.device)
318
+
319
+ index_candidates = (self.timesteps == timestep).nonzero()
320
+
321
+ # The sigma index that is taken for the **very** first `step`
322
+ # is always the second index (or the last index if there is only 1)
323
+ # This way we can ensure we don't accidentally skip a sigma in
324
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
325
+ if len(index_candidates) > 1:
326
+ step_index = index_candidates[1]
327
+ else:
328
+ step_index = index_candidates[0]
329
+
330
+ self._step_index = step_index.item()
331
+
332
+ @property
333
+ def step_index(self):
334
+ return self._step_index
335
+
336
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
337
+ """
338
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
339
+ current timestep.
340
+
341
+ Args:
342
+ sample (`torch.FloatTensor`):
343
+ The input sample.
344
+ timestep (`int`, *optional*):
345
+ The current timestep in the diffusion chain.
346
+ Returns:
347
+ `torch.FloatTensor`:
348
+ A scaled input sample.
349
+ """
350
+ return sample
351
+
352
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
353
+ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
354
+ """
355
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
356
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
357
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
358
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
359
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
360
+
361
+ https://arxiv.org/abs/2205.11487
362
+ """
363
+ dtype = sample.dtype
364
+ batch_size, channels, *remaining_dims = sample.shape
365
+
366
+ if dtype not in (torch.float32, torch.float64):
367
+ # upcast for quantile calculation, and clamp not implemented for cpu half
368
+ sample = sample.float()
369
+
370
+ # Flatten sample for doing quantile calculation along each image
371
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
372
+
373
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
374
+
375
+ s = torch.quantile(
376
+ abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
377
+ s = torch.clamp(
378
+ s, min=1, max=self.config.sample_max_value
379
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
380
+ # (batch_size, 1) because clamp will broadcast along dim=0
381
+ s = s.unsqueeze(1)
382
+ # "we threshold xt0 to the range [-s, s] and then divide by s"
383
+ sample = torch.clamp(sample, -s, s) / s
384
+
385
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
386
+ sample = sample.to(dtype)
387
+
388
+ return sample
389
+
390
+ def set_timesteps(
391
+ self,
392
+ num_inference_steps: Optional[int] = None,
393
+ device: Union[str, torch.device] = None,
394
+ original_inference_steps: Optional[int] = None,
395
+ timesteps: Optional[List[int]] = None,
396
+ strength: int = 1.0,
397
+ ):
398
+ """
399
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
400
+
401
+ Args:
402
+ num_inference_steps (`int`, *optional*):
403
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
404
+ `timesteps` must be `None`.
405
+ device (`str` or `torch.device`, *optional*):
406
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
407
+ original_inference_steps (`int`, *optional*):
408
+ The original number of inference steps, which will be used to generate a linearly-spaced timestep
409
+ schedule (which is different from the standard `diffusers` implementation). We will then take
410
+ `num_inference_steps` timesteps from this schedule, evenly spaced in terms of indices, and use that as
411
+ our final timestep schedule. If not set, this will default to the `original_inference_steps` attribute.
412
+ timesteps (`List[int]`, *optional*):
413
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
414
+ timestep spacing strategy of equal spacing between timesteps on the training/distillation timestep
415
+ schedule is used. If `timesteps` is passed, `num_inference_steps` must be `None`.
416
+ """
417
+ # 0. Check inputs
418
+ if num_inference_steps is None and timesteps is None:
419
+ raise ValueError(
420
+ "Must pass exactly one of `num_inference_steps` or `custom_timesteps`.")
421
+
422
+ if num_inference_steps is not None and timesteps is not None:
423
+ raise ValueError(
424
+ "Can only pass one of `num_inference_steps` or `custom_timesteps`.")
425
+
426
+ # 1. Calculate the LCM original training/distillation timestep schedule.
427
+ original_steps = (
428
+ original_inference_steps if original_inference_steps is not None else self.config.original_inference_steps
429
+ )
430
+
431
+ if original_steps > self.config.num_train_timesteps:
432
+ raise ValueError(
433
+ f"`original_steps`: {original_steps} cannot be larger than `self.config.train_timesteps`:"
434
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
435
+ f" maximal {self.config.num_train_timesteps} timesteps."
436
+ )
437
+
438
+ # LCM Timesteps Setting
439
+ # The skipping step parameter k from the paper.
440
+ k = self.config.num_train_timesteps // original_steps
441
+ # LCM Training/Distillation Steps Schedule
442
+ # Currently, only a linearly-spaced schedule is supported (same as in the LCM distillation scripts).
443
+ lcm_origin_timesteps = np.asarray(
444
+ list(range(1, int(original_steps * strength) + 1))) * k - 1
445
+
446
+ # 2. Calculate the LCM inference timestep schedule.
447
+ if timesteps is not None:
448
+ # 2.1 Handle custom timestep schedules.
449
+ train_timesteps = set(lcm_origin_timesteps)
450
+ non_train_timesteps = []
451
+ for i in range(1, len(timesteps)):
452
+ if timesteps[i] >= timesteps[i - 1]:
453
+ raise ValueError(
454
+ "`custom_timesteps` must be in descending order.")
455
+
456
+ if timesteps[i] not in train_timesteps:
457
+ non_train_timesteps.append(timesteps[i])
458
+
459
+ if timesteps[0] >= self.config.num_train_timesteps:
460
+ raise ValueError(
461
+ f"`timesteps` must start before `self.config.train_timesteps`:"
462
+ f" {self.config.num_train_timesteps}."
463
+ )
464
+
465
+ # Raise warning if timestep schedule does not start with self.config.num_train_timesteps - 1
466
+ if strength == 1.0 and timesteps[0] != self.config.num_train_timesteps - 1:
467
+ logger.warning(
468
+ f"The first timestep on the custom timestep schedule is {timesteps[0]}, not"
469
+ f" `self.config.num_train_timesteps - 1`: {self.config.num_train_timesteps - 1}. You may get"
470
+ f" unexpected results when using this timestep schedule."
471
+ )
472
+
473
+ # Raise warning if custom timestep schedule contains timesteps not on original timestep schedule
474
+ if non_train_timesteps:
475
+ logger.warning(
476
+ f"The custom timestep schedule contains the following timesteps which are not on the original"
477
+ f" training/distillation timestep schedule: {non_train_timesteps}. You may get unexpected results"
478
+ f" when using this timestep schedule."
479
+ )
480
+
481
+ # Raise warning if custom timestep schedule is longer than original_steps
482
+ if len(timesteps) > original_steps:
483
+ logger.warning(
484
+ f"The number of timesteps in the custom timestep schedule is {len(timesteps)}, which exceeds the"
485
+ f" the length of the timestep schedule used for training: {original_steps}. You may get some"
486
+ f" unexpected results when using this timestep schedule."
487
+ )
488
+
489
+ timesteps = np.array(timesteps, dtype=np.int64)
490
+ self.num_inference_steps = len(timesteps)
491
+ self.custom_timesteps = True
492
+
493
+ # Apply strength (e.g. for img2img pipelines) (see StableDiffusionImg2ImgPipeline.get_timesteps)
494
+ init_timestep = min(
495
+ int(self.num_inference_steps * strength), self.num_inference_steps)
496
+ t_start = max(self.num_inference_steps - init_timestep, 0)
497
+ timesteps = timesteps[t_start * self.order:]
498
+ # TODO: also reset self.num_inference_steps?
499
+ else:
500
+ # 2.2 Create the "standard" LCM inference timestep schedule.
501
+ if num_inference_steps > self.config.num_train_timesteps:
502
+ raise ValueError(
503
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
504
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
505
+ f" maximal {self.config.num_train_timesteps} timesteps."
506
+ )
507
+
508
+ skipping_step = len(lcm_origin_timesteps) // num_inference_steps
509
+
510
+ if skipping_step < 1:
511
+ raise ValueError(
512
+ f"The combination of `original_steps x strength`: {original_steps} x {strength} is smaller than `num_inference_steps`: {num_inference_steps}. Make sure to either reduce `num_inference_steps` to a value smaller than {int(original_steps * strength)} or increase `strength` to a value higher than {float(num_inference_steps / original_steps)}."
513
+ )
514
+
515
+ self.num_inference_steps = num_inference_steps
516
+
517
+ if num_inference_steps > original_steps:
518
+ raise ValueError(
519
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `original_inference_steps`:"
520
+ f" {original_steps} because the final timestep schedule will be a subset of the"
521
+ f" `original_inference_steps`-sized initial timestep schedule."
522
+ )
523
+
524
+ # LCM Inference Steps Schedule
525
+ lcm_origin_timesteps = lcm_origin_timesteps[::-1].copy()
526
+ # Select (approximately) evenly spaced indices from lcm_origin_timesteps.
527
+ inference_indices = np.linspace(
528
+ 0, len(lcm_origin_timesteps), num=num_inference_steps, endpoint=False)
529
+ '''
530
+
531
+ 当只有1步时会进行999步直接进行
532
+ 两步: 999, 499,
533
+ 四步: 999, 759, 499, 259
534
+
535
+ '''
536
+ inference_indices = np.floor(inference_indices).astype(np.int64)
537
+ timesteps = lcm_origin_timesteps[inference_indices]
538
+
539
+ self.timesteps = torch.from_numpy(timesteps).to(
540
+ device=device, dtype=torch.long)
541
+
542
+ self._step_index = None
543
+
544
+
545
+ def get_scalings_for_boundary_condition_discrete(self, timestep):
546
+ self.sigma_data = 0.5 # Default: 0.5
547
+ scaled_timestep = timestep * self.config.timestep_scaling
548
+
549
+ c_skip = self.sigma_data**2 / (scaled_timestep**2 + self.sigma_data**2)
550
+ c_out = scaled_timestep / \
551
+ (scaled_timestep**2 + self.sigma_data**2) ** 0.5
552
+ return c_skip, c_out
553
+
554
+ def step(
555
+ self,
556
+ model_output: torch.FloatTensor,
557
+ timestep: int,
558
+ sample: torch.FloatTensor,
559
+ generator: Optional[torch.Generator] = None,
560
+ return_dict: bool = True,
561
+ use_ddim: bool = False,
562
+ ) -> Union[LCMSchedulerOutput, Tuple]:
563
+ """
564
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
565
+ process from the learned model outputs (most often the predicted noise).
566
+
567
+ Args:
568
+ model_output (`torch.FloatTensor`):
569
+ The direct output from learned diffusion model.
570
+ timestep (`float`):
571
+ The current discrete timestep in the diffusion chain.
572
+ sample (`torch.FloatTensor`):
573
+ A current instance of a sample created by the diffusion process.
574
+ generator (`torch.Generator`, *optional*):
575
+ A random number generator.
576
+ return_dict (`bool`, *optional*, defaults to `True`):
577
+ Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
578
+ Returns:
579
+ [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
580
+ If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
581
+ tuple is returned where the first element is the sample tensor.
582
+ """
583
+ if self.num_inference_steps is None:
584
+ raise ValueError(
585
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
586
+ )
587
+
588
+ if self.step_index is None:
589
+ self._init_step_index(timestep)
590
+
591
+ # 1. get previous step value
592
+ prev_step_index = self.step_index + 1
593
+ if prev_step_index < len(self.timesteps):
594
+ prev_timestep = self.timesteps[prev_step_index]
595
+ else:
596
+ prev_timestep = timestep
597
+
598
+ # 2. compute alphas, betas
599
+ alpha_prod_t = self.alphas_cumprod[timestep]
600
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
601
+
602
+ beta_prod_t = 1 - alpha_prod_t
603
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
604
+
605
+ # 3. Get scalings for boundary conditions
606
+ c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(
607
+ timestep)
608
+
609
+ # 4. Compute the predicted original sample x_0 based on the model parameterization
610
+ if self.config.prediction_type == "epsilon": # noise-prediction
611
+ predicted_original_sample = (
612
+ sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
613
+ elif self.config.prediction_type == "sample": # x-prediction
614
+ predicted_original_sample = model_output
615
+ elif self.config.prediction_type == "v_prediction": # v-prediction
616
+ predicted_original_sample = alpha_prod_t.sqrt(
617
+ ) * sample - beta_prod_t.sqrt() * model_output
618
+ else:
619
+ raise ValueError(
620
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
621
+ " `v_prediction` for `LCMScheduler`."
622
+ )
623
+
624
+ # 5. Clip or threshold "predicted x_0"
625
+ if self.config.thresholding:
626
+ predicted_original_sample = self._threshold_sample(
627
+ predicted_original_sample)
628
+ elif self.config.clip_sample:
629
+ predicted_original_sample = predicted_original_sample.clamp(
630
+ -self.config.clip_sample_range, self.config.clip_sample_range
631
+ )
632
+
633
+ # 6. Denoise model output using boundary conditions
634
+ denoised = c_out * predicted_original_sample + c_skip * sample
635
+ # denoised = predicted_original_sample
636
+
637
+ # 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference
638
+ # Noise is not used on the final timestep of the timestep schedule.
639
+ # This also means that noise is not used for one-step sampling.
640
+ if self.step_index != self.num_inference_steps - 1:
641
+ if not use_ddim:
642
+ noise = randn_tensor(
643
+ model_output.shape, generator=generator, device=model_output.device, dtype=denoised.dtype
644
+ )
645
+ prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
646
+ else:
647
+ prev_sample = denoised
648
+
649
+ # upon completion increase step index by one
650
+ self._step_index += 1
651
+
652
+ if not return_dict:
653
+ return (prev_sample, denoised)
654
+
655
+ return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised)
656
+
657
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
658
+ def add_noise(
659
+ self,
660
+ original_samples: torch.FloatTensor,
661
+ noise: torch.FloatTensor,
662
+ timesteps: torch.IntTensor,
663
+ ) -> torch.FloatTensor:
664
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
665
+ alphas_cumprod = self.alphas_cumprod.to(
666
+ device=original_samples.device, dtype=original_samples.dtype)
667
+ timesteps = timesteps.to(original_samples.device)
668
+
669
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
670
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
671
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
672
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
673
+
674
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
675
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
676
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
677
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
678
+
679
+ noisy_samples = sqrt_alpha_prod * original_samples + \
680
+ sqrt_one_minus_alpha_prod * noise
681
+ return noisy_samples
682
+
683
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
684
+ def get_velocity(
685
+ self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
686
+ ) -> torch.FloatTensor:
687
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
688
+ alphas_cumprod = self.alphas_cumprod.to(
689
+ device=sample.device, dtype=sample.dtype)
690
+ timesteps = timesteps.to(sample.device)
691
+
692
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
693
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
694
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
695
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
696
+
697
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
698
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
699
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
700
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
701
+
702
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
703
+ return velocity
704
+
705
+ def __len__(self):
706
+ return self.config.num_train_timesteps
707
+
708
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
709
+ def previous_timestep(self, timestep):
710
+ if self.custom_timesteps:
711
+ index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
712
+ if index == self.timesteps.shape[0] - 1:
713
+ prev_t = torch.tensor(-1)
714
+ else:
715
+ prev_t = self.timesteps[index + 1]
716
+ else:
717
+ num_inference_steps = (
718
+ self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
719
+ )
720
+ prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
721
+
722
+ return prev_t
animatelcm/utils/convert_from_ckpt.py ADDED
@@ -0,0 +1,951 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Conversion script for the Stable Diffusion checkpoints."""
16
+
17
+ import re
18
+ from io import BytesIO
19
+ from typing import Optional
20
+
21
+ import requests
22
+ import torch
23
+ from transformers import (
24
+ AutoFeatureExtractor,
25
+ BertTokenizerFast,
26
+ CLIPImageProcessor,
27
+ CLIPTextModel,
28
+ CLIPTextModelWithProjection,
29
+ CLIPTokenizer,
30
+ CLIPVisionConfig,
31
+ CLIPVisionModelWithProjection,
32
+ )
33
+
34
+ from diffusers.models import (
35
+ AutoencoderKL,
36
+ PriorTransformer,
37
+ UNet2DConditionModel,
38
+ )
39
+ from diffusers.schedulers import (
40
+ DDIMScheduler,
41
+ DDPMScheduler,
42
+ DPMSolverMultistepScheduler,
43
+ EulerAncestralDiscreteScheduler,
44
+ EulerDiscreteScheduler,
45
+ HeunDiscreteScheduler,
46
+ LMSDiscreteScheduler,
47
+ PNDMScheduler,
48
+ UnCLIPScheduler,
49
+ )
50
+ from diffusers.utils.import_utils import BACKENDS_MAPPING
51
+
52
+
53
+ def shave_segments(path, n_shave_prefix_segments=1):
54
+ """
55
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
56
+ """
57
+ if n_shave_prefix_segments >= 0:
58
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
59
+ else:
60
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
61
+
62
+
63
+ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
64
+ """
65
+ Updates paths inside resnets to the new naming scheme (local renaming)
66
+ """
67
+ mapping = []
68
+ for old_item in old_list:
69
+ new_item = old_item.replace("in_layers.0", "norm1")
70
+ new_item = new_item.replace("in_layers.2", "conv1")
71
+
72
+ new_item = new_item.replace("out_layers.0", "norm2")
73
+ new_item = new_item.replace("out_layers.3", "conv2")
74
+
75
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
76
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
77
+
78
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
79
+
80
+ mapping.append({"old": old_item, "new": new_item})
81
+
82
+ return mapping
83
+
84
+
85
+ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
86
+ """
87
+ Updates paths inside resnets to the new naming scheme (local renaming)
88
+ """
89
+ mapping = []
90
+ for old_item in old_list:
91
+ new_item = old_item
92
+
93
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
94
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
95
+
96
+ mapping.append({"old": old_item, "new": new_item})
97
+
98
+ return mapping
99
+
100
+
101
+ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
102
+ """
103
+ Updates paths inside attentions to the new naming scheme (local renaming)
104
+ """
105
+ mapping = []
106
+ for old_item in old_list:
107
+ new_item = old_item
108
+
109
+ mapping.append({"old": old_item, "new": new_item})
110
+
111
+ return mapping
112
+
113
+
114
+ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
115
+ """
116
+ Updates paths inside attentions to the new naming scheme (local renaming)
117
+ """
118
+ mapping = []
119
+ for old_item in old_list:
120
+ new_item = old_item
121
+
122
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
123
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
124
+
125
+ new_item = new_item.replace("q.weight", "query.weight")
126
+ new_item = new_item.replace("q.bias", "query.bias")
127
+
128
+ new_item = new_item.replace("k.weight", "key.weight")
129
+ new_item = new_item.replace("k.bias", "key.bias")
130
+
131
+ new_item = new_item.replace("v.weight", "value.weight")
132
+ new_item = new_item.replace("v.bias", "value.bias")
133
+
134
+ new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
135
+ new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
136
+
137
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
138
+
139
+ mapping.append({"old": old_item, "new": new_item})
140
+
141
+ return mapping
142
+
143
+
144
+ def assign_to_checkpoint(
145
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
146
+ ):
147
+ """
148
+ This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
149
+ attention layers, and takes into account additional replacements that may arise.
150
+
151
+ Assigns the weights to the new checkpoint.
152
+ """
153
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
154
+
155
+ # Splits the attention layers into three variables.
156
+ if attention_paths_to_split is not None:
157
+ for path, path_map in attention_paths_to_split.items():
158
+ old_tensor = old_checkpoint[path]
159
+ channels = old_tensor.shape[0] // 3
160
+
161
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
162
+
163
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
164
+
165
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
166
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
167
+
168
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
169
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
170
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
171
+
172
+ for path in paths:
173
+ new_path = path["new"]
174
+
175
+ # These have already been assigned
176
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
177
+ continue
178
+
179
+ # Global renaming happens here
180
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
181
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
182
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
183
+
184
+ if additional_replacements is not None:
185
+ for replacement in additional_replacements:
186
+ new_path = new_path.replace(replacement["old"], replacement["new"])
187
+
188
+ # proj_attn.weight has to be converted from conv 1D to linear
189
+ if "proj_attn.weight" in new_path:
190
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
191
+ else:
192
+ checkpoint[new_path] = old_checkpoint[path["old"]]
193
+
194
+
195
+ def conv_attn_to_linear(checkpoint):
196
+ keys = list(checkpoint.keys())
197
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
198
+ for key in keys:
199
+ if ".".join(key.split(".")[-2:]) in attn_keys:
200
+ if checkpoint[key].ndim > 2:
201
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
202
+ elif "proj_attn.weight" in key:
203
+ if checkpoint[key].ndim > 2:
204
+ checkpoint[key] = checkpoint[key][:, :, 0]
205
+
206
+
207
+ def create_unet_diffusers_config(original_config, image_size: int, controlnet=False):
208
+ """
209
+ Creates a config for the diffusers based on the config of the LDM model.
210
+ """
211
+ if controlnet:
212
+ unet_params = original_config.model.params.control_stage_config.params
213
+ else:
214
+ unet_params = original_config.model.params.unet_config.params
215
+
216
+ vae_params = original_config.model.params.first_stage_config.params.ddconfig
217
+
218
+ block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
219
+
220
+ down_block_types = []
221
+ resolution = 1
222
+ for i in range(len(block_out_channels)):
223
+ block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
224
+ down_block_types.append(block_type)
225
+ if i != len(block_out_channels) - 1:
226
+ resolution *= 2
227
+
228
+ up_block_types = []
229
+ for i in range(len(block_out_channels)):
230
+ block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
231
+ up_block_types.append(block_type)
232
+ resolution //= 2
233
+
234
+ vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
235
+
236
+ head_dim = unet_params.num_heads if "num_heads" in unet_params else None
237
+ use_linear_projection = (
238
+ unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
239
+ )
240
+ if use_linear_projection:
241
+ # stable diffusion 2-base-512 and 2-768
242
+ if head_dim is None:
243
+ head_dim = [5, 10, 20, 20]
244
+
245
+ class_embed_type = None
246
+ projection_class_embeddings_input_dim = None
247
+
248
+ if "num_classes" in unet_params:
249
+ if unet_params.num_classes == "sequential":
250
+ class_embed_type = "projection"
251
+ assert "adm_in_channels" in unet_params
252
+ projection_class_embeddings_input_dim = unet_params.adm_in_channels
253
+ else:
254
+ raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}")
255
+
256
+ config = {
257
+ "sample_size": image_size // vae_scale_factor,
258
+ "in_channels": unet_params.in_channels,
259
+ "down_block_types": tuple(down_block_types),
260
+ "block_out_channels": tuple(block_out_channels),
261
+ "layers_per_block": unet_params.num_res_blocks,
262
+ "cross_attention_dim": unet_params.context_dim,
263
+ "attention_head_dim": head_dim,
264
+ "use_linear_projection": use_linear_projection,
265
+ "class_embed_type": class_embed_type,
266
+ "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
267
+ }
268
+
269
+ if not controlnet:
270
+ config["out_channels"] = unet_params.out_channels
271
+ config["up_block_types"] = tuple(up_block_types)
272
+
273
+ return config
274
+
275
+
276
+ def create_vae_diffusers_config(original_config, image_size: int):
277
+ """
278
+ Creates a config for the diffusers based on the config of the LDM model.
279
+ """
280
+ vae_params = original_config.model.params.first_stage_config.params.ddconfig
281
+ _ = original_config.model.params.first_stage_config.params.embed_dim
282
+
283
+ block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
284
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
285
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
286
+
287
+ config = {
288
+ "sample_size": image_size,
289
+ "in_channels": vae_params.in_channels,
290
+ "out_channels": vae_params.out_ch,
291
+ "down_block_types": tuple(down_block_types),
292
+ "up_block_types": tuple(up_block_types),
293
+ "block_out_channels": tuple(block_out_channels),
294
+ "latent_channels": vae_params.z_channels,
295
+ "layers_per_block": vae_params.num_res_blocks,
296
+ }
297
+ return config
298
+
299
+
300
+ def create_diffusers_schedular(original_config):
301
+ schedular = DDIMScheduler(
302
+ num_train_timesteps=original_config.model.params.timesteps,
303
+ beta_start=original_config.model.params.linear_start,
304
+ beta_end=original_config.model.params.linear_end,
305
+ beta_schedule="scaled_linear",
306
+ )
307
+ return schedular
308
+
309
+
310
+ def create_ldm_bert_config(original_config):
311
+ bert_params = original_config.model.parms.cond_stage_config.params
312
+ config = LDMBertConfig(
313
+ d_model=bert_params.n_embed,
314
+ encoder_layers=bert_params.n_layer,
315
+ encoder_ffn_dim=bert_params.n_embed * 4,
316
+ )
317
+ return config
318
+
319
+
320
+ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False):
321
+ """
322
+ Takes a state dict and a config, and returns a converted checkpoint.
323
+ """
324
+
325
+ # extract state_dict for UNet
326
+ unet_state_dict = {}
327
+ keys = list(checkpoint.keys())
328
+
329
+ if controlnet:
330
+ unet_key = "control_model."
331
+ else:
332
+ unet_key = "model.diffusion_model."
333
+
334
+ # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
335
+ if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
336
+ print(f"Checkpoint {path} has both EMA and non-EMA weights.")
337
+ print(
338
+ "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
339
+ " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
340
+ )
341
+ for key in keys:
342
+ if key.startswith("model.diffusion_model"):
343
+ flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
344
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
345
+ else:
346
+ if sum(k.startswith("model_ema") for k in keys) > 100:
347
+ print(
348
+ "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
349
+ " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
350
+ )
351
+
352
+ for key in keys:
353
+ if key.startswith(unet_key):
354
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
355
+
356
+ new_checkpoint = {}
357
+
358
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
359
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
360
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
361
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
362
+
363
+ if config["class_embed_type"] is None:
364
+ # No parameters to port
365
+ ...
366
+ elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
367
+ new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
368
+ new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
369
+ new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
370
+ new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
371
+ else:
372
+ raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
373
+
374
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
375
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
376
+
377
+ if not controlnet:
378
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
379
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
380
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
381
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
382
+
383
+ # Retrieves the keys for the input blocks only
384
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
385
+ input_blocks = {
386
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
387
+ for layer_id in range(num_input_blocks)
388
+ }
389
+
390
+ # Retrieves the keys for the middle blocks only
391
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
392
+ middle_blocks = {
393
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
394
+ for layer_id in range(num_middle_blocks)
395
+ }
396
+
397
+ # Retrieves the keys for the output blocks only
398
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
399
+ output_blocks = {
400
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
401
+ for layer_id in range(num_output_blocks)
402
+ }
403
+
404
+ for i in range(1, num_input_blocks):
405
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
406
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
407
+
408
+ resnets = [
409
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
410
+ ]
411
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
412
+
413
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
414
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
415
+ f"input_blocks.{i}.0.op.weight"
416
+ )
417
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
418
+ f"input_blocks.{i}.0.op.bias"
419
+ )
420
+
421
+ paths = renew_resnet_paths(resnets)
422
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
423
+ assign_to_checkpoint(
424
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
425
+ )
426
+
427
+ if len(attentions):
428
+ paths = renew_attention_paths(attentions)
429
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
430
+ assign_to_checkpoint(
431
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
432
+ )
433
+
434
+ resnet_0 = middle_blocks[0]
435
+ attentions = middle_blocks[1]
436
+ resnet_1 = middle_blocks[2]
437
+
438
+ resnet_0_paths = renew_resnet_paths(resnet_0)
439
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
440
+
441
+ resnet_1_paths = renew_resnet_paths(resnet_1)
442
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
443
+
444
+ attentions_paths = renew_attention_paths(attentions)
445
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
446
+ assign_to_checkpoint(
447
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
448
+ )
449
+
450
+ for i in range(num_output_blocks):
451
+ block_id = i // (config["layers_per_block"] + 1)
452
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
453
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
454
+ output_block_list = {}
455
+
456
+ for layer in output_block_layers:
457
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
458
+ if layer_id in output_block_list:
459
+ output_block_list[layer_id].append(layer_name)
460
+ else:
461
+ output_block_list[layer_id] = [layer_name]
462
+
463
+ if len(output_block_list) > 1:
464
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
465
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
466
+
467
+ resnet_0_paths = renew_resnet_paths(resnets)
468
+ paths = renew_resnet_paths(resnets)
469
+
470
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
471
+ assign_to_checkpoint(
472
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
473
+ )
474
+
475
+ output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
476
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
477
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
478
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
479
+ f"output_blocks.{i}.{index}.conv.weight"
480
+ ]
481
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
482
+ f"output_blocks.{i}.{index}.conv.bias"
483
+ ]
484
+
485
+ # Clear attentions as they have been attributed above.
486
+ if len(attentions) == 2:
487
+ attentions = []
488
+
489
+ if len(attentions):
490
+ paths = renew_attention_paths(attentions)
491
+ meta_path = {
492
+ "old": f"output_blocks.{i}.1",
493
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
494
+ }
495
+ assign_to_checkpoint(
496
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
497
+ )
498
+ else:
499
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
500
+ for path in resnet_0_paths:
501
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
502
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
503
+
504
+ new_checkpoint[new_path] = unet_state_dict[old_path]
505
+
506
+ if controlnet:
507
+ # conditioning embedding
508
+
509
+ orig_index = 0
510
+
511
+ new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop(
512
+ f"input_hint_block.{orig_index}.weight"
513
+ )
514
+ new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop(
515
+ f"input_hint_block.{orig_index}.bias"
516
+ )
517
+
518
+ orig_index += 2
519
+
520
+ diffusers_index = 0
521
+
522
+ while diffusers_index < 6:
523
+ new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop(
524
+ f"input_hint_block.{orig_index}.weight"
525
+ )
526
+ new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop(
527
+ f"input_hint_block.{orig_index}.bias"
528
+ )
529
+ diffusers_index += 1
530
+ orig_index += 2
531
+
532
+ new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop(
533
+ f"input_hint_block.{orig_index}.weight"
534
+ )
535
+ new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop(
536
+ f"input_hint_block.{orig_index}.bias"
537
+ )
538
+
539
+ # down blocks
540
+ for i in range(num_input_blocks):
541
+ new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight")
542
+ new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias")
543
+
544
+ # mid block
545
+ new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight")
546
+ new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias")
547
+
548
+ return new_checkpoint
549
+
550
+
551
+ def convert_ldm_vae_checkpoint(checkpoint, config):
552
+ # extract state dict for VAE
553
+ vae_state_dict = {}
554
+ vae_key = "first_stage_model."
555
+ keys = list(checkpoint.keys())
556
+ for key in keys:
557
+ if key.startswith(vae_key):
558
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
559
+
560
+ new_checkpoint = {}
561
+
562
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
563
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
564
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
565
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
566
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
567
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
568
+
569
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
570
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
571
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
572
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
573
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
574
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
575
+
576
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
577
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
578
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
579
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
580
+
581
+ # Retrieves the keys for the encoder down blocks only
582
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
583
+ down_blocks = {
584
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
585
+ }
586
+
587
+ # Retrieves the keys for the decoder up blocks only
588
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
589
+ up_blocks = {
590
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
591
+ }
592
+
593
+ for i in range(num_down_blocks):
594
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
595
+
596
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
597
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
598
+ f"encoder.down.{i}.downsample.conv.weight"
599
+ )
600
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
601
+ f"encoder.down.{i}.downsample.conv.bias"
602
+ )
603
+
604
+ paths = renew_vae_resnet_paths(resnets)
605
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
606
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
607
+
608
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
609
+ num_mid_res_blocks = 2
610
+ for i in range(1, num_mid_res_blocks + 1):
611
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
612
+
613
+ paths = renew_vae_resnet_paths(resnets)
614
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
615
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
616
+
617
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
618
+ paths = renew_vae_attention_paths(mid_attentions)
619
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
620
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
621
+ conv_attn_to_linear(new_checkpoint)
622
+
623
+ for i in range(num_up_blocks):
624
+ block_id = num_up_blocks - 1 - i
625
+ resnets = [
626
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
627
+ ]
628
+
629
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
630
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
631
+ f"decoder.up.{block_id}.upsample.conv.weight"
632
+ ]
633
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
634
+ f"decoder.up.{block_id}.upsample.conv.bias"
635
+ ]
636
+
637
+ paths = renew_vae_resnet_paths(resnets)
638
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
639
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
640
+
641
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
642
+ num_mid_res_blocks = 2
643
+ for i in range(1, num_mid_res_blocks + 1):
644
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
645
+
646
+ paths = renew_vae_resnet_paths(resnets)
647
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
648
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
649
+
650
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
651
+ paths = renew_vae_attention_paths(mid_attentions)
652
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
653
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
654
+ conv_attn_to_linear(new_checkpoint)
655
+ return new_checkpoint
656
+
657
+
658
+ def convert_ldm_bert_checkpoint(checkpoint, config):
659
+ def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
660
+ hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
661
+ hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
662
+ hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
663
+
664
+ hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
665
+ hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
666
+
667
+ def _copy_linear(hf_linear, pt_linear):
668
+ hf_linear.weight = pt_linear.weight
669
+ hf_linear.bias = pt_linear.bias
670
+
671
+ def _copy_layer(hf_layer, pt_layer):
672
+ # copy layer norms
673
+ _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
674
+ _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
675
+
676
+ # copy attn
677
+ _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
678
+
679
+ # copy MLP
680
+ pt_mlp = pt_layer[1][1]
681
+ _copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
682
+ _copy_linear(hf_layer.fc2, pt_mlp.net[2])
683
+
684
+ def _copy_layers(hf_layers, pt_layers):
685
+ for i, hf_layer in enumerate(hf_layers):
686
+ if i != 0:
687
+ i += i
688
+ pt_layer = pt_layers[i : i + 2]
689
+ _copy_layer(hf_layer, pt_layer)
690
+
691
+ hf_model = LDMBertModel(config).eval()
692
+
693
+ # copy embeds
694
+ hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
695
+ hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
696
+
697
+ # copy layer norm
698
+ _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
699
+
700
+ # copy hidden layers
701
+ _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
702
+
703
+ _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
704
+
705
+ return hf_model
706
+
707
+
708
+ def convert_ldm_clip_checkpoint(checkpoint):
709
+ text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
710
+ keys = list(checkpoint.keys())
711
+
712
+ text_model_dict = {}
713
+
714
+ for key in keys:
715
+ if key.startswith("cond_stage_model.transformer"):
716
+ text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
717
+
718
+ text_model.load_state_dict(text_model_dict)
719
+
720
+ return text_model
721
+
722
+
723
+ textenc_conversion_lst = [
724
+ ("cond_stage_model.model.positional_embedding", "text_model.embeddings.position_embedding.weight"),
725
+ ("cond_stage_model.model.token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
726
+ ("cond_stage_model.model.ln_final.weight", "text_model.final_layer_norm.weight"),
727
+ ("cond_stage_model.model.ln_final.bias", "text_model.final_layer_norm.bias"),
728
+ ]
729
+ textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst}
730
+
731
+ textenc_transformer_conversion_lst = [
732
+ # (stable-diffusion, HF Diffusers)
733
+ ("resblocks.", "text_model.encoder.layers."),
734
+ ("ln_1", "layer_norm1"),
735
+ ("ln_2", "layer_norm2"),
736
+ (".c_fc.", ".fc1."),
737
+ (".c_proj.", ".fc2."),
738
+ (".attn", ".self_attn"),
739
+ ("ln_final.", "transformer.text_model.final_layer_norm."),
740
+ ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
741
+ ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
742
+ ]
743
+ protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}
744
+ textenc_pattern = re.compile("|".join(protected.keys()))
745
+
746
+
747
+ def convert_paint_by_example_checkpoint(checkpoint):
748
+ config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14")
749
+ model = PaintByExampleImageEncoder(config)
750
+
751
+ keys = list(checkpoint.keys())
752
+
753
+ text_model_dict = {}
754
+
755
+ for key in keys:
756
+ if key.startswith("cond_stage_model.transformer"):
757
+ text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
758
+
759
+ # load clip vision
760
+ model.model.load_state_dict(text_model_dict)
761
+
762
+ # load mapper
763
+ keys_mapper = {
764
+ k[len("cond_stage_model.mapper.res") :]: v
765
+ for k, v in checkpoint.items()
766
+ if k.startswith("cond_stage_model.mapper")
767
+ }
768
+
769
+ MAPPING = {
770
+ "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"],
771
+ "attn.c_proj": ["attn1.to_out.0"],
772
+ "ln_1": ["norm1"],
773
+ "ln_2": ["norm3"],
774
+ "mlp.c_fc": ["ff.net.0.proj"],
775
+ "mlp.c_proj": ["ff.net.2"],
776
+ }
777
+
778
+ mapped_weights = {}
779
+ for key, value in keys_mapper.items():
780
+ prefix = key[: len("blocks.i")]
781
+ suffix = key.split(prefix)[-1].split(".")[-1]
782
+ name = key.split(prefix)[-1].split(suffix)[0][1:-1]
783
+ mapped_names = MAPPING[name]
784
+
785
+ num_splits = len(mapped_names)
786
+ for i, mapped_name in enumerate(mapped_names):
787
+ new_name = ".".join([prefix, mapped_name, suffix])
788
+ shape = value.shape[0] // num_splits
789
+ mapped_weights[new_name] = value[i * shape : (i + 1) * shape]
790
+
791
+ model.mapper.load_state_dict(mapped_weights)
792
+
793
+ # load final layer norm
794
+ model.final_layer_norm.load_state_dict(
795
+ {
796
+ "bias": checkpoint["cond_stage_model.final_ln.bias"],
797
+ "weight": checkpoint["cond_stage_model.final_ln.weight"],
798
+ }
799
+ )
800
+
801
+ # load final proj
802
+ model.proj_out.load_state_dict(
803
+ {
804
+ "bias": checkpoint["proj_out.bias"],
805
+ "weight": checkpoint["proj_out.weight"],
806
+ }
807
+ )
808
+
809
+ # load uncond vector
810
+ model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"])
811
+ return model
812
+
813
+
814
+ def convert_open_clip_checkpoint(checkpoint):
815
+ text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
816
+
817
+ keys = list(checkpoint.keys())
818
+
819
+ text_model_dict = {}
820
+
821
+ if "cond_stage_model.model.text_projection" in checkpoint:
822
+ d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0])
823
+ else:
824
+ d_model = 1024
825
+
826
+ text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
827
+
828
+ for key in keys:
829
+ if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer
830
+ continue
831
+ if key in textenc_conversion_map:
832
+ text_model_dict[textenc_conversion_map[key]] = checkpoint[key]
833
+ if key.startswith("cond_stage_model.model.transformer."):
834
+ new_key = key[len("cond_stage_model.model.transformer.") :]
835
+ if new_key.endswith(".in_proj_weight"):
836
+ new_key = new_key[: -len(".in_proj_weight")]
837
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
838
+ text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
839
+ text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :]
840
+ text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :]
841
+ elif new_key.endswith(".in_proj_bias"):
842
+ new_key = new_key[: -len(".in_proj_bias")]
843
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
844
+ text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model]
845
+ text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2]
846
+ text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :]
847
+ else:
848
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
849
+
850
+ text_model_dict[new_key] = checkpoint[key]
851
+
852
+ text_model.load_state_dict(text_model_dict)
853
+
854
+ return text_model
855
+
856
+
857
+ def stable_unclip_image_encoder(original_config):
858
+ """
859
+ Returns the image processor and clip image encoder for the img2img unclip pipeline.
860
+
861
+ We currently know of two types of stable unclip models which separately use the clip and the openclip image
862
+ encoders.
863
+ """
864
+
865
+ image_embedder_config = original_config.model.params.embedder_config
866
+
867
+ sd_clip_image_embedder_class = image_embedder_config.target
868
+ sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1]
869
+
870
+ if sd_clip_image_embedder_class == "ClipImageEmbedder":
871
+ clip_model_name = image_embedder_config.params.model
872
+
873
+ if clip_model_name == "ViT-L/14":
874
+ feature_extractor = CLIPImageProcessor()
875
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
876
+ else:
877
+ raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}")
878
+
879
+ elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder":
880
+ feature_extractor = CLIPImageProcessor()
881
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
882
+ else:
883
+ raise NotImplementedError(
884
+ f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}"
885
+ )
886
+
887
+ return feature_extractor, image_encoder
888
+
889
+
890
+ def stable_unclip_image_noising_components(
891
+ original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None
892
+ ):
893
+ """
894
+ Returns the noising components for the img2img and txt2img unclip pipelines.
895
+
896
+ Converts the stability noise augmentor into
897
+ 1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats
898
+ 2. a `DDPMScheduler` for holding the noise schedule
899
+
900
+ If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided.
901
+ """
902
+ noise_aug_config = original_config.model.params.noise_aug_config
903
+ noise_aug_class = noise_aug_config.target
904
+ noise_aug_class = noise_aug_class.split(".")[-1]
905
+
906
+ if noise_aug_class == "CLIPEmbeddingNoiseAugmentation":
907
+ noise_aug_config = noise_aug_config.params
908
+ embedding_dim = noise_aug_config.timestep_dim
909
+ max_noise_level = noise_aug_config.noise_schedule_config.timesteps
910
+ beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule
911
+
912
+ image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim)
913
+ image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule)
914
+
915
+ if "clip_stats_path" in noise_aug_config:
916
+ if clip_stats_path is None:
917
+ raise ValueError("This stable unclip config requires a `clip_stats_path`")
918
+
919
+ clip_mean, clip_std = torch.load(clip_stats_path, map_location=device)
920
+ clip_mean = clip_mean[None, :]
921
+ clip_std = clip_std[None, :]
922
+
923
+ clip_stats_state_dict = {
924
+ "mean": clip_mean,
925
+ "std": clip_std,
926
+ }
927
+
928
+ image_normalizer.load_state_dict(clip_stats_state_dict)
929
+ else:
930
+ raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}")
931
+
932
+ return image_normalizer, image_noising_scheduler
933
+
934
+
935
+ def convert_controlnet_checkpoint(
936
+ checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
937
+ ):
938
+ ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
939
+ ctrlnet_config["upcast_attention"] = upcast_attention
940
+
941
+ ctrlnet_config.pop("sample_size")
942
+
943
+ controlnet_model = ControlNetModel(**ctrlnet_config)
944
+
945
+ converted_ctrl_checkpoint = convert_ldm_unet_checkpoint(
946
+ checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True
947
+ )
948
+
949
+ controlnet_model.load_state_dict(converted_ctrl_checkpoint)
950
+
951
+ return controlnet_model
animatelcm/utils/convert_lora_safetensor_to_diffusers.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """ Conversion script for the LoRA's safetensors checkpoints. """
17
+
18
+ import argparse
19
+
20
+ import torch
21
+ from safetensors.torch import load_file
22
+
23
+ from diffusers import StableDiffusionPipeline
24
+
25
+
26
+ def convert_motion_lora_ckpt_to_diffusers(pipeline, state_dict, alpha=1.0):
27
+ # directly update weight in diffusers model
28
+ for key in state_dict:
29
+ # only process lora down key
30
+ if "up." in key: continue
31
+
32
+ up_key = key.replace(".down.", ".up.")
33
+ model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "")
34
+ model_key = model_key.replace("to_out.", "to_out.0.")
35
+ layer_infos = model_key.split(".")[:-1]
36
+
37
+ curr_layer = pipeline.unet
38
+ while len(layer_infos) > 0:
39
+ temp_name = layer_infos.pop(0)
40
+ curr_layer = curr_layer.__getattr__(temp_name)
41
+
42
+ weight_down = state_dict[key]
43
+ weight_up = state_dict[up_key]
44
+ curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
45
+
46
+ return pipeline
47
+
48
+
49
+
50
+ def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6):
51
+ # load base model
52
+ # pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32)
53
+
54
+ # load LoRA weight from .safetensors
55
+ # state_dict = load_file(checkpoint_path)
56
+
57
+ visited = []
58
+
59
+ # directly update weight in diffusers model
60
+ for key in state_dict:
61
+ # it is suggested to print out the key, it usually will be something like below
62
+ # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
63
+
64
+ # as we have set the alpha beforehand, so just skip
65
+ if ".alpha" in key or key in visited:
66
+ continue
67
+
68
+ if "text" in key:
69
+ layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
70
+ curr_layer = pipeline.text_encoder
71
+ else:
72
+ layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
73
+ curr_layer = pipeline.unet
74
+
75
+ # find the target layer
76
+ temp_name = layer_infos.pop(0)
77
+ while len(layer_infos) > -1:
78
+ try:
79
+ curr_layer = curr_layer.__getattr__(temp_name)
80
+ if len(layer_infos) > 0:
81
+ temp_name = layer_infos.pop(0)
82
+ elif len(layer_infos) == 0:
83
+ break
84
+ except Exception:
85
+ if len(temp_name) > 0:
86
+ temp_name += "_" + layer_infos.pop(0)
87
+ else:
88
+ temp_name = layer_infos.pop(0)
89
+
90
+ pair_keys = []
91
+ if "lora_down" in key:
92
+ pair_keys.append(key.replace("lora_down", "lora_up"))
93
+ pair_keys.append(key)
94
+ else:
95
+ pair_keys.append(key)
96
+ pair_keys.append(key.replace("lora_up", "lora_down"))
97
+
98
+ # update weight
99
+ if len(state_dict[pair_keys[0]].shape) == 4:
100
+ weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
101
+ weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
102
+ curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device)
103
+ else:
104
+ weight_up = state_dict[pair_keys[0]].to(torch.float32)
105
+ weight_down = state_dict[pair_keys[1]].to(torch.float32)
106
+ curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
107
+
108
+ # update visited list
109
+ for item in pair_keys:
110
+ visited.append(item)
111
+
112
+ return pipeline
113
+
114
+
115
+ if __name__ == "__main__":
116
+ parser = argparse.ArgumentParser()
117
+
118
+ parser.add_argument(
119
+ "--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format."
120
+ )
121
+ parser.add_argument(
122
+ "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
123
+ )
124
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
125
+ parser.add_argument(
126
+ "--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors"
127
+ )
128
+ parser.add_argument(
129
+ "--lora_prefix_text_encoder",
130
+ default="lora_te",
131
+ type=str,
132
+ help="The prefix of text encoder weight in safetensors",
133
+ )
134
+ parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW")
135
+ parser.add_argument(
136
+ "--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not."
137
+ )
138
+ parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
139
+
140
+ args = parser.parse_args()
141
+
142
+ base_model_path = args.base_model_path
143
+ checkpoint_path = args.checkpoint_path
144
+ dump_path = args.dump_path
145
+ lora_prefix_unet = args.lora_prefix_unet
146
+ lora_prefix_text_encoder = args.lora_prefix_text_encoder
147
+ alpha = args.alpha
148
+
149
+ pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha)
150
+
151
+ pipe = pipe.to(args.device)
152
+ pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
animatelcm/utils/lcm_utils.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from safetensors import safe_open
4
+
5
+
6
+ def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32):
7
+ """
8
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
9
+
10
+ Args:
11
+ timesteps (`torch.Tensor`):
12
+ generate embedding vectors at these timesteps
13
+ embedding_dim (`int`, *optional*, defaults to 512):
14
+ dimension of the embeddings to generate
15
+ dtype:
16
+ data type of the generated embeddings
17
+
18
+ Returns:
19
+ `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
20
+ """
21
+ assert len(w.shape) == 1
22
+ w = w * 1000.0
23
+
24
+ half_dim = embedding_dim // 2
25
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
26
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
27
+ emb = w.to(dtype)[:, None] * emb[None, :]
28
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
29
+ if embedding_dim % 2 == 1: # zero pad
30
+ emb = torch.nn.functional.pad(emb, (0, 1))
31
+ assert emb.shape == (w.shape[0], embedding_dim)
32
+ return emb
33
+
34
+
35
+ def append_dims(x, target_dims):
36
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
37
+ dims_to_append = target_dims - x.ndim
38
+ if dims_to_append < 0:
39
+ raise ValueError(
40
+ f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
41
+ return x[(...,) + (None,) * dims_to_append]
42
+
43
+
44
+ # From LCMScheduler.get_scalings_for_boundary_condition_discrete
45
+ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
46
+ c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2)
47
+ c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5
48
+ return c_skip, c_out
49
+
50
+
51
+ # Compare LCMScheduler.step, Step 4
52
+ def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas):
53
+ if prediction_type == "epsilon":
54
+ sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
55
+ alphas = extract_into_tensor(alphas, timesteps, sample.shape)
56
+ pred_x_0 = (sample - sigmas * model_output) / alphas
57
+ elif prediction_type == "v_prediction":
58
+ sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
59
+ alphas = extract_into_tensor(alphas, timesteps, sample.shape)
60
+ pred_x_0 = alphas * sample - sigmas * model_output
61
+ else:
62
+ raise ValueError(
63
+ f"Prediction type {prediction_type} currently not supported.")
64
+
65
+ return pred_x_0
66
+
67
+
68
+ def scale_for_loss(timesteps, sample, prediction_type, alphas, sigmas):
69
+ if prediction_type == "epsilon":
70
+ sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
71
+ alphas = extract_into_tensor(alphas, timesteps, sample.shape)
72
+ sample = sample * alphas / sigmas
73
+ else:
74
+ raise ValueError(
75
+ f"Prediction type {prediction_type} currently not supported.")
76
+
77
+ return sample
78
+
79
+
80
+ def extract_into_tensor(a, t, x_shape):
81
+ b, *_ = t.shape
82
+ out = a.gather(-1, t)
83
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
84
+
85
+
86
+ class DDIMSolver:
87
+ def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50):
88
+ # DDIM sampling parameters
89
+ step_ratio = timesteps // ddim_timesteps
90
+ self.ddim_timesteps = (
91
+ np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1
92
+ # self.ddim_timesteps = (torch.linspace(100**2,1000**2,30)**0.5).round().numpy().astype(np.int64) - 1
93
+ self.ddim_timesteps_prev = np.asarray(
94
+ [0] + self.ddim_timesteps[:-1].tolist()
95
+ )
96
+ self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
97
+ self.ddim_alpha_cumprods_prev = np.asarray(
98
+ [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
99
+ )
100
+ self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
101
+ self.ddim_alpha_cumprods_prev = np.asarray(
102
+ [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
103
+ )
104
+ # convert to torch tensors
105
+ self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()
106
+ self.ddim_timesteps_prev = torch.from_numpy(
107
+ self.ddim_timesteps_prev).long()
108
+ self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)
109
+ self.ddim_alpha_cumprods_prev = torch.from_numpy(
110
+ self.ddim_alpha_cumprods_prev)
111
+
112
+ def to(self, device):
113
+ self.ddim_timesteps = self.ddim_timesteps.to(device)
114
+ self.ddim_timesteps_prev = self.ddim_timesteps_prev.to(device)
115
+ self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)
116
+ self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(
117
+ device)
118
+ return self
119
+
120
+ def ddim_step(self, pred_x0, pred_noise, timestep_index):
121
+ alpha_cumprod_prev = extract_into_tensor(
122
+ self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape)
123
+ dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise
124
+ x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt
125
+ return x_prev
126
+
127
+
128
+ @torch.no_grad()
129
+ def update_ema(target_params, source_params, rate=0.99):
130
+ """
131
+ Update target parameters to be closer to those of source parameters using
132
+ an exponential moving average.
133
+
134
+ :param target_params: the target parameter sequence.
135
+ :param source_params: the source parameter sequence.
136
+ :param rate: the EMA rate (closer to 1 means slower).
137
+ """
138
+ for targ, src in zip(target_params, source_params):
139
+ targ.detach().mul_(rate).add_(src, alpha=1 - rate)
140
+
141
+
142
+ def convert_lcm_lora(unet, path, alpha=1.0):
143
+
144
+ if path.endswith(("ckpt",)):
145
+ state_dict = torch.load(path, map_location="cpu")
146
+ else:
147
+ state_dict = {}
148
+ with safe_open(path, framework="pt", device="cpu") as f:
149
+ for key in f.keys():
150
+ state_dict[key] = f.get_tensor(key)
151
+
152
+ num_alpha = 0
153
+ for key in state_dict.keys():
154
+ if "alpha" in key:
155
+ num_alpha += 1
156
+
157
+ lora_keys = [k for k in state_dict.keys(
158
+ ) if k.endswith("lora_down.weight")]
159
+
160
+ updated_state_dict = {}
161
+ for key in lora_keys:
162
+ lora_name = key.split(".")[0]
163
+
164
+ if lora_name.startswith("lora_unet_"):
165
+ diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
166
+
167
+ if "input.blocks" in diffusers_name:
168
+ diffusers_name = diffusers_name.replace(
169
+ "input.blocks", "down_blocks")
170
+ else:
171
+ diffusers_name = diffusers_name.replace(
172
+ "down.blocks", "down_blocks")
173
+
174
+ if "middle.block" in diffusers_name:
175
+ diffusers_name = diffusers_name.replace(
176
+ "middle.block", "mid_block")
177
+ else:
178
+ diffusers_name = diffusers_name.replace(
179
+ "mid.block", "mid_block")
180
+ if "output.blocks" in diffusers_name:
181
+ diffusers_name = diffusers_name.replace(
182
+ "output.blocks", "up_blocks")
183
+ else:
184
+ diffusers_name = diffusers_name.replace(
185
+ "up.blocks", "up_blocks")
186
+
187
+ diffusers_name = diffusers_name.replace(
188
+ "transformer.blocks", "transformer_blocks")
189
+ diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
190
+ diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
191
+ diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
192
+ diffusers_name = diffusers_name.replace(
193
+ "to.out.0.lora", "to_out_lora")
194
+ diffusers_name = diffusers_name.replace("proj.in", "proj_in")
195
+ diffusers_name = diffusers_name.replace("proj.out", "proj_out")
196
+ diffusers_name = diffusers_name.replace(
197
+ "time.emb.proj", "time_emb_proj")
198
+ diffusers_name = diffusers_name.replace(
199
+ "conv.shortcut", "conv_shortcut")
200
+
201
+ updated_state_dict[diffusers_name] = state_dict[key]
202
+ up_diffusers_name = diffusers_name.replace(".down.", ".up.")
203
+ up_key = key.replace("lora_down.weight", "lora_up.weight")
204
+ updated_state_dict[up_diffusers_name] = state_dict[up_key]
205
+
206
+ state_dict = updated_state_dict
207
+
208
+ num_lora = 0
209
+ for key in state_dict:
210
+ if "up." in key:
211
+ continue
212
+ up_key = key.replace(".down.", ".up.")
213
+ model_key = key.replace("processor.", "").replace("_lora", "").replace(
214
+ "down.", "").replace("up.", "").replace(".lora", "")
215
+ model_key = model_key.replace("to_out.", "to_out.0.")
216
+ layer_infos = model_key.split(".")[:-1]
217
+
218
+ curr_layer = unet
219
+ while len(layer_infos) > 0:
220
+ temp_name = layer_infos.pop(0)
221
+ curr_layer = curr_layer.__getattr__(temp_name)
222
+
223
+ weight_down = state_dict[key].to(
224
+ curr_layer.weight.data.device, curr_layer.weight.data.dtype)
225
+ weight_up = state_dict[up_key].to(
226
+ curr_layer.weight.data.device, curr_layer.weight.data.dtype)
227
+
228
+ if weight_up.ndim == 2:
229
+ curr_layer.weight.data += 1/8 * alpha * \
230
+ torch.mm(weight_up, weight_down)
231
+ else:
232
+ assert weight_up.ndim == 4
233
+ curr_layer.weight.data += 1/8 * alpha * torch.mm(weight_up.flatten(
234
+ start_dim=1), weight_down.flatten(start_dim=1)).reshape(curr_layer.weight.data.shape)
235
+ num_lora += 1
236
+
237
+ return unet
animatelcm/utils/util.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import imageio
3
+ import numpy as np
4
+ from typing import Union
5
+
6
+ import torch
7
+ import torchvision
8
+ import torch.distributed as dist
9
+
10
+ from safetensors import safe_open
11
+ from tqdm import tqdm
12
+ from einops import rearrange
13
+ from animatelcm.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
14
+ from animatelcm.utils.convert_lora_safetensor_to_diffusers import convert_lora, convert_motion_lora_ckpt_to_diffusers
15
+
16
+
17
+ def zero_rank_print(s):
18
+ if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s)
19
+
20
+
21
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
22
+ videos = rearrange(videos, "b c t h w -> t b c h w")
23
+ outputs = []
24
+ for x in videos:
25
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
26
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
27
+ if rescale:
28
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
29
+ x = (x * 255).numpy().astype(np.uint8)
30
+ outputs.append(x)
31
+
32
+ os.makedirs(os.path.dirname(path), exist_ok=True)
33
+ imageio.mimsave(path, outputs, fps=fps)
34
+
35
+
36
+ # DDIM Inversion
37
+ @torch.no_grad()
38
+ def init_prompt(prompt, pipeline):
39
+ uncond_input = pipeline.tokenizer(
40
+ [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
41
+ return_tensors="pt"
42
+ )
43
+ uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
44
+ text_input = pipeline.tokenizer(
45
+ [prompt],
46
+ padding="max_length",
47
+ max_length=pipeline.tokenizer.model_max_length,
48
+ truncation=True,
49
+ return_tensors="pt",
50
+ )
51
+ text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
52
+ context = torch.cat([uncond_embeddings, text_embeddings])
53
+
54
+ return context
55
+
56
+
57
+ def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
58
+ sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
59
+ timestep, next_timestep = min(
60
+ timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
61
+ alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
62
+ alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
63
+ beta_prod_t = 1 - alpha_prod_t
64
+ next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
65
+ next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
66
+ next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
67
+ return next_sample
68
+
69
+
70
+ def get_noise_pred_single(latents, t, context, unet):
71
+ noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
72
+ return noise_pred
73
+
74
+
75
+ @torch.no_grad()
76
+ def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
77
+ context = init_prompt(prompt, pipeline)
78
+ uncond_embeddings, cond_embeddings = context.chunk(2)
79
+ all_latent = [latent]
80
+ latent = latent.clone().detach()
81
+ for i in tqdm(range(num_inv_steps)):
82
+ t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
83
+ noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
84
+ latent = next_step(noise_pred, t, latent, ddim_scheduler)
85
+ all_latent.append(latent)
86
+ return all_latent
87
+
88
+
89
+ @torch.no_grad()
90
+ def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
91
+ ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
92
+ return ddim_latents
93
+
94
+ def load_weights(
95
+ animation_pipeline,
96
+ motion_module_path = "",
97
+ motion_module_lora_configs = [],
98
+ dreambooth_model_path = "",
99
+ lora_model_path = "",
100
+ lora_alpha = 0.8,
101
+ ):
102
+ unet_state_dict = {}
103
+ if motion_module_path != "":
104
+ print(f"load motion module from {motion_module_path}")
105
+ motion_module_state_dict = torch.load(motion_module_path, map_location="cpu")
106
+ motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict
107
+ unet_state_dict.update({name: param for name, param in motion_module_state_dict.items() if "motion_modules." in name})
108
+
109
+ missing, unexpected = animation_pipeline.unet.load_state_dict(unet_state_dict, strict=False)
110
+ assert len(unexpected) == 0
111
+ del unet_state_dict
112
+
113
+ if dreambooth_model_path != "":
114
+ print(f"load dreambooth model from {dreambooth_model_path}")
115
+ if dreambooth_model_path.endswith(".safetensors"):
116
+ dreambooth_state_dict = {}
117
+ with safe_open(dreambooth_model_path, framework="pt", device="cpu") as f:
118
+ for key in f.keys():
119
+ dreambooth_state_dict[key] = f.get_tensor(key)
120
+ elif dreambooth_model_path.endswith(".ckpt"):
121
+ dreambooth_state_dict = torch.load(dreambooth_model_path, map_location="cpu")
122
+
123
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, animation_pipeline.vae.config)
124
+ animation_pipeline.vae.load_state_dict(converted_vae_checkpoint)
125
+
126
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, animation_pipeline.unet.config)
127
+ animation_pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
128
+
129
+ animation_pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict)
130
+ del dreambooth_state_dict
131
+
132
+ if lora_model_path != "":
133
+ print(f"load lora model from {lora_model_path}")
134
+ assert lora_model_path.endswith(".safetensors")
135
+ lora_state_dict = {}
136
+ with safe_open(lora_model_path, framework="pt", device="cpu") as f:
137
+ for key in f.keys():
138
+ lora_state_dict[key] = f.get_tensor(key)
139
+
140
+ animation_pipeline = convert_lora(animation_pipeline, lora_state_dict, alpha=lora_alpha)
141
+ del lora_state_dict
142
+
143
+
144
+ for motion_module_lora_config in motion_module_lora_configs:
145
+ path, alpha = motion_module_lora_config["path"], motion_module_lora_config["alpha"]
146
+ print(f"load motion LoRA from {path}")
147
+
148
+ motion_lora_state_dict = torch.load(path, map_location="cpu")
149
+ motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict
150
+
151
+ animation_pipeline = convert_motion_lora_ckpt_to_diffusers(animation_pipeline, motion_lora_state_dict, alpha)
152
+
153
+ return animation_pipeline
app.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import json
4
+ import torch
5
+ import random
6
+
7
+ import gradio as gr
8
+ from glob import glob
9
+ from omegaconf import OmegaConf
10
+ from datetime import datetime
11
+ from safetensors import safe_open
12
+
13
+ from diffusers import AutoencoderKL
14
+ from diffusers.utils.import_utils import is_xformers_available
15
+ from transformers import CLIPTextModel, CLIPTokenizer
16
+
17
+ from animatelcm.scheduler.lcm_scheduler import LCMScheduler
18
+ from animatelcm.models.unet import UNet3DConditionModel
19
+ from animatelcm.pipelines.pipeline_animation import AnimationPipeline
20
+ from animatelcm.utils.util import save_videos_grid
21
+ from animatelcm.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
22
+ from animatelcm.utils.convert_lora_safetensor_to_diffusers import convert_lora
23
+ from animatelcm.utils.lcm_utils import convert_lcm_lora
24
+ import copy
25
+
26
+ sample_idx = 0
27
+ scheduler_dict = {
28
+ "LCM": LCMScheduler,
29
+ }
30
+
31
+ css = """
32
+ .toolbutton {
33
+ margin-buttom: 0em 0em 0em 0em;
34
+ max-width: 2.5em;
35
+ min-width: 2.5em !important;
36
+ height: 2.5em;
37
+ }
38
+ """
39
+
40
+
41
+ class AnimateController:
42
+ def __init__(self):
43
+
44
+ # config dirs
45
+ self.basedir = os.getcwd()
46
+ self.stable_diffusion_dir = os.path.join(
47
+ self.basedir, "models", "StableDiffusion")
48
+ self.motion_module_dir = os.path.join(
49
+ self.basedir, "models", "Motion_Module")
50
+ self.personalized_model_dir = os.path.join(
51
+ self.basedir, "models", "DreamBooth_LoRA")
52
+ self.savedir = os.path.join(
53
+ self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
54
+ self.savedir_sample = os.path.join(self.savedir, "sample")
55
+ self.lcm_lora_path = "models/LCM_LoRA/sd15_t2v_beta_lora.safetensors"
56
+ os.makedirs(self.savedir, exist_ok=True)
57
+
58
+ self.stable_diffusion_list = []
59
+ self.motion_module_list = []
60
+ self.personalized_model_list = []
61
+
62
+ self.refresh_stable_diffusion()
63
+ self.refresh_motion_module()
64
+ self.refresh_personalized_model()
65
+
66
+ # config models
67
+ self.tokenizer = None
68
+ self.text_encoder = None
69
+ self.vae = None
70
+ self.unet = None
71
+ self.pipeline = None
72
+ self.lora_model_state_dict = {}
73
+
74
+ self.inference_config = OmegaConf.load("configs/inference.yaml")
75
+
76
+ def refresh_stable_diffusion(self):
77
+ self.stable_diffusion_list = glob(
78
+ os.path.join(self.stable_diffusion_dir, "*/"))
79
+
80
+ def refresh_motion_module(self):
81
+ motion_module_list = glob(os.path.join(
82
+ self.motion_module_dir, "*.ckpt"))
83
+ self.motion_module_list = [
84
+ os.path.basename(p) for p in motion_module_list]
85
+
86
+ def refresh_personalized_model(self):
87
+ personalized_model_list = glob(os.path.join(
88
+ self.personalized_model_dir, "*.safetensors"))
89
+ self.personalized_model_list = [
90
+ os.path.basename(p) for p in personalized_model_list]
91
+
92
+ def update_stable_diffusion(self, stable_diffusion_dropdown):
93
+ stable_diffusion_dropdown = os.path.join(self.stable_diffusion_dir,stable_diffusion_dropdown)
94
+ self.tokenizer = CLIPTokenizer.from_pretrained(
95
+ stable_diffusion_dropdown, subfolder="tokenizer")
96
+ self.text_encoder = CLIPTextModel.from_pretrained(
97
+ stable_diffusion_dropdown, subfolder="text_encoder").cuda()
98
+ self.vae = AutoencoderKL.from_pretrained(
99
+ stable_diffusion_dropdown, subfolder="vae").cuda()
100
+ self.unet = UNet3DConditionModel.from_pretrained_2d(
101
+ stable_diffusion_dropdown, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).cuda()
102
+ return gr.Dropdown.update()
103
+
104
+ def update_motion_module(self, motion_module_dropdown):
105
+ if self.unet is None:
106
+ gr.Info(f"Please select a pretrained model path.")
107
+ return gr.Dropdown.update(value=None)
108
+ else:
109
+ motion_module_dropdown = os.path.join(
110
+ self.motion_module_dir, motion_module_dropdown)
111
+ motion_module_state_dict = torch.load(
112
+ motion_module_dropdown, map_location="cpu")
113
+ missing, unexpected = self.unet.load_state_dict(
114
+ motion_module_state_dict, strict=False)
115
+ assert len(unexpected) == 0
116
+ return gr.Dropdown.update()
117
+
118
+ def update_base_model(self, base_model_dropdown):
119
+ if self.unet is None:
120
+ gr.Info(f"Please select a pretrained model path.")
121
+ return gr.Dropdown.update(value=None)
122
+ else:
123
+ base_model_dropdown = os.path.join(
124
+ self.personalized_model_dir, base_model_dropdown)
125
+ base_model_state_dict = {}
126
+ with safe_open(base_model_dropdown, framework="pt", device="cpu") as f:
127
+ for key in f.keys():
128
+ base_model_state_dict[key] = f.get_tensor(key)
129
+
130
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(
131
+ base_model_state_dict, self.vae.config)
132
+ self.vae.load_state_dict(converted_vae_checkpoint)
133
+
134
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(
135
+ base_model_state_dict, self.unet.config)
136
+ self.unet.load_state_dict(converted_unet_checkpoint, strict=False)
137
+
138
+ # self.text_encoder = convert_ldm_clip_checkpoint(base_model_state_dict)
139
+ return gr.Dropdown.update()
140
+
141
+ def update_lora_model(self, lora_model_dropdown):
142
+ lora_model_dropdown = os.path.join(
143
+ self.personalized_model_dir, lora_model_dropdown)
144
+ self.lora_model_state_dict = {}
145
+ if lora_model_dropdown == "none":
146
+ pass
147
+ else:
148
+ with safe_open(lora_model_dropdown, framework="pt", device="cpu") as f:
149
+ for key in f.keys():
150
+ self.lora_model_state_dict[key] = f.get_tensor(key)
151
+ return gr.Dropdown.update()
152
+
153
+ def animate(
154
+ self,
155
+ stable_diffusion_dropdown,
156
+ motion_module_dropdown,
157
+ base_model_dropdown,
158
+ lora_alpha_slider,
159
+ spatial_lora_slider,
160
+ prompt_textbox,
161
+ negative_prompt_textbox,
162
+ sampler_dropdown,
163
+ sample_step_slider,
164
+ width_slider,
165
+ length_slider,
166
+ height_slider,
167
+ cfg_scale_slider,
168
+ seed_textbox
169
+ ):
170
+ if self.unet is None:
171
+ raise gr.Error(f"Please select a pretrained model path.")
172
+ if motion_module_dropdown == "":
173
+ raise gr.Error(f"Please select a motion module.")
174
+ if base_model_dropdown == "":
175
+ raise gr.Error(f"Please select a base DreamBooth model.")
176
+
177
+ if is_xformers_available():
178
+ self.unet.enable_xformers_memory_efficient_attention()
179
+
180
+ pipeline = AnimationPipeline(
181
+ vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet,
182
+ scheduler=scheduler_dict[sampler_dropdown](
183
+ **OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
184
+ ).to("cuda")
185
+
186
+ if self.lora_model_state_dict != {}:
187
+ pipeline = convert_lora(
188
+ pipeline, self.lora_model_state_dict, alpha=lora_alpha_slider)
189
+
190
+ pipeline.unet = convert_lcm_lora(copy.deepcopy(
191
+ self.unet), self.lcm_lora_path, spatial_lora_slider)
192
+
193
+ pipeline.to("cuda")
194
+
195
+ if seed_textbox != -1 and seed_textbox != "":
196
+ torch.manual_seed(int(seed_textbox))
197
+ else:
198
+ torch.seed()
199
+ seed = torch.initial_seed()
200
+
201
+ sample = pipeline(
202
+ prompt_textbox,
203
+ negative_prompt=negative_prompt_textbox,
204
+ num_inference_steps=sample_step_slider,
205
+ guidance_scale=cfg_scale_slider,
206
+ width=width_slider,
207
+ height=height_slider,
208
+ video_length=length_slider,
209
+ ).videos
210
+
211
+ save_sample_path = os.path.join(
212
+ self.savedir_sample, f"{sample_idx}.mp4")
213
+ save_videos_grid(sample, save_sample_path)
214
+
215
+ sample_config = {
216
+ "prompt": prompt_textbox,
217
+ "n_prompt": negative_prompt_textbox,
218
+ "sampler": sampler_dropdown,
219
+ "num_inference_steps": sample_step_slider,
220
+ "guidance_scale": cfg_scale_slider,
221
+ "width": width_slider,
222
+ "height": height_slider,
223
+ "video_length": length_slider,
224
+ "seed": seed
225
+ }
226
+ json_str = json.dumps(sample_config, indent=4)
227
+ with open(os.path.join(self.savedir, "logs.json"), "a") as f:
228
+ f.write(json_str)
229
+ f.write("\n\n")
230
+ return gr.Video.update(value=save_sample_path)
231
+
232
+
233
+ controller = AnimateController()
234
+
235
+
236
+ def ui():
237
+ with gr.Blocks(css=css) as demo:
238
+ gr.Markdown(
239
+ """
240
+ # [AnimateLCM: Accelerating the Animation of Personalized Diffusion Models and Adapters with Decoupled Consistency Learning](https://arxiv.org/abs/2402.00769)
241
+ Fu-Yun Wang, Zhaoyang Huang (*Corresponding Author), Xiaoyu Shi, Weikang Bian, Guanglu Song, Yu Liu, Hongsheng Li (*Corresponding Author)<br>
242
+ [arXiv Report](https://arxiv.org/abs/2402.00769) | [Project Page](https://animatelcm.github.io/) | [Github](https://github.com/G-U-N/AnimateLCM)
243
+ """
244
+ )
245
+ with gr.Column(variant="panel"):
246
+ gr.Markdown(
247
+ """
248
+ ### 1. Model checkpoints (select pretrained model path first).
249
+ """
250
+ )
251
+ with gr.Row():
252
+ stable_diffusion_dropdown = gr.Dropdown(
253
+ label="Pretrained Model Path",
254
+ choices=controller.stable_diffusion_list,
255
+ interactive=True,
256
+ )
257
+ stable_diffusion_dropdown.change(fn=controller.update_stable_diffusion, inputs=[
258
+ stable_diffusion_dropdown], outputs=[stable_diffusion_dropdown])
259
+
260
+ stable_diffusion_refresh_button = gr.Button(
261
+ value="\U0001F503", elem_classes="toolbutton")
262
+
263
+ def update_stable_diffusion():
264
+ controller.refresh_stable_diffusion()
265
+ return gr.Dropdown.update(choices=controller.stable_diffusion_list)
266
+ stable_diffusion_refresh_button.click(
267
+ fn=update_stable_diffusion, inputs=[], outputs=[stable_diffusion_dropdown])
268
+
269
+ with gr.Row():
270
+ motion_module_dropdown = gr.Dropdown(
271
+ label="Select motion module",
272
+ choices=controller.motion_module_list,
273
+ interactive=True,
274
+ )
275
+ motion_module_dropdown.change(fn=controller.update_motion_module, inputs=[
276
+ motion_module_dropdown], outputs=[motion_module_dropdown])
277
+
278
+ motion_module_refresh_button = gr.Button(
279
+ value="\U0001F503", elem_classes="toolbutton")
280
+
281
+ def update_motion_module():
282
+ controller.refresh_motion_module()
283
+ return gr.Dropdown.update(choices=controller.motion_module_list)
284
+ motion_module_refresh_button.click(
285
+ fn=update_motion_module, inputs=[], outputs=[motion_module_dropdown])
286
+
287
+ base_model_dropdown = gr.Dropdown(
288
+ label="Select base Dreambooth model (required)",
289
+ choices=controller.personalized_model_list,
290
+ interactive=True,
291
+ )
292
+ base_model_dropdown.change(fn=controller.update_base_model, inputs=[
293
+ base_model_dropdown], outputs=[base_model_dropdown])
294
+
295
+ lora_model_dropdown = gr.Dropdown(
296
+ label="Select LoRA model (optional)",
297
+ choices=["none"]
298
+ value="none",
299
+ interactive=True,
300
+ )
301
+ lora_model_dropdown.change(fn=controller.update_lora_model, inputs=[
302
+ lora_model_dropdown], outputs=[lora_model_dropdown])
303
+
304
+ lora_alpha_slider = gr.Slider(
305
+ label="LoRA alpha", value=0.8, minimum=0, maximum=2, interactive=True)
306
+ spatial_lora_slider = gr.Slider(
307
+ label="LCM LoRA alpha", value=0.8, minimum=0.0, maximum=1.0, interactive=True)
308
+
309
+ personalized_refresh_button = gr.Button(
310
+ value="\U0001F503", elem_classes="toolbutton")
311
+
312
+ def update_personalized_model():
313
+ controller.refresh_personalized_model()
314
+ return [
315
+ gr.Dropdown.update(
316
+ choices=controller.personalized_model_list),
317
+ gr.Dropdown.update(
318
+ choices=["none"] + controller.personalized_model_list)
319
+ ]
320
+ personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[
321
+ base_model_dropdown, lora_model_dropdown])
322
+
323
+ with gr.Column(variant="panel"):
324
+ gr.Markdown(
325
+ """
326
+ ### 2. Configs for AnimateLCM.
327
+ """
328
+ )
329
+
330
+ prompt_textbox = gr.Textbox(label="Prompt", lines=2)
331
+ negative_prompt_textbox = gr.Textbox(
332
+ label="Negative prompt", lines=2)
333
+
334
+ with gr.Row().style(equal_height=False):
335
+ with gr.Column():
336
+ with gr.Row():
337
+ sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(
338
+ scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
339
+ sample_step_slider = gr.Slider(
340
+ label="Sampling steps", value=4, minimum=1, maximum=25, step=1)
341
+
342
+ width_slider = gr.Slider(
343
+ label="Width", value=512, minimum=256, maximum=1024, step=64)
344
+ height_slider = gr.Slider(
345
+ label="Height", value=512, minimum=256, maximum=1024, step=64)
346
+ length_slider = gr.Slider(
347
+ label="Animation length", value=16, minimum=12, maximum=20, step=1)
348
+ cfg_scale_slider = gr.Slider(
349
+ label="CFG Scale", value=1, minimum=1, maximum=2)
350
+
351
+ with gr.Row():
352
+ seed_textbox = gr.Textbox(label="Seed", value=-1)
353
+ seed_button = gr.Button(
354
+ value="\U0001F3B2", elem_classes="toolbutton")
355
+ seed_button.click(fn=lambda: gr.Textbox.update(
356
+ value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox])
357
+
358
+ generate_button = gr.Button(
359
+ value="Generate", variant='primary')
360
+
361
+ result_video = gr.Video(
362
+ label="Generated Animation", interactive=False)
363
+
364
+ generate_button.click(
365
+ fn=controller.animate,
366
+ inputs=[
367
+ stable_diffusion_dropdown,
368
+ motion_module_dropdown,
369
+ base_model_dropdown,
370
+ lora_alpha_slider,
371
+ spatial_lora_slider,
372
+ prompt_textbox,
373
+ negative_prompt_textbox,
374
+ sampler_dropdown,
375
+ sample_step_slider,
376
+ width_slider,
377
+ length_slider,
378
+ height_slider,
379
+ cfg_scale_slider,
380
+ seed_textbox,
381
+ ],
382
+ outputs=[result_video]
383
+ )
384
+
385
+ return demo
386
+
387
+
388
+ if __name__ == "__main__":
389
+ demo = ui()
390
+ # gr.close_all()
391
+ demo.queue(concurrency_count=3, max_size=20)
392
+ demo.launch(share=True, server_name="127.0.0.1")
models/.DS_Store ADDED
Binary file (6.15 kB). View file
 
models/DreamBooth_LoRA/cartoon2d.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cbfba64e662370f59d4aa2aa69bf16749fce93846ccce20506aee5df01169859
3
+ size 4244124028
models/DreamBooth_LoRA/cartoon3d.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a6b4c0392d7486bfa4fd1a31c7b7d2679f743f8ea8d9f219c82b5c33db31ddb9
3
+ size 2132625644
models/DreamBooth_LoRA/realistic1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c0d1994c73d784a17a5b335ae8bda02dcc8dd2fc5f5dbf55169d5aab385e53f2
3
+ size 2132650523
models/DreamBooth_LoRA/realistic2.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a38fa861a24f4f4c6e0f68289101e645dd9ca1e93e1049cc8a4f2a77513fad52
3
+ size 2400040290
models/LCM_LoRA/Put LCMLoRA checkpoints here.txt ADDED
File without changes
models/LCM_LoRA/sd15_t2v_beta_lora.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f90d840e075ff588a58e22c6586e2ae9a6f7922996ee6649a7f01072333afe4
3
+ size 134621556
models/Motion_Module/Put motion module checkpoints here.txt ADDED
File without changes
models/Motion_Module/sd15_t2v_beta_motion.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b46c3de62e5696af72c4056e3cdcbea12fbc19581c0aad7b6f2b027851148f5f
3
+ size 1813041929
models/StableDiffusion/Put diffusers stable-diffusion-v1-5 repo here.txt ADDED
File without changes
models/StableDiffusion/stable-diffusion-v1-5/.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.npy filter=lfs diff=lfs merge=lfs -text
14
+ *.npz filter=lfs diff=lfs merge=lfs -text
15
+ *.onnx filter=lfs diff=lfs merge=lfs -text
16
+ *.ot filter=lfs diff=lfs merge=lfs -text
17
+ *.parquet filter=lfs diff=lfs merge=lfs -text
18
+ *.pb filter=lfs diff=lfs merge=lfs -text
19
+ *.pickle filter=lfs diff=lfs merge=lfs -text
20
+ *.pkl filter=lfs diff=lfs merge=lfs -text
21
+ *.pt filter=lfs diff=lfs merge=lfs -text
22
+ *.pth filter=lfs diff=lfs merge=lfs -text
23
+ *.rar filter=lfs diff=lfs merge=lfs -text
24
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
25
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
26
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
27
+ *.tflite filter=lfs diff=lfs merge=lfs -text
28
+ *.tgz filter=lfs diff=lfs merge=lfs -text
29
+ *.wasm filter=lfs diff=lfs merge=lfs -text
30
+ *.xz filter=lfs diff=lfs merge=lfs -text
31
+ *.zip filter=lfs diff=lfs merge=lfs -text
32
+ *.zst filter=lfs diff=lfs merge=lfs -text
33
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
34
+ v1-5-pruned-emaonly.ckpt filter=lfs diff=lfs merge=lfs -text
35
+ v1-5-pruned.ckpt filter=lfs diff=lfs merge=lfs -text
models/StableDiffusion/stable-diffusion-v1-5/README.md ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: creativeml-openrail-m
3
+ tags:
4
+ - stable-diffusion
5
+ - stable-diffusion-diffusers
6
+ - text-to-image
7
+ inference: true
8
+ extra_gated_prompt: |-
9
+ This model is open access and available to all, with a CreativeML OpenRAIL-M license further specifying rights and usage.
10
+ The CreativeML OpenRAIL License specifies:
11
+
12
+ 1. You can't use the model to deliberately produce nor share illegal or harmful outputs or content
13
+ 2. CompVis claims no rights on the outputs you generate, you are free to use them and are accountable for their use which must not go against the provisions set in the license
14
+ 3. You may re-distribute the weights and use the model commercially and/or as a service. If you do, please be aware you have to include the same use restrictions as the ones in the license and share a copy of the CreativeML OpenRAIL-M to all your users (please read the license entirely and carefully)
15
+ Please read the full license carefully here: https://huggingface.co/spaces/CompVis/stable-diffusion-license
16
+
17
+ extra_gated_heading: Please read the LICENSE to access this model
18
+ ---
19
+
20
+ # Stable Diffusion v1-5 Model Card
21
+
22
+ Stable Diffusion is a latent text-to-image diffusion model capable of generating photo-realistic images given any text input.
23
+ For more information about how Stable Diffusion functions, please have a look at [🤗's Stable Diffusion blog](https://huggingface.co/blog/stable_diffusion).
24
+
25
+ The **Stable-Diffusion-v1-5** checkpoint was initialized with the weights of the [Stable-Diffusion-v1-2](https:/steps/huggingface.co/CompVis/stable-diffusion-v1-2)
26
+ checkpoint and subsequently fine-tuned on 595k steps at resolution 512x512 on "laion-aesthetics v2 5+" and 10% dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598).
27
+
28
+ You can use this both with the [🧨Diffusers library](https://github.com/huggingface/diffusers) and the [RunwayML GitHub repository](https://github.com/runwayml/stable-diffusion).
29
+
30
+ ### Diffusers
31
+ ```py
32
+ from diffusers import StableDiffusionPipeline
33
+ import torch
34
+
35
+ model_id = "runwayml/stable-diffusion-v1-5"
36
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
37
+ pipe = pipe.to("cuda")
38
+
39
+ prompt = "a photo of an astronaut riding a horse on mars"
40
+ image = pipe(prompt).images[0]
41
+
42
+ image.save("astronaut_rides_horse.png")
43
+ ```
44
+ For more detailed instructions, use-cases and examples in JAX follow the instructions [here](https://github.com/huggingface/diffusers#text-to-image-generation-with-stable-diffusion)
45
+
46
+ ### Original GitHub Repository
47
+
48
+ 1. Download the weights
49
+ - [v1-5-pruned-emaonly.ckpt](https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt) - 4.27GB, ema-only weight. uses less VRAM - suitable for inference
50
+ - [v1-5-pruned.ckpt](https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned.ckpt) - 7.7GB, ema+non-ema weights. uses more VRAM - suitable for fine-tuning
51
+
52
+ 2. Follow instructions [here](https://github.com/runwayml/stable-diffusion).
53
+
54
+ ## Model Details
55
+ - **Developed by:** Robin Rombach, Patrick Esser
56
+ - **Model type:** Diffusion-based text-to-image generation model
57
+ - **Language(s):** English
58
+ - **License:** [The CreativeML OpenRAIL M license](https://huggingface.co/spaces/CompVis/stable-diffusion-license) is an [Open RAIL M license](https://www.licenses.ai/blog/2022/8/18/naming-convention-of-responsible-ai-licenses), adapted from the work that [BigScience](https://bigscience.huggingface.co/) and [the RAIL Initiative](https://www.licenses.ai/) are jointly carrying in the area of responsible AI licensing. See also [the article about the BLOOM Open RAIL license](https://bigscience.huggingface.co/blog/the-bigscience-rail-license) on which our license is based.
59
+ - **Model Description:** This is a model that can be used to generate and modify images based on text prompts. It is a [Latent Diffusion Model](https://arxiv.org/abs/2112.10752) that uses a fixed, pretrained text encoder ([CLIP ViT-L/14](https://arxiv.org/abs/2103.00020)) as suggested in the [Imagen paper](https://arxiv.org/abs/2205.11487).
60
+ - **Resources for more information:** [GitHub Repository](https://github.com/CompVis/stable-diffusion), [Paper](https://arxiv.org/abs/2112.10752).
61
+ - **Cite as:**
62
+
63
+ @InProceedings{Rombach_2022_CVPR,
64
+ author = {Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bj\"orn},
65
+ title = {High-Resolution Image Synthesis With Latent Diffusion Models},
66
+ booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
67
+ month = {June},
68
+ year = {2022},
69
+ pages = {10684-10695}
70
+ }
71
+
72
+ # Uses
73
+
74
+ ## Direct Use
75
+ The model is intended for research purposes only. Possible research areas and
76
+ tasks include
77
+
78
+ - Safe deployment of models which have the potential to generate harmful content.
79
+ - Probing and understanding the limitations and biases of generative models.
80
+ - Generation of artworks and use in design and other artistic processes.
81
+ - Applications in educational or creative tools.
82
+ - Research on generative models.
83
+
84
+ Excluded uses are described below.
85
+
86
+ ### Misuse, Malicious Use, and Out-of-Scope Use
87
+ _Note: This section is taken from the [DALLE-MINI model card](https://huggingface.co/dalle-mini/dalle-mini), but applies in the same way to Stable Diffusion v1_.
88
+
89
+
90
+ The model should not be used to intentionally create or disseminate images that create hostile or alienating environments for people. This includes generating images that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes.
91
+
92
+ #### Out-of-Scope Use
93
+ The model was not trained to be factual or true representations of people or events, and therefore using the model to generate such content is out-of-scope for the abilities of this model.
94
+
95
+ #### Misuse and Malicious Use
96
+ Using the model to generate content that is cruel to individuals is a misuse of this model. This includes, but is not limited to:
97
+
98
+ - Generating demeaning, dehumanizing, or otherwise harmful representations of people or their environments, cultures, religions, etc.
99
+ - Intentionally promoting or propagating discriminatory content or harmful stereotypes.
100
+ - Impersonating individuals without their consent.
101
+ - Sexual content without consent of the people who might see it.
102
+ - Mis- and disinformation
103
+ - Representations of egregious violence and gore
104
+ - Sharing of copyrighted or licensed material in violation of its terms of use.
105
+ - Sharing content that is an alteration of copyrighted or licensed material in violation of its terms of use.
106
+
107
+ ## Limitations and Bias
108
+
109
+ ### Limitations
110
+
111
+ - The model does not achieve perfect photorealism
112
+ - The model cannot render legible text
113
+ - The model does not perform well on more difficult tasks which involve compositionality, such as rendering an image corresponding to “A red cube on top of a blue sphere”
114
+ - Faces and people in general may not be generated properly.
115
+ - The model was trained mainly with English captions and will not work as well in other languages.
116
+ - The autoencoding part of the model is lossy
117
+ - The model was trained on a large-scale dataset
118
+ [LAION-5B](https://laion.ai/blog/laion-5b/) which contains adult material
119
+ and is not fit for product use without additional safety mechanisms and
120
+ considerations.
121
+ - No additional measures were used to deduplicate the dataset. As a result, we observe some degree of memorization for images that are duplicated in the training data.
122
+ The training data can be searched at [https://rom1504.github.io/clip-retrieval/](https://rom1504.github.io/clip-retrieval/) to possibly assist in the detection of memorized images.
123
+
124
+ ### Bias
125
+
126
+ While the capabilities of image generation models are impressive, they can also reinforce or exacerbate social biases.
127
+ Stable Diffusion v1 was trained on subsets of [LAION-2B(en)](https://laion.ai/blog/laion-5b/),
128
+ which consists of images that are primarily limited to English descriptions.
129
+ Texts and images from communities and cultures that use other languages are likely to be insufficiently accounted for.
130
+ This affects the overall output of the model, as white and western cultures are often set as the default. Further, the
131
+ ability of the model to generate content with non-English prompts is significantly worse than with English-language prompts.
132
+
133
+ ### Safety Module
134
+
135
+ The intended use of this model is with the [Safety Checker](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py) in Diffusers.
136
+ This checker works by checking model outputs against known hard-coded NSFW concepts.
137
+ The concepts are intentionally hidden to reduce the likelihood of reverse-engineering this filter.
138
+ Specifically, the checker compares the class probability of harmful concepts in the embedding space of the `CLIPTextModel` *after generation* of the images.
139
+ The concepts are passed into the model with the generated image and compared to a hand-engineered weight for each NSFW concept.
140
+
141
+
142
+ ## Training
143
+
144
+ **Training Data**
145
+ The model developers used the following dataset for training the model:
146
+
147
+ - LAION-2B (en) and subsets thereof (see next section)
148
+
149
+ **Training Procedure**
150
+ Stable Diffusion v1-5 is a latent diffusion model which combines an autoencoder with a diffusion model that is trained in the latent space of the autoencoder. During training,
151
+
152
+ - Images are encoded through an encoder, which turns images into latent representations. The autoencoder uses a relative downsampling factor of 8 and maps images of shape H x W x 3 to latents of shape H/f x W/f x 4
153
+ - Text prompts are encoded through a ViT-L/14 text-encoder.
154
+ - The non-pooled output of the text encoder is fed into the UNet backbone of the latent diffusion model via cross-attention.
155
+ - The loss is a reconstruction objective between the noise that was added to the latent and the prediction made by the UNet.
156
+
157
+ Currently six Stable Diffusion checkpoints are provided, which were trained as follows.
158
+ - [`stable-diffusion-v1-1`](https://huggingface.co/CompVis/stable-diffusion-v1-1): 237,000 steps at resolution `256x256` on [laion2B-en](https://huggingface.co/datasets/laion/laion2B-en).
159
+ 194,000 steps at resolution `512x512` on [laion-high-resolution](https://huggingface.co/datasets/laion/laion-high-resolution) (170M examples from LAION-5B with resolution `>= 1024x1024`).
160
+ - [`stable-diffusion-v1-2`](https://huggingface.co/CompVis/stable-diffusion-v1-2): Resumed from `stable-diffusion-v1-1`.
161
+ 515,000 steps at resolution `512x512` on "laion-improved-aesthetics" (a subset of laion2B-en,
162
+ filtered to images with an original size `>= 512x512`, estimated aesthetics score `> 5.0`, and an estimated watermark probability `< 0.5`. The watermark estimate is from the LAION-5B metadata, the aesthetics score is estimated using an [improved aesthetics estimator](https://github.com/christophschuhmann/improved-aesthetic-predictor)).
163
+ - [`stable-diffusion-v1-3`](https://huggingface.co/CompVis/stable-diffusion-v1-3): Resumed from `stable-diffusion-v1-2` - 195,000 steps at resolution `512x512` on "laion-improved-aesthetics" and 10 % dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598).
164
+ - [`stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4) Resumed from `stable-diffusion-v1-2` - 225,000 steps at resolution `512x512` on "laion-aesthetics v2 5+" and 10 % dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598).
165
+ - [`stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) Resumed from `stable-diffusion-v1-2` - 595,000 steps at resolution `512x512` on "laion-aesthetics v2 5+" and 10 % dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598).
166
+ - [`stable-diffusion-inpainting`](https://huggingface.co/runwayml/stable-diffusion-inpainting) Resumed from `stable-diffusion-v1-5` - then 440,000 steps of inpainting training at resolution 512x512 on “laion-aesthetics v2 5+” and 10% dropping of the text-conditioning. For inpainting, the UNet has 5 additional input channels (4 for the encoded masked-image and 1 for the mask itself) whose weights were zero-initialized after restoring the non-inpainting checkpoint. During training, we generate synthetic masks and in 25% mask everything.
167
+
168
+ - **Hardware:** 32 x 8 x A100 GPUs
169
+ - **Optimizer:** AdamW
170
+ - **Gradient Accumulations**: 2
171
+ - **Batch:** 32 x 8 x 2 x 4 = 2048
172
+ - **Learning rate:** warmup to 0.0001 for 10,000 steps and then kept constant
173
+
174
+ ## Evaluation Results
175
+ Evaluations with different classifier-free guidance scales (1.5, 2.0, 3.0, 4.0,
176
+ 5.0, 6.0, 7.0, 8.0) and 50 PNDM/PLMS sampling
177
+ steps show the relative improvements of the checkpoints:
178
+
179
+ ![pareto](https://huggingface.co/CompVis/stable-diffusion/resolve/main/v1-1-to-v1-5.png)
180
+
181
+ Evaluated using 50 PLMS steps and 10000 random prompts from the COCO2017 validation set, evaluated at 512x512 resolution. Not optimized for FID scores.
182
+ ## Environmental Impact
183
+
184
+ **Stable Diffusion v1** **Estimated Emissions**
185
+ Based on that information, we estimate the following CO2 emissions using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700). The hardware, runtime, cloud provider, and compute region were utilized to estimate the carbon impact.
186
+
187
+ - **Hardware Type:** A100 PCIe 40GB
188
+ - **Hours used:** 150000
189
+ - **Cloud Provider:** AWS
190
+ - **Compute Region:** US-east
191
+ - **Carbon Emitted (Power consumption x Time x Carbon produced based on location of power grid):** 11250 kg CO2 eq.
192
+
193
+
194
+ ## Citation
195
+
196
+ ```bibtex
197
+ @InProceedings{Rombach_2022_CVPR,
198
+ author = {Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bj\"orn},
199
+ title = {High-Resolution Image Synthesis With Latent Diffusion Models},
200
+ booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
201
+ month = {June},
202
+ year = {2022},
203
+ pages = {10684-10695}
204
+ }
205
+ ```
206
+
207
+ *This model card was written by: Robin Rombach and Patrick Esser and is based on the [DALL-E Mini model card](https://huggingface.co/dalle-mini/dalle-mini).*
models/StableDiffusion/stable-diffusion-v1-5/feature_extractor/preprocessor_config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": 224,
3
+ "do_center_crop": true,
4
+ "do_convert_rgb": true,
5
+ "do_normalize": true,
6
+ "do_resize": true,
7
+ "feature_extractor_type": "CLIPFeatureExtractor",
8
+ "image_mean": [
9
+ 0.48145466,
10
+ 0.4578275,
11
+ 0.40821073
12
+ ],
13
+ "image_std": [
14
+ 0.26862954,
15
+ 0.26130258,
16
+ 0.27577711
17
+ ],
18
+ "resample": 3,
19
+ "size": 224
20
+ }
models/StableDiffusion/stable-diffusion-v1-5/model_index.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "StableDiffusionPipeline",
3
+ "_diffusers_version": "0.6.0",
4
+ "feature_extractor": [
5
+ "transformers",
6
+ "CLIPImageProcessor"
7
+ ],
8
+ "safety_checker": [
9
+ "stable_diffusion",
10
+ "StableDiffusionSafetyChecker"
11
+ ],
12
+ "scheduler": [
13
+ "diffusers",
14
+ "PNDMScheduler"
15
+ ],
16
+ "text_encoder": [
17
+ "transformers",
18
+ "CLIPTextModel"
19
+ ],
20
+ "tokenizer": [
21
+ "transformers",
22
+ "CLIPTokenizer"
23
+ ],
24
+ "unet": [
25
+ "diffusers",
26
+ "UNet2DConditionModel"
27
+ ],
28
+ "vae": [
29
+ "diffusers",
30
+ "AutoencoderKL"
31
+ ]
32
+ }
models/StableDiffusion/stable-diffusion-v1-5/safety_checker/config.json ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_commit_hash": "4bb648a606ef040e7685bde262611766a5fdd67b",
3
+ "_name_or_path": "CompVis/stable-diffusion-safety-checker",
4
+ "architectures": [
5
+ "StableDiffusionSafetyChecker"
6
+ ],
7
+ "initializer_factor": 1.0,
8
+ "logit_scale_init_value": 2.6592,
9
+ "model_type": "clip",
10
+ "projection_dim": 768,
11
+ "text_config": {
12
+ "_name_or_path": "",
13
+ "add_cross_attention": false,
14
+ "architectures": null,
15
+ "attention_dropout": 0.0,
16
+ "bad_words_ids": null,
17
+ "bos_token_id": 0,
18
+ "chunk_size_feed_forward": 0,
19
+ "cross_attention_hidden_size": null,
20
+ "decoder_start_token_id": null,
21
+ "diversity_penalty": 0.0,
22
+ "do_sample": false,
23
+ "dropout": 0.0,
24
+ "early_stopping": false,
25
+ "encoder_no_repeat_ngram_size": 0,
26
+ "eos_token_id": 2,
27
+ "exponential_decay_length_penalty": null,
28
+ "finetuning_task": null,
29
+ "forced_bos_token_id": null,
30
+ "forced_eos_token_id": null,
31
+ "hidden_act": "quick_gelu",
32
+ "hidden_size": 768,
33
+ "id2label": {
34
+ "0": "LABEL_0",
35
+ "1": "LABEL_1"
36
+ },
37
+ "initializer_factor": 1.0,
38
+ "initializer_range": 0.02,
39
+ "intermediate_size": 3072,
40
+ "is_decoder": false,
41
+ "is_encoder_decoder": false,
42
+ "label2id": {
43
+ "LABEL_0": 0,
44
+ "LABEL_1": 1
45
+ },
46
+ "layer_norm_eps": 1e-05,
47
+ "length_penalty": 1.0,
48
+ "max_length": 20,
49
+ "max_position_embeddings": 77,
50
+ "min_length": 0,
51
+ "model_type": "clip_text_model",
52
+ "no_repeat_ngram_size": 0,
53
+ "num_attention_heads": 12,
54
+ "num_beam_groups": 1,
55
+ "num_beams": 1,
56
+ "num_hidden_layers": 12,
57
+ "num_return_sequences": 1,
58
+ "output_attentions": false,
59
+ "output_hidden_states": false,
60
+ "output_scores": false,
61
+ "pad_token_id": 1,
62
+ "prefix": null,
63
+ "problem_type": null,
64
+ "pruned_heads": {},
65
+ "remove_invalid_values": false,
66
+ "repetition_penalty": 1.0,
67
+ "return_dict": true,
68
+ "return_dict_in_generate": false,
69
+ "sep_token_id": null,
70
+ "task_specific_params": null,
71
+ "temperature": 1.0,
72
+ "tf_legacy_loss": false,
73
+ "tie_encoder_decoder": false,
74
+ "tie_word_embeddings": true,
75
+ "tokenizer_class": null,
76
+ "top_k": 50,
77
+ "top_p": 1.0,
78
+ "torch_dtype": null,
79
+ "torchscript": false,
80
+ "transformers_version": "4.22.0.dev0",
81
+ "typical_p": 1.0,
82
+ "use_bfloat16": false,
83
+ "vocab_size": 49408
84
+ },
85
+ "text_config_dict": {
86
+ "hidden_size": 768,
87
+ "intermediate_size": 3072,
88
+ "num_attention_heads": 12,
89
+ "num_hidden_layers": 12
90
+ },
91
+ "torch_dtype": "float32",
92
+ "transformers_version": null,
93
+ "vision_config": {
94
+ "_name_or_path": "",
95
+ "add_cross_attention": false,
96
+ "architectures": null,
97
+ "attention_dropout": 0.0,
98
+ "bad_words_ids": null,
99
+ "bos_token_id": null,
100
+ "chunk_size_feed_forward": 0,
101
+ "cross_attention_hidden_size": null,
102
+ "decoder_start_token_id": null,
103
+ "diversity_penalty": 0.0,
104
+ "do_sample": false,
105
+ "dropout": 0.0,
106
+ "early_stopping": false,
107
+ "encoder_no_repeat_ngram_size": 0,
108
+ "eos_token_id": null,
109
+ "exponential_decay_length_penalty": null,
110
+ "finetuning_task": null,
111
+ "forced_bos_token_id": null,
112
+ "forced_eos_token_id": null,
113
+ "hidden_act": "quick_gelu",
114
+ "hidden_size": 1024,
115
+ "id2label": {
116
+ "0": "LABEL_0",
117
+ "1": "LABEL_1"
118
+ },
119
+ "image_size": 224,
120
+ "initializer_factor": 1.0,
121
+ "initializer_range": 0.02,
122
+ "intermediate_size": 4096,
123
+ "is_decoder": false,
124
+ "is_encoder_decoder": false,
125
+ "label2id": {
126
+ "LABEL_0": 0,
127
+ "LABEL_1": 1
128
+ },
129
+ "layer_norm_eps": 1e-05,
130
+ "length_penalty": 1.0,
131
+ "max_length": 20,
132
+ "min_length": 0,
133
+ "model_type": "clip_vision_model",
134
+ "no_repeat_ngram_size": 0,
135
+ "num_attention_heads": 16,
136
+ "num_beam_groups": 1,
137
+ "num_beams": 1,
138
+ "num_channels": 3,
139
+ "num_hidden_layers": 24,
140
+ "num_return_sequences": 1,
141
+ "output_attentions": false,
142
+ "output_hidden_states": false,
143
+ "output_scores": false,
144
+ "pad_token_id": null,
145
+ "patch_size": 14,
146
+ "prefix": null,
147
+ "problem_type": null,
148
+ "pruned_heads": {},
149
+ "remove_invalid_values": false,
150
+ "repetition_penalty": 1.0,
151
+ "return_dict": true,
152
+ "return_dict_in_generate": false,
153
+ "sep_token_id": null,
154
+ "task_specific_params": null,
155
+ "temperature": 1.0,
156
+ "tf_legacy_loss": false,
157
+ "tie_encoder_decoder": false,
158
+ "tie_word_embeddings": true,
159
+ "tokenizer_class": null,
160
+ "top_k": 50,
161
+ "top_p": 1.0,
162
+ "torch_dtype": null,
163
+ "torchscript": false,
164
+ "transformers_version": "4.22.0.dev0",
165
+ "typical_p": 1.0,
166
+ "use_bfloat16": false
167
+ },
168
+ "vision_config_dict": {
169
+ "hidden_size": 1024,
170
+ "intermediate_size": 4096,
171
+ "num_attention_heads": 16,
172
+ "num_hidden_layers": 24,
173
+ "patch_size": 14
174
+ }
175
+ }
models/StableDiffusion/stable-diffusion-v1-5/scheduler/scheduler_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "PNDMScheduler",
3
+ "_diffusers_version": "0.6.0",
4
+ "beta_end": 0.012,
5
+ "beta_schedule": "scaled_linear",
6
+ "beta_start": 0.00085,
7
+ "num_train_timesteps": 1000,
8
+ "set_alpha_to_one": false,
9
+ "skip_prk_steps": true,
10
+ "steps_offset": 1,
11
+ "trained_betas": null,
12
+ "clip_sample": false
13
+ }
models/StableDiffusion/stable-diffusion-v1-5/text_encoder/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "openai/clip-vit-large-patch14",
3
+ "architectures": [
4
+ "CLIPTextModel"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 0,
8
+ "dropout": 0.0,
9
+ "eos_token_id": 2,
10
+ "hidden_act": "quick_gelu",
11
+ "hidden_size": 768,
12
+ "initializer_factor": 1.0,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 3072,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 77,
17
+ "model_type": "clip_text_model",
18
+ "num_attention_heads": 12,
19
+ "num_hidden_layers": 12,
20
+ "pad_token_id": 1,
21
+ "projection_dim": 768,
22
+ "torch_dtype": "float32",
23
+ "transformers_version": "4.22.0.dev0",
24
+ "vocab_size": 49408
25
+ }
models/StableDiffusion/stable-diffusion-v1-5/text_encoder/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d008943c017f0092921106440254dbbe00b6a285f7883ec8ba160c3faad88334
3
+ size 492265874
models/StableDiffusion/stable-diffusion-v1-5/tokenizer/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
models/StableDiffusion/stable-diffusion-v1-5/tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|startoftext|>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "<|endoftext|>",
17
+ "unk_token": {
18
+ "content": "<|endoftext|>",
19
+ "lstrip": false,
20
+ "normalized": true,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
models/StableDiffusion/stable-diffusion-v1-5/tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "bos_token": {
4
+ "__type": "AddedToken",
5
+ "content": "<|startoftext|>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false
10
+ },
11
+ "do_lower_case": true,
12
+ "eos_token": {
13
+ "__type": "AddedToken",
14
+ "content": "<|endoftext|>",
15
+ "lstrip": false,
16
+ "normalized": true,
17
+ "rstrip": false,
18
+ "single_word": false
19
+ },
20
+ "errors": "replace",
21
+ "model_max_length": 77,
22
+ "name_or_path": "openai/clip-vit-large-patch14",
23
+ "pad_token": "<|endoftext|>",
24
+ "special_tokens_map_file": "./special_tokens_map.json",
25
+ "tokenizer_class": "CLIPTokenizer",
26
+ "unk_token": {
27
+ "__type": "AddedToken",
28
+ "content": "<|endoftext|>",
29
+ "lstrip": false,
30
+ "normalized": true,
31
+ "rstrip": false,
32
+ "single_word": false
33
+ }
34
+ }
models/StableDiffusion/stable-diffusion-v1-5/tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
models/StableDiffusion/stable-diffusion-v1-5/unet/config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "UNet2DConditionModel",
3
+ "_diffusers_version": "0.6.0",
4
+ "act_fn": "silu",
5
+ "attention_head_dim": 8,
6
+ "block_out_channels": [
7
+ 320,
8
+ 640,
9
+ 1280,
10
+ 1280
11
+ ],
12
+ "center_input_sample": false,
13
+ "cross_attention_dim": 768,
14
+ "down_block_types": [
15
+ "CrossAttnDownBlock2D",
16
+ "CrossAttnDownBlock2D",
17
+ "CrossAttnDownBlock2D",
18
+ "DownBlock2D"
19
+ ],
20
+ "downsample_padding": 1,
21
+ "flip_sin_to_cos": true,
22
+ "freq_shift": 0,
23
+ "in_channels": 4,
24
+ "layers_per_block": 2,
25
+ "mid_block_scale_factor": 1,
26
+ "norm_eps": 1e-05,
27
+ "norm_num_groups": 32,
28
+ "out_channels": 4,
29
+ "sample_size": 64,
30
+ "up_block_types": [
31
+ "UpBlock2D",
32
+ "CrossAttnUpBlock2D",
33
+ "CrossAttnUpBlock2D",
34
+ "CrossAttnUpBlock2D"
35
+ ]
36
+ }
models/StableDiffusion/stable-diffusion-v1-5/unet/diffusion_pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7da0e21ba7ea50637bee26e81c220844defdf01aafca02b2c42ecdadb813de4
3
+ size 3438354725
models/StableDiffusion/stable-diffusion-v1-5/v1-inference.yaml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-04
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.0120
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: "jpg"
11
+ cond_stage_key: "txt"
12
+ image_size: 64
13
+ channels: 4
14
+ cond_stage_trainable: false # Note: different from the one we trained before
15
+ conditioning_key: crossattn
16
+ monitor: val/loss_simple_ema
17
+ scale_factor: 0.18215
18
+ use_ema: False
19
+
20
+ scheduler_config: # 10000 warmup steps
21
+ target: ldm.lr_scheduler.LambdaLinearScheduler
22
+ params:
23
+ warm_up_steps: [ 10000 ]
24
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25
+ f_start: [ 1.e-6 ]
26
+ f_max: [ 1. ]
27
+ f_min: [ 1. ]
28
+
29
+ unet_config:
30
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31
+ params:
32
+ image_size: 32 # unused
33
+ in_channels: 4
34
+ out_channels: 4
35
+ model_channels: 320
36
+ attention_resolutions: [ 4, 2, 1 ]
37
+ num_res_blocks: 2
38
+ channel_mult: [ 1, 2, 4, 4 ]
39
+ num_heads: 8
40
+ use_spatial_transformer: True
41
+ transformer_depth: 1
42
+ context_dim: 768
43
+ use_checkpoint: True
44
+ legacy: False
45
+
46
+ first_stage_config:
47
+ target: ldm.models.autoencoder.AutoencoderKL
48
+ params:
49
+ embed_dim: 4
50
+ monitor: val/rec_loss
51
+ ddconfig:
52
+ double_z: true
53
+ z_channels: 4
54
+ resolution: 256
55
+ in_channels: 3
56
+ out_ch: 3
57
+ ch: 128
58
+ ch_mult:
59
+ - 1
60
+ - 2
61
+ - 4
62
+ - 4
63
+ num_res_blocks: 2
64
+ attn_resolutions: []
65
+ dropout: 0.0
66
+ lossconfig:
67
+ target: torch.nn.Identity
68
+
69
+ cond_stage_config:
70
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
models/StableDiffusion/stable-diffusion-v1-5/vae/config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.6.0",
4
+ "act_fn": "silu",
5
+ "block_out_channels": [
6
+ 128,
7
+ 256,
8
+ 512,
9
+ 512
10
+ ],
11
+ "down_block_types": [
12
+ "DownEncoderBlock2D",
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D"
16
+ ],
17
+ "in_channels": 3,
18
+ "latent_channels": 4,
19
+ "layers_per_block": 2,
20
+ "norm_num_groups": 32,
21
+ "out_channels": 3,
22
+ "sample_size": 512,
23
+ "up_block_types": [
24
+ "UpDecoderBlock2D",
25
+ "UpDecoderBlock2D",
26
+ "UpDecoderBlock2D",
27
+ "UpDecoderBlock2D"
28
+ ]
29
+ }
models/StableDiffusion/stable-diffusion-v1-5/vae/diffusion_pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1b134cded8eb78b184aefb8805b6b572f36fa77b255c483665dda931fa0130c5
3
+ size 334707217
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==1.13.1
2
+ torchvision==0.14.1
3
+ torchaudio==0.13.1
4
+ diffusers==0.11.1
5
+ transformers==4.25.1
6
+ xformers==0.0.16
7
+ imageio==2.27.0
8
+ gradio==3.48.0
9
+ gdown
10
+ einops
11
+ omegaconf
12
+ safetensors
13
+ imageio[ffmpeg]
14
+ imageio[pyav]
15
+ accelerate