Spaces:
Sleeping
Sleeping
wtx-mmlab
commited on
Commit
·
069c5f0
1
Parent(s):
661ec85
init commit
Browse files- animatediff/data/dataset.py +98 -0
- animatediff/models/attention.py +300 -0
- animatediff/models/motion_module.py +331 -0
- animatediff/models/resnet.py +217 -0
- animatediff/models/unet.py +497 -0
- animatediff/models/unet_blocks.py +760 -0
- animatediff/pipelines/pipeline_animation.py +656 -0
- animatediff/utils/convert_from_ckpt.py +959 -0
- animatediff/utils/convert_lora_safetensor_to_diffusers.py +154 -0
- animatediff/utils/freeinit_utils.py +140 -0
- animatediff/utils/util.py +157 -0
- app.py +488 -0
- configs/inference/inference-v1.yaml +26 -0
- configs/inference/inference-v2.yaml +27 -0
- configs/prompts/1-ToonYou.yaml +23 -0
- configs/prompts/2-Lyriel.yaml +23 -0
- configs/prompts/3-RcnzCartoon.yaml +23 -0
- configs/prompts/4-MajicMix.yaml +23 -0
- configs/prompts/5-RealisticVision.yaml +23 -0
- configs/prompts/6-Tusun.yaml +21 -0
- configs/prompts/7-FilmVelvia.yaml +24 -0
- configs/prompts/8-GhibliBackground.yaml +21 -0
- configs/prompts/freeinit_examples/RcnzCartoon_v2.yaml +33 -0
- configs/prompts/freeinit_examples/RealisticVision_v1.yaml +32 -0
- configs/prompts/freeinit_examples/RealisticVision_v2.yaml +37 -0
- configs/prompts/v2/5-RealisticVision-MotionLoRA.yaml +189 -0
- configs/prompts/v2/5-RealisticVision.yaml +23 -0
- models/DreamBooth_LoRA/Put personalized T2I checkpoints here.txt +0 -0
- models/MotionLoRA/Put MotionLoRA checkpoints here.txt +0 -0
- models/Motion_Module/Put motion module checkpoints here.txt +0 -0
- models/StableDiffusion/Put diffusers stable-diffusion-v1-5 repo here.txt +0 -0
- requirements.txt +15 -0
animatediff/data/dataset.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, io, csv, math, random
|
2 |
+
import numpy as np
|
3 |
+
from einops import rearrange
|
4 |
+
from decord import VideoReader
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torchvision.transforms as transforms
|
8 |
+
from torch.utils.data.dataset import Dataset
|
9 |
+
from animatediff.utils.util import zero_rank_print
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
class WebVid10M(Dataset):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
csv_path, video_folder,
|
17 |
+
sample_size=256, sample_stride=4, sample_n_frames=16,
|
18 |
+
is_image=False,
|
19 |
+
):
|
20 |
+
zero_rank_print(f"loading annotations from {csv_path} ...")
|
21 |
+
with open(csv_path, 'r') as csvfile:
|
22 |
+
self.dataset = list(csv.DictReader(csvfile))
|
23 |
+
self.length = len(self.dataset)
|
24 |
+
zero_rank_print(f"data scale: {self.length}")
|
25 |
+
|
26 |
+
self.video_folder = video_folder
|
27 |
+
self.sample_stride = sample_stride
|
28 |
+
self.sample_n_frames = sample_n_frames
|
29 |
+
self.is_image = is_image
|
30 |
+
|
31 |
+
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
|
32 |
+
self.pixel_transforms = transforms.Compose([
|
33 |
+
transforms.RandomHorizontalFlip(),
|
34 |
+
transforms.Resize(sample_size[0]),
|
35 |
+
transforms.CenterCrop(sample_size),
|
36 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
37 |
+
])
|
38 |
+
|
39 |
+
def get_batch(self, idx):
|
40 |
+
video_dict = self.dataset[idx]
|
41 |
+
videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
|
42 |
+
|
43 |
+
video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
|
44 |
+
video_reader = VideoReader(video_dir)
|
45 |
+
video_length = len(video_reader)
|
46 |
+
|
47 |
+
if not self.is_image:
|
48 |
+
clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
|
49 |
+
start_idx = random.randint(0, video_length - clip_length)
|
50 |
+
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
|
51 |
+
else:
|
52 |
+
batch_index = [random.randint(0, video_length - 1)]
|
53 |
+
|
54 |
+
pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
|
55 |
+
pixel_values = pixel_values / 255.
|
56 |
+
del video_reader
|
57 |
+
|
58 |
+
if self.is_image:
|
59 |
+
pixel_values = pixel_values[0]
|
60 |
+
|
61 |
+
return pixel_values, name
|
62 |
+
|
63 |
+
def __len__(self):
|
64 |
+
return self.length
|
65 |
+
|
66 |
+
def __getitem__(self, idx):
|
67 |
+
while True:
|
68 |
+
try:
|
69 |
+
pixel_values, name = self.get_batch(idx)
|
70 |
+
break
|
71 |
+
|
72 |
+
except Exception as e:
|
73 |
+
idx = random.randint(0, self.length-1)
|
74 |
+
|
75 |
+
pixel_values = self.pixel_transforms(pixel_values)
|
76 |
+
sample = dict(pixel_values=pixel_values, text=name)
|
77 |
+
return sample
|
78 |
+
|
79 |
+
|
80 |
+
|
81 |
+
if __name__ == "__main__":
|
82 |
+
from animatediff.utils.util import save_videos_grid
|
83 |
+
|
84 |
+
dataset = WebVid10M(
|
85 |
+
csv_path="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv",
|
86 |
+
video_folder="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val",
|
87 |
+
sample_size=256,
|
88 |
+
sample_stride=4, sample_n_frames=16,
|
89 |
+
is_image=True,
|
90 |
+
)
|
91 |
+
import pdb
|
92 |
+
pdb.set_trace()
|
93 |
+
|
94 |
+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=16,)
|
95 |
+
for idx, batch in enumerate(dataloader):
|
96 |
+
print(batch["pixel_values"].shape, len(batch["text"]))
|
97 |
+
# for i in range(batch["pixel_values"].shape[0]):
|
98 |
+
# save_videos_grid(batch["pixel_values"][i:i+1].permute(0,2,1,3,4), os.path.join(".", f"{idx}-{i}.mp4"), rescale=True)
|
animatediff/models/attention.py
ADDED
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
|
2 |
+
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch import nn
|
9 |
+
|
10 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
11 |
+
from diffusers.modeling_utils import ModelMixin
|
12 |
+
from diffusers.utils import BaseOutput
|
13 |
+
from diffusers.utils.import_utils import is_xformers_available
|
14 |
+
from diffusers.models.attention import CrossAttention, FeedForward, AdaLayerNorm
|
15 |
+
|
16 |
+
from einops import rearrange, repeat
|
17 |
+
import pdb
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class Transformer3DModelOutput(BaseOutput):
|
21 |
+
sample: torch.FloatTensor
|
22 |
+
|
23 |
+
|
24 |
+
if is_xformers_available():
|
25 |
+
import xformers
|
26 |
+
import xformers.ops
|
27 |
+
else:
|
28 |
+
xformers = None
|
29 |
+
|
30 |
+
|
31 |
+
class Transformer3DModel(ModelMixin, ConfigMixin):
|
32 |
+
@register_to_config
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
num_attention_heads: int = 16,
|
36 |
+
attention_head_dim: int = 88,
|
37 |
+
in_channels: Optional[int] = None,
|
38 |
+
num_layers: int = 1,
|
39 |
+
dropout: float = 0.0,
|
40 |
+
norm_num_groups: int = 32,
|
41 |
+
cross_attention_dim: Optional[int] = None,
|
42 |
+
attention_bias: bool = False,
|
43 |
+
activation_fn: str = "geglu",
|
44 |
+
num_embeds_ada_norm: Optional[int] = None,
|
45 |
+
use_linear_projection: bool = False,
|
46 |
+
only_cross_attention: bool = False,
|
47 |
+
upcast_attention: bool = False,
|
48 |
+
|
49 |
+
unet_use_cross_frame_attention=None,
|
50 |
+
unet_use_temporal_attention=None,
|
51 |
+
):
|
52 |
+
super().__init__()
|
53 |
+
self.use_linear_projection = use_linear_projection
|
54 |
+
self.num_attention_heads = num_attention_heads
|
55 |
+
self.attention_head_dim = attention_head_dim
|
56 |
+
inner_dim = num_attention_heads * attention_head_dim
|
57 |
+
|
58 |
+
# Define input layers
|
59 |
+
self.in_channels = in_channels
|
60 |
+
|
61 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
62 |
+
if use_linear_projection:
|
63 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
64 |
+
else:
|
65 |
+
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
66 |
+
|
67 |
+
# Define transformers blocks
|
68 |
+
self.transformer_blocks = nn.ModuleList(
|
69 |
+
[
|
70 |
+
BasicTransformerBlock(
|
71 |
+
inner_dim,
|
72 |
+
num_attention_heads,
|
73 |
+
attention_head_dim,
|
74 |
+
dropout=dropout,
|
75 |
+
cross_attention_dim=cross_attention_dim,
|
76 |
+
activation_fn=activation_fn,
|
77 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
78 |
+
attention_bias=attention_bias,
|
79 |
+
only_cross_attention=only_cross_attention,
|
80 |
+
upcast_attention=upcast_attention,
|
81 |
+
|
82 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
83 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
84 |
+
)
|
85 |
+
for d in range(num_layers)
|
86 |
+
]
|
87 |
+
)
|
88 |
+
|
89 |
+
# 4. Define output layers
|
90 |
+
if use_linear_projection:
|
91 |
+
self.proj_out = nn.Linear(in_channels, inner_dim)
|
92 |
+
else:
|
93 |
+
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
94 |
+
|
95 |
+
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
|
96 |
+
# Input
|
97 |
+
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
|
98 |
+
video_length = hidden_states.shape[2]
|
99 |
+
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
|
100 |
+
encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
|
101 |
+
|
102 |
+
batch, channel, height, weight = hidden_states.shape
|
103 |
+
residual = hidden_states
|
104 |
+
|
105 |
+
hidden_states = self.norm(hidden_states)
|
106 |
+
if not self.use_linear_projection:
|
107 |
+
hidden_states = self.proj_in(hidden_states)
|
108 |
+
inner_dim = hidden_states.shape[1]
|
109 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
110 |
+
else:
|
111 |
+
inner_dim = hidden_states.shape[1]
|
112 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
113 |
+
hidden_states = self.proj_in(hidden_states)
|
114 |
+
|
115 |
+
# Blocks
|
116 |
+
for block in self.transformer_blocks:
|
117 |
+
hidden_states = block(
|
118 |
+
hidden_states,
|
119 |
+
encoder_hidden_states=encoder_hidden_states,
|
120 |
+
timestep=timestep,
|
121 |
+
video_length=video_length
|
122 |
+
)
|
123 |
+
|
124 |
+
# Output
|
125 |
+
if not self.use_linear_projection:
|
126 |
+
hidden_states = (
|
127 |
+
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
128 |
+
)
|
129 |
+
hidden_states = self.proj_out(hidden_states)
|
130 |
+
else:
|
131 |
+
hidden_states = self.proj_out(hidden_states)
|
132 |
+
hidden_states = (
|
133 |
+
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
134 |
+
)
|
135 |
+
|
136 |
+
output = hidden_states + residual
|
137 |
+
|
138 |
+
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
|
139 |
+
if not return_dict:
|
140 |
+
return (output,)
|
141 |
+
|
142 |
+
return Transformer3DModelOutput(sample=output)
|
143 |
+
|
144 |
+
|
145 |
+
class BasicTransformerBlock(nn.Module):
|
146 |
+
def __init__(
|
147 |
+
self,
|
148 |
+
dim: int,
|
149 |
+
num_attention_heads: int,
|
150 |
+
attention_head_dim: int,
|
151 |
+
dropout=0.0,
|
152 |
+
cross_attention_dim: Optional[int] = None,
|
153 |
+
activation_fn: str = "geglu",
|
154 |
+
num_embeds_ada_norm: Optional[int] = None,
|
155 |
+
attention_bias: bool = False,
|
156 |
+
only_cross_attention: bool = False,
|
157 |
+
upcast_attention: bool = False,
|
158 |
+
|
159 |
+
unet_use_cross_frame_attention = None,
|
160 |
+
unet_use_temporal_attention = None,
|
161 |
+
):
|
162 |
+
super().__init__()
|
163 |
+
self.only_cross_attention = only_cross_attention
|
164 |
+
self.use_ada_layer_norm = num_embeds_ada_norm is not None
|
165 |
+
self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
|
166 |
+
self.unet_use_temporal_attention = unet_use_temporal_attention
|
167 |
+
|
168 |
+
# SC-Attn
|
169 |
+
assert unet_use_cross_frame_attention is not None
|
170 |
+
if unet_use_cross_frame_attention:
|
171 |
+
self.attn1 = SparseCausalAttention2D(
|
172 |
+
query_dim=dim,
|
173 |
+
heads=num_attention_heads,
|
174 |
+
dim_head=attention_head_dim,
|
175 |
+
dropout=dropout,
|
176 |
+
bias=attention_bias,
|
177 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
178 |
+
upcast_attention=upcast_attention,
|
179 |
+
)
|
180 |
+
else:
|
181 |
+
self.attn1 = CrossAttention(
|
182 |
+
query_dim=dim,
|
183 |
+
heads=num_attention_heads,
|
184 |
+
dim_head=attention_head_dim,
|
185 |
+
dropout=dropout,
|
186 |
+
bias=attention_bias,
|
187 |
+
upcast_attention=upcast_attention,
|
188 |
+
)
|
189 |
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
190 |
+
|
191 |
+
# Cross-Attn
|
192 |
+
if cross_attention_dim is not None:
|
193 |
+
self.attn2 = CrossAttention(
|
194 |
+
query_dim=dim,
|
195 |
+
cross_attention_dim=cross_attention_dim,
|
196 |
+
heads=num_attention_heads,
|
197 |
+
dim_head=attention_head_dim,
|
198 |
+
dropout=dropout,
|
199 |
+
bias=attention_bias,
|
200 |
+
upcast_attention=upcast_attention,
|
201 |
+
)
|
202 |
+
else:
|
203 |
+
self.attn2 = None
|
204 |
+
|
205 |
+
if cross_attention_dim is not None:
|
206 |
+
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
207 |
+
else:
|
208 |
+
self.norm2 = None
|
209 |
+
|
210 |
+
# Feed-forward
|
211 |
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
|
212 |
+
self.norm3 = nn.LayerNorm(dim)
|
213 |
+
|
214 |
+
# Temp-Attn
|
215 |
+
assert unet_use_temporal_attention is not None
|
216 |
+
if unet_use_temporal_attention:
|
217 |
+
self.attn_temp = CrossAttention(
|
218 |
+
query_dim=dim,
|
219 |
+
heads=num_attention_heads,
|
220 |
+
dim_head=attention_head_dim,
|
221 |
+
dropout=dropout,
|
222 |
+
bias=attention_bias,
|
223 |
+
upcast_attention=upcast_attention,
|
224 |
+
)
|
225 |
+
nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
|
226 |
+
self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
227 |
+
|
228 |
+
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
229 |
+
if not is_xformers_available():
|
230 |
+
print("Here is how to install it")
|
231 |
+
raise ModuleNotFoundError(
|
232 |
+
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
233 |
+
" xformers",
|
234 |
+
name="xformers",
|
235 |
+
)
|
236 |
+
elif not torch.cuda.is_available():
|
237 |
+
raise ValueError(
|
238 |
+
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
|
239 |
+
" available for GPU "
|
240 |
+
)
|
241 |
+
else:
|
242 |
+
try:
|
243 |
+
# Make sure we can run the memory efficient attention
|
244 |
+
_ = xformers.ops.memory_efficient_attention(
|
245 |
+
torch.randn((1, 2, 40), device="cuda"),
|
246 |
+
torch.randn((1, 2, 40), device="cuda"),
|
247 |
+
torch.randn((1, 2, 40), device="cuda"),
|
248 |
+
)
|
249 |
+
except Exception as e:
|
250 |
+
raise e
|
251 |
+
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
252 |
+
if self.attn2 is not None:
|
253 |
+
self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
254 |
+
# self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
255 |
+
|
256 |
+
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None):
|
257 |
+
# SparseCausal-Attention
|
258 |
+
norm_hidden_states = (
|
259 |
+
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
|
260 |
+
)
|
261 |
+
|
262 |
+
# if self.only_cross_attention:
|
263 |
+
# hidden_states = (
|
264 |
+
# self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
|
265 |
+
# )
|
266 |
+
# else:
|
267 |
+
# hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
|
268 |
+
|
269 |
+
# pdb.set_trace()
|
270 |
+
if self.unet_use_cross_frame_attention:
|
271 |
+
hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
|
272 |
+
else:
|
273 |
+
hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
|
274 |
+
|
275 |
+
if self.attn2 is not None:
|
276 |
+
# Cross-Attention
|
277 |
+
norm_hidden_states = (
|
278 |
+
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
279 |
+
)
|
280 |
+
hidden_states = (
|
281 |
+
self.attn2(
|
282 |
+
norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
|
283 |
+
)
|
284 |
+
+ hidden_states
|
285 |
+
)
|
286 |
+
|
287 |
+
# Feed-forward
|
288 |
+
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
289 |
+
|
290 |
+
# Temporal-Attention
|
291 |
+
if self.unet_use_temporal_attention:
|
292 |
+
d = hidden_states.shape[1]
|
293 |
+
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
|
294 |
+
norm_hidden_states = (
|
295 |
+
self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
|
296 |
+
)
|
297 |
+
hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
|
298 |
+
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
|
299 |
+
|
300 |
+
return hidden_states
|
animatediff/models/motion_module.py
ADDED
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import List, Optional, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch import nn
|
8 |
+
import torchvision
|
9 |
+
|
10 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
11 |
+
from diffusers.modeling_utils import ModelMixin
|
12 |
+
from diffusers.utils import BaseOutput
|
13 |
+
from diffusers.utils.import_utils import is_xformers_available
|
14 |
+
from diffusers.models.attention import CrossAttention, FeedForward
|
15 |
+
|
16 |
+
from einops import rearrange, repeat
|
17 |
+
import math
|
18 |
+
|
19 |
+
|
20 |
+
def zero_module(module):
|
21 |
+
# Zero out the parameters of a module and return it.
|
22 |
+
for p in module.parameters():
|
23 |
+
p.detach().zero_()
|
24 |
+
return module
|
25 |
+
|
26 |
+
|
27 |
+
@dataclass
|
28 |
+
class TemporalTransformer3DModelOutput(BaseOutput):
|
29 |
+
sample: torch.FloatTensor
|
30 |
+
|
31 |
+
|
32 |
+
if is_xformers_available():
|
33 |
+
import xformers
|
34 |
+
import xformers.ops
|
35 |
+
else:
|
36 |
+
xformers = None
|
37 |
+
|
38 |
+
|
39 |
+
def get_motion_module(
|
40 |
+
in_channels,
|
41 |
+
motion_module_type: str,
|
42 |
+
motion_module_kwargs: dict
|
43 |
+
):
|
44 |
+
if motion_module_type == "Vanilla":
|
45 |
+
return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,)
|
46 |
+
else:
|
47 |
+
raise ValueError
|
48 |
+
|
49 |
+
|
50 |
+
class VanillaTemporalModule(nn.Module):
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
in_channels,
|
54 |
+
num_attention_heads = 8,
|
55 |
+
num_transformer_block = 2,
|
56 |
+
attention_block_types =( "Temporal_Self", "Temporal_Self" ),
|
57 |
+
cross_frame_attention_mode = None,
|
58 |
+
temporal_position_encoding = False,
|
59 |
+
temporal_position_encoding_max_len = 24,
|
60 |
+
temporal_attention_dim_div = 1,
|
61 |
+
zero_initialize = True,
|
62 |
+
):
|
63 |
+
super().__init__()
|
64 |
+
|
65 |
+
self.temporal_transformer = TemporalTransformer3DModel(
|
66 |
+
in_channels=in_channels,
|
67 |
+
num_attention_heads=num_attention_heads,
|
68 |
+
attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
|
69 |
+
num_layers=num_transformer_block,
|
70 |
+
attention_block_types=attention_block_types,
|
71 |
+
cross_frame_attention_mode=cross_frame_attention_mode,
|
72 |
+
temporal_position_encoding=temporal_position_encoding,
|
73 |
+
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
74 |
+
)
|
75 |
+
|
76 |
+
if zero_initialize:
|
77 |
+
self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
|
78 |
+
|
79 |
+
def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None):
|
80 |
+
hidden_states = input_tensor
|
81 |
+
hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
|
82 |
+
|
83 |
+
output = hidden_states
|
84 |
+
return output
|
85 |
+
|
86 |
+
|
87 |
+
class TemporalTransformer3DModel(nn.Module):
|
88 |
+
def __init__(
|
89 |
+
self,
|
90 |
+
in_channels,
|
91 |
+
num_attention_heads,
|
92 |
+
attention_head_dim,
|
93 |
+
|
94 |
+
num_layers,
|
95 |
+
attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
|
96 |
+
dropout = 0.0,
|
97 |
+
norm_num_groups = 32,
|
98 |
+
cross_attention_dim = 768,
|
99 |
+
activation_fn = "geglu",
|
100 |
+
attention_bias = False,
|
101 |
+
upcast_attention = False,
|
102 |
+
|
103 |
+
cross_frame_attention_mode = None,
|
104 |
+
temporal_position_encoding = False,
|
105 |
+
temporal_position_encoding_max_len = 24,
|
106 |
+
):
|
107 |
+
super().__init__()
|
108 |
+
|
109 |
+
inner_dim = num_attention_heads * attention_head_dim
|
110 |
+
|
111 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
112 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
113 |
+
|
114 |
+
self.transformer_blocks = nn.ModuleList(
|
115 |
+
[
|
116 |
+
TemporalTransformerBlock(
|
117 |
+
dim=inner_dim,
|
118 |
+
num_attention_heads=num_attention_heads,
|
119 |
+
attention_head_dim=attention_head_dim,
|
120 |
+
attention_block_types=attention_block_types,
|
121 |
+
dropout=dropout,
|
122 |
+
norm_num_groups=norm_num_groups,
|
123 |
+
cross_attention_dim=cross_attention_dim,
|
124 |
+
activation_fn=activation_fn,
|
125 |
+
attention_bias=attention_bias,
|
126 |
+
upcast_attention=upcast_attention,
|
127 |
+
cross_frame_attention_mode=cross_frame_attention_mode,
|
128 |
+
temporal_position_encoding=temporal_position_encoding,
|
129 |
+
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
130 |
+
)
|
131 |
+
for d in range(num_layers)
|
132 |
+
]
|
133 |
+
)
|
134 |
+
self.proj_out = nn.Linear(inner_dim, in_channels)
|
135 |
+
|
136 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
137 |
+
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
|
138 |
+
video_length = hidden_states.shape[2]
|
139 |
+
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
|
140 |
+
|
141 |
+
batch, channel, height, weight = hidden_states.shape
|
142 |
+
residual = hidden_states
|
143 |
+
|
144 |
+
hidden_states = self.norm(hidden_states)
|
145 |
+
inner_dim = hidden_states.shape[1]
|
146 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
147 |
+
hidden_states = self.proj_in(hidden_states)
|
148 |
+
|
149 |
+
# Transformer Blocks
|
150 |
+
for block in self.transformer_blocks:
|
151 |
+
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length)
|
152 |
+
|
153 |
+
# output
|
154 |
+
hidden_states = self.proj_out(hidden_states)
|
155 |
+
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
156 |
+
|
157 |
+
output = hidden_states + residual
|
158 |
+
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
|
159 |
+
|
160 |
+
return output
|
161 |
+
|
162 |
+
|
163 |
+
class TemporalTransformerBlock(nn.Module):
|
164 |
+
def __init__(
|
165 |
+
self,
|
166 |
+
dim,
|
167 |
+
num_attention_heads,
|
168 |
+
attention_head_dim,
|
169 |
+
attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
|
170 |
+
dropout = 0.0,
|
171 |
+
norm_num_groups = 32,
|
172 |
+
cross_attention_dim = 768,
|
173 |
+
activation_fn = "geglu",
|
174 |
+
attention_bias = False,
|
175 |
+
upcast_attention = False,
|
176 |
+
cross_frame_attention_mode = None,
|
177 |
+
temporal_position_encoding = False,
|
178 |
+
temporal_position_encoding_max_len = 24,
|
179 |
+
):
|
180 |
+
super().__init__()
|
181 |
+
|
182 |
+
attention_blocks = []
|
183 |
+
norms = []
|
184 |
+
|
185 |
+
for block_name in attention_block_types:
|
186 |
+
attention_blocks.append(
|
187 |
+
VersatileAttention(
|
188 |
+
attention_mode=block_name.split("_")[0],
|
189 |
+
cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
|
190 |
+
|
191 |
+
query_dim=dim,
|
192 |
+
heads=num_attention_heads,
|
193 |
+
dim_head=attention_head_dim,
|
194 |
+
dropout=dropout,
|
195 |
+
bias=attention_bias,
|
196 |
+
upcast_attention=upcast_attention,
|
197 |
+
|
198 |
+
cross_frame_attention_mode=cross_frame_attention_mode,
|
199 |
+
temporal_position_encoding=temporal_position_encoding,
|
200 |
+
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
201 |
+
)
|
202 |
+
)
|
203 |
+
norms.append(nn.LayerNorm(dim))
|
204 |
+
|
205 |
+
self.attention_blocks = nn.ModuleList(attention_blocks)
|
206 |
+
self.norms = nn.ModuleList(norms)
|
207 |
+
|
208 |
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
|
209 |
+
self.ff_norm = nn.LayerNorm(dim)
|
210 |
+
|
211 |
+
|
212 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
|
213 |
+
for attention_block, norm in zip(self.attention_blocks, self.norms):
|
214 |
+
norm_hidden_states = norm(hidden_states)
|
215 |
+
hidden_states = attention_block(
|
216 |
+
norm_hidden_states,
|
217 |
+
encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
|
218 |
+
video_length=video_length,
|
219 |
+
) + hidden_states
|
220 |
+
|
221 |
+
hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
|
222 |
+
|
223 |
+
output = hidden_states
|
224 |
+
return output
|
225 |
+
|
226 |
+
|
227 |
+
class PositionalEncoding(nn.Module):
|
228 |
+
def __init__(
|
229 |
+
self,
|
230 |
+
d_model,
|
231 |
+
dropout = 0.,
|
232 |
+
max_len = 24
|
233 |
+
):
|
234 |
+
super().__init__()
|
235 |
+
self.dropout = nn.Dropout(p=dropout)
|
236 |
+
position = torch.arange(max_len).unsqueeze(1)
|
237 |
+
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
|
238 |
+
pe = torch.zeros(1, max_len, d_model)
|
239 |
+
pe[0, :, 0::2] = torch.sin(position * div_term)
|
240 |
+
pe[0, :, 1::2] = torch.cos(position * div_term)
|
241 |
+
self.register_buffer('pe', pe)
|
242 |
+
|
243 |
+
def forward(self, x):
|
244 |
+
x = x + self.pe[:, :x.size(1)]
|
245 |
+
return self.dropout(x)
|
246 |
+
|
247 |
+
|
248 |
+
class VersatileAttention(CrossAttention):
|
249 |
+
def __init__(
|
250 |
+
self,
|
251 |
+
attention_mode = None,
|
252 |
+
cross_frame_attention_mode = None,
|
253 |
+
temporal_position_encoding = False,
|
254 |
+
temporal_position_encoding_max_len = 24,
|
255 |
+
*args, **kwargs
|
256 |
+
):
|
257 |
+
super().__init__(*args, **kwargs)
|
258 |
+
assert attention_mode == "Temporal"
|
259 |
+
|
260 |
+
self.attention_mode = attention_mode
|
261 |
+
self.is_cross_attention = kwargs["cross_attention_dim"] is not None
|
262 |
+
|
263 |
+
self.pos_encoder = PositionalEncoding(
|
264 |
+
kwargs["query_dim"],
|
265 |
+
dropout=0.,
|
266 |
+
max_len=temporal_position_encoding_max_len
|
267 |
+
) if (temporal_position_encoding and attention_mode == "Temporal") else None
|
268 |
+
|
269 |
+
def extra_repr(self):
|
270 |
+
return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
|
271 |
+
|
272 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
|
273 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
274 |
+
|
275 |
+
if self.attention_mode == "Temporal":
|
276 |
+
d = hidden_states.shape[1]
|
277 |
+
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
|
278 |
+
|
279 |
+
if self.pos_encoder is not None:
|
280 |
+
hidden_states = self.pos_encoder(hidden_states)
|
281 |
+
|
282 |
+
encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states
|
283 |
+
else:
|
284 |
+
raise NotImplementedError
|
285 |
+
|
286 |
+
encoder_hidden_states = encoder_hidden_states
|
287 |
+
|
288 |
+
if self.group_norm is not None:
|
289 |
+
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
290 |
+
|
291 |
+
query = self.to_q(hidden_states)
|
292 |
+
dim = query.shape[-1]
|
293 |
+
query = self.reshape_heads_to_batch_dim(query)
|
294 |
+
|
295 |
+
if self.added_kv_proj_dim is not None:
|
296 |
+
raise NotImplementedError
|
297 |
+
|
298 |
+
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
299 |
+
key = self.to_k(encoder_hidden_states)
|
300 |
+
value = self.to_v(encoder_hidden_states)
|
301 |
+
|
302 |
+
key = self.reshape_heads_to_batch_dim(key)
|
303 |
+
value = self.reshape_heads_to_batch_dim(value)
|
304 |
+
|
305 |
+
if attention_mask is not None:
|
306 |
+
if attention_mask.shape[-1] != query.shape[1]:
|
307 |
+
target_length = query.shape[1]
|
308 |
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
309 |
+
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
|
310 |
+
|
311 |
+
# attention, what we cannot get enough of
|
312 |
+
if self._use_memory_efficient_attention_xformers:
|
313 |
+
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
|
314 |
+
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
315 |
+
hidden_states = hidden_states.to(query.dtype)
|
316 |
+
else:
|
317 |
+
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
318 |
+
hidden_states = self._attention(query, key, value, attention_mask)
|
319 |
+
else:
|
320 |
+
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
|
321 |
+
|
322 |
+
# linear proj
|
323 |
+
hidden_states = self.to_out[0](hidden_states)
|
324 |
+
|
325 |
+
# dropout
|
326 |
+
hidden_states = self.to_out[1](hidden_states)
|
327 |
+
|
328 |
+
if self.attention_mode == "Temporal":
|
329 |
+
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
|
330 |
+
|
331 |
+
return hidden_states
|
animatediff/models/resnet.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from einops import rearrange
|
8 |
+
|
9 |
+
|
10 |
+
class InflatedConv3d(nn.Conv2d):
|
11 |
+
def forward(self, x):
|
12 |
+
video_length = x.shape[2]
|
13 |
+
|
14 |
+
x = rearrange(x, "b c f h w -> (b f) c h w")
|
15 |
+
x = super().forward(x)
|
16 |
+
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
|
17 |
+
|
18 |
+
return x
|
19 |
+
|
20 |
+
|
21 |
+
class InflatedGroupNorm(nn.GroupNorm):
|
22 |
+
def forward(self, x):
|
23 |
+
video_length = x.shape[2]
|
24 |
+
|
25 |
+
x = rearrange(x, "b c f h w -> (b f) c h w")
|
26 |
+
x = super().forward(x)
|
27 |
+
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
|
28 |
+
|
29 |
+
return x
|
30 |
+
|
31 |
+
|
32 |
+
class Upsample3D(nn.Module):
|
33 |
+
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
|
34 |
+
super().__init__()
|
35 |
+
self.channels = channels
|
36 |
+
self.out_channels = out_channels or channels
|
37 |
+
self.use_conv = use_conv
|
38 |
+
self.use_conv_transpose = use_conv_transpose
|
39 |
+
self.name = name
|
40 |
+
|
41 |
+
conv = None
|
42 |
+
if use_conv_transpose:
|
43 |
+
raise NotImplementedError
|
44 |
+
elif use_conv:
|
45 |
+
self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
|
46 |
+
|
47 |
+
def forward(self, hidden_states, output_size=None):
|
48 |
+
assert hidden_states.shape[1] == self.channels
|
49 |
+
|
50 |
+
if self.use_conv_transpose:
|
51 |
+
raise NotImplementedError
|
52 |
+
|
53 |
+
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
54 |
+
dtype = hidden_states.dtype
|
55 |
+
if dtype == torch.bfloat16:
|
56 |
+
hidden_states = hidden_states.to(torch.float32)
|
57 |
+
|
58 |
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
59 |
+
if hidden_states.shape[0] >= 64:
|
60 |
+
hidden_states = hidden_states.contiguous()
|
61 |
+
|
62 |
+
# if `output_size` is passed we force the interpolation output
|
63 |
+
# size and do not make use of `scale_factor=2`
|
64 |
+
if output_size is None:
|
65 |
+
hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
|
66 |
+
else:
|
67 |
+
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
|
68 |
+
|
69 |
+
# If the input is bfloat16, we cast back to bfloat16
|
70 |
+
if dtype == torch.bfloat16:
|
71 |
+
hidden_states = hidden_states.to(dtype)
|
72 |
+
|
73 |
+
# if self.use_conv:
|
74 |
+
# if self.name == "conv":
|
75 |
+
# hidden_states = self.conv(hidden_states)
|
76 |
+
# else:
|
77 |
+
# hidden_states = self.Conv2d_0(hidden_states)
|
78 |
+
hidden_states = self.conv(hidden_states)
|
79 |
+
|
80 |
+
return hidden_states
|
81 |
+
|
82 |
+
|
83 |
+
class Downsample3D(nn.Module):
|
84 |
+
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
|
85 |
+
super().__init__()
|
86 |
+
self.channels = channels
|
87 |
+
self.out_channels = out_channels or channels
|
88 |
+
self.use_conv = use_conv
|
89 |
+
self.padding = padding
|
90 |
+
stride = 2
|
91 |
+
self.name = name
|
92 |
+
|
93 |
+
if use_conv:
|
94 |
+
self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
95 |
+
else:
|
96 |
+
raise NotImplementedError
|
97 |
+
|
98 |
+
def forward(self, hidden_states):
|
99 |
+
assert hidden_states.shape[1] == self.channels
|
100 |
+
if self.use_conv and self.padding == 0:
|
101 |
+
raise NotImplementedError
|
102 |
+
|
103 |
+
assert hidden_states.shape[1] == self.channels
|
104 |
+
hidden_states = self.conv(hidden_states)
|
105 |
+
|
106 |
+
return hidden_states
|
107 |
+
|
108 |
+
|
109 |
+
class ResnetBlock3D(nn.Module):
|
110 |
+
def __init__(
|
111 |
+
self,
|
112 |
+
*,
|
113 |
+
in_channels,
|
114 |
+
out_channels=None,
|
115 |
+
conv_shortcut=False,
|
116 |
+
dropout=0.0,
|
117 |
+
temb_channels=512,
|
118 |
+
groups=32,
|
119 |
+
groups_out=None,
|
120 |
+
pre_norm=True,
|
121 |
+
eps=1e-6,
|
122 |
+
non_linearity="swish",
|
123 |
+
time_embedding_norm="default",
|
124 |
+
output_scale_factor=1.0,
|
125 |
+
use_in_shortcut=None,
|
126 |
+
use_inflated_groupnorm=None,
|
127 |
+
):
|
128 |
+
super().__init__()
|
129 |
+
self.pre_norm = pre_norm
|
130 |
+
self.pre_norm = True
|
131 |
+
self.in_channels = in_channels
|
132 |
+
out_channels = in_channels if out_channels is None else out_channels
|
133 |
+
self.out_channels = out_channels
|
134 |
+
self.use_conv_shortcut = conv_shortcut
|
135 |
+
self.time_embedding_norm = time_embedding_norm
|
136 |
+
self.output_scale_factor = output_scale_factor
|
137 |
+
|
138 |
+
if groups_out is None:
|
139 |
+
groups_out = groups
|
140 |
+
|
141 |
+
assert use_inflated_groupnorm != None
|
142 |
+
if use_inflated_groupnorm:
|
143 |
+
self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
144 |
+
else:
|
145 |
+
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
146 |
+
|
147 |
+
self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
148 |
+
|
149 |
+
if temb_channels is not None:
|
150 |
+
if self.time_embedding_norm == "default":
|
151 |
+
time_emb_proj_out_channels = out_channels
|
152 |
+
elif self.time_embedding_norm == "scale_shift":
|
153 |
+
time_emb_proj_out_channels = out_channels * 2
|
154 |
+
else:
|
155 |
+
raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
|
156 |
+
|
157 |
+
self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
|
158 |
+
else:
|
159 |
+
self.time_emb_proj = None
|
160 |
+
|
161 |
+
if use_inflated_groupnorm:
|
162 |
+
self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
163 |
+
else:
|
164 |
+
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
165 |
+
|
166 |
+
self.dropout = torch.nn.Dropout(dropout)
|
167 |
+
self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
168 |
+
|
169 |
+
if non_linearity == "swish":
|
170 |
+
self.nonlinearity = lambda x: F.silu(x)
|
171 |
+
elif non_linearity == "mish":
|
172 |
+
self.nonlinearity = Mish()
|
173 |
+
elif non_linearity == "silu":
|
174 |
+
self.nonlinearity = nn.SiLU()
|
175 |
+
|
176 |
+
self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
|
177 |
+
|
178 |
+
self.conv_shortcut = None
|
179 |
+
if self.use_in_shortcut:
|
180 |
+
self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
181 |
+
|
182 |
+
def forward(self, input_tensor, temb):
|
183 |
+
hidden_states = input_tensor
|
184 |
+
|
185 |
+
hidden_states = self.norm1(hidden_states)
|
186 |
+
hidden_states = self.nonlinearity(hidden_states)
|
187 |
+
|
188 |
+
hidden_states = self.conv1(hidden_states)
|
189 |
+
|
190 |
+
if temb is not None:
|
191 |
+
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
|
192 |
+
|
193 |
+
if temb is not None and self.time_embedding_norm == "default":
|
194 |
+
hidden_states = hidden_states + temb
|
195 |
+
|
196 |
+
hidden_states = self.norm2(hidden_states)
|
197 |
+
|
198 |
+
if temb is not None and self.time_embedding_norm == "scale_shift":
|
199 |
+
scale, shift = torch.chunk(temb, 2, dim=1)
|
200 |
+
hidden_states = hidden_states * (1 + scale) + shift
|
201 |
+
|
202 |
+
hidden_states = self.nonlinearity(hidden_states)
|
203 |
+
|
204 |
+
hidden_states = self.dropout(hidden_states)
|
205 |
+
hidden_states = self.conv2(hidden_states)
|
206 |
+
|
207 |
+
if self.conv_shortcut is not None:
|
208 |
+
input_tensor = self.conv_shortcut(input_tensor)
|
209 |
+
|
210 |
+
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
211 |
+
|
212 |
+
return output_tensor
|
213 |
+
|
214 |
+
|
215 |
+
class Mish(torch.nn.Module):
|
216 |
+
def forward(self, hidden_states):
|
217 |
+
return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
|
animatediff/models/unet.py
ADDED
@@ -0,0 +1,497 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
|
2 |
+
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from typing import List, Optional, Tuple, Union
|
5 |
+
|
6 |
+
import os
|
7 |
+
import json
|
8 |
+
import pdb
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.utils.checkpoint
|
13 |
+
|
14 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
15 |
+
from diffusers.modeling_utils import ModelMixin
|
16 |
+
from diffusers.utils import BaseOutput, logging
|
17 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
18 |
+
from .unet_blocks import (
|
19 |
+
CrossAttnDownBlock3D,
|
20 |
+
CrossAttnUpBlock3D,
|
21 |
+
DownBlock3D,
|
22 |
+
UNetMidBlock3DCrossAttn,
|
23 |
+
UpBlock3D,
|
24 |
+
get_down_block,
|
25 |
+
get_up_block,
|
26 |
+
)
|
27 |
+
from .resnet import InflatedConv3d, InflatedGroupNorm
|
28 |
+
|
29 |
+
|
30 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
31 |
+
|
32 |
+
|
33 |
+
@dataclass
|
34 |
+
class UNet3DConditionOutput(BaseOutput):
|
35 |
+
sample: torch.FloatTensor
|
36 |
+
|
37 |
+
|
38 |
+
class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
39 |
+
_supports_gradient_checkpointing = True
|
40 |
+
|
41 |
+
@register_to_config
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
sample_size: Optional[int] = None,
|
45 |
+
in_channels: int = 4,
|
46 |
+
out_channels: int = 4,
|
47 |
+
center_input_sample: bool = False,
|
48 |
+
flip_sin_to_cos: bool = True,
|
49 |
+
freq_shift: int = 0,
|
50 |
+
down_block_types: Tuple[str] = (
|
51 |
+
"CrossAttnDownBlock3D",
|
52 |
+
"CrossAttnDownBlock3D",
|
53 |
+
"CrossAttnDownBlock3D",
|
54 |
+
"DownBlock3D",
|
55 |
+
),
|
56 |
+
mid_block_type: str = "UNetMidBlock3DCrossAttn",
|
57 |
+
up_block_types: Tuple[str] = (
|
58 |
+
"UpBlock3D",
|
59 |
+
"CrossAttnUpBlock3D",
|
60 |
+
"CrossAttnUpBlock3D",
|
61 |
+
"CrossAttnUpBlock3D"
|
62 |
+
),
|
63 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
64 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
65 |
+
layers_per_block: int = 2,
|
66 |
+
downsample_padding: int = 1,
|
67 |
+
mid_block_scale_factor: float = 1,
|
68 |
+
act_fn: str = "silu",
|
69 |
+
norm_num_groups: int = 32,
|
70 |
+
norm_eps: float = 1e-5,
|
71 |
+
cross_attention_dim: int = 1280,
|
72 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
73 |
+
dual_cross_attention: bool = False,
|
74 |
+
use_linear_projection: bool = False,
|
75 |
+
class_embed_type: Optional[str] = None,
|
76 |
+
num_class_embeds: Optional[int] = None,
|
77 |
+
upcast_attention: bool = False,
|
78 |
+
resnet_time_scale_shift: str = "default",
|
79 |
+
|
80 |
+
use_inflated_groupnorm=False,
|
81 |
+
|
82 |
+
# Additional
|
83 |
+
use_motion_module = False,
|
84 |
+
motion_module_resolutions = ( 1,2,4,8 ),
|
85 |
+
motion_module_mid_block = False,
|
86 |
+
motion_module_decoder_only = False,
|
87 |
+
motion_module_type = None,
|
88 |
+
motion_module_kwargs = {},
|
89 |
+
unet_use_cross_frame_attention = None,
|
90 |
+
unet_use_temporal_attention = None,
|
91 |
+
):
|
92 |
+
super().__init__()
|
93 |
+
|
94 |
+
self.sample_size = sample_size
|
95 |
+
time_embed_dim = block_out_channels[0] * 4
|
96 |
+
|
97 |
+
# input
|
98 |
+
self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
|
99 |
+
|
100 |
+
# time
|
101 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
102 |
+
timestep_input_dim = block_out_channels[0]
|
103 |
+
|
104 |
+
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
105 |
+
|
106 |
+
# class embedding
|
107 |
+
if class_embed_type is None and num_class_embeds is not None:
|
108 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
109 |
+
elif class_embed_type == "timestep":
|
110 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
111 |
+
elif class_embed_type == "identity":
|
112 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
113 |
+
else:
|
114 |
+
self.class_embedding = None
|
115 |
+
|
116 |
+
self.down_blocks = nn.ModuleList([])
|
117 |
+
self.mid_block = None
|
118 |
+
self.up_blocks = nn.ModuleList([])
|
119 |
+
|
120 |
+
if isinstance(only_cross_attention, bool):
|
121 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
122 |
+
|
123 |
+
if isinstance(attention_head_dim, int):
|
124 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
125 |
+
|
126 |
+
# down
|
127 |
+
output_channel = block_out_channels[0]
|
128 |
+
for i, down_block_type in enumerate(down_block_types):
|
129 |
+
res = 2 ** i
|
130 |
+
input_channel = output_channel
|
131 |
+
output_channel = block_out_channels[i]
|
132 |
+
is_final_block = i == len(block_out_channels) - 1
|
133 |
+
|
134 |
+
down_block = get_down_block(
|
135 |
+
down_block_type,
|
136 |
+
num_layers=layers_per_block,
|
137 |
+
in_channels=input_channel,
|
138 |
+
out_channels=output_channel,
|
139 |
+
temb_channels=time_embed_dim,
|
140 |
+
add_downsample=not is_final_block,
|
141 |
+
resnet_eps=norm_eps,
|
142 |
+
resnet_act_fn=act_fn,
|
143 |
+
resnet_groups=norm_num_groups,
|
144 |
+
cross_attention_dim=cross_attention_dim,
|
145 |
+
attn_num_head_channels=attention_head_dim[i],
|
146 |
+
downsample_padding=downsample_padding,
|
147 |
+
dual_cross_attention=dual_cross_attention,
|
148 |
+
use_linear_projection=use_linear_projection,
|
149 |
+
only_cross_attention=only_cross_attention[i],
|
150 |
+
upcast_attention=upcast_attention,
|
151 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
152 |
+
|
153 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
154 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
155 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
156 |
+
|
157 |
+
use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only),
|
158 |
+
motion_module_type=motion_module_type,
|
159 |
+
motion_module_kwargs=motion_module_kwargs,
|
160 |
+
)
|
161 |
+
self.down_blocks.append(down_block)
|
162 |
+
|
163 |
+
# mid
|
164 |
+
if mid_block_type == "UNetMidBlock3DCrossAttn":
|
165 |
+
self.mid_block = UNetMidBlock3DCrossAttn(
|
166 |
+
in_channels=block_out_channels[-1],
|
167 |
+
temb_channels=time_embed_dim,
|
168 |
+
resnet_eps=norm_eps,
|
169 |
+
resnet_act_fn=act_fn,
|
170 |
+
output_scale_factor=mid_block_scale_factor,
|
171 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
172 |
+
cross_attention_dim=cross_attention_dim,
|
173 |
+
attn_num_head_channels=attention_head_dim[-1],
|
174 |
+
resnet_groups=norm_num_groups,
|
175 |
+
dual_cross_attention=dual_cross_attention,
|
176 |
+
use_linear_projection=use_linear_projection,
|
177 |
+
upcast_attention=upcast_attention,
|
178 |
+
|
179 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
180 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
181 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
182 |
+
|
183 |
+
use_motion_module=use_motion_module and motion_module_mid_block,
|
184 |
+
motion_module_type=motion_module_type,
|
185 |
+
motion_module_kwargs=motion_module_kwargs,
|
186 |
+
)
|
187 |
+
else:
|
188 |
+
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
|
189 |
+
|
190 |
+
# count how many layers upsample the videos
|
191 |
+
self.num_upsamplers = 0
|
192 |
+
|
193 |
+
# up
|
194 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
195 |
+
reversed_attention_head_dim = list(reversed(attention_head_dim))
|
196 |
+
only_cross_attention = list(reversed(only_cross_attention))
|
197 |
+
output_channel = reversed_block_out_channels[0]
|
198 |
+
for i, up_block_type in enumerate(up_block_types):
|
199 |
+
res = 2 ** (3 - i)
|
200 |
+
is_final_block = i == len(block_out_channels) - 1
|
201 |
+
|
202 |
+
prev_output_channel = output_channel
|
203 |
+
output_channel = reversed_block_out_channels[i]
|
204 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
205 |
+
|
206 |
+
# add upsample block for all BUT final layer
|
207 |
+
if not is_final_block:
|
208 |
+
add_upsample = True
|
209 |
+
self.num_upsamplers += 1
|
210 |
+
else:
|
211 |
+
add_upsample = False
|
212 |
+
|
213 |
+
up_block = get_up_block(
|
214 |
+
up_block_type,
|
215 |
+
num_layers=layers_per_block + 1,
|
216 |
+
in_channels=input_channel,
|
217 |
+
out_channels=output_channel,
|
218 |
+
prev_output_channel=prev_output_channel,
|
219 |
+
temb_channels=time_embed_dim,
|
220 |
+
add_upsample=add_upsample,
|
221 |
+
resnet_eps=norm_eps,
|
222 |
+
resnet_act_fn=act_fn,
|
223 |
+
resnet_groups=norm_num_groups,
|
224 |
+
cross_attention_dim=cross_attention_dim,
|
225 |
+
attn_num_head_channels=reversed_attention_head_dim[i],
|
226 |
+
dual_cross_attention=dual_cross_attention,
|
227 |
+
use_linear_projection=use_linear_projection,
|
228 |
+
only_cross_attention=only_cross_attention[i],
|
229 |
+
upcast_attention=upcast_attention,
|
230 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
231 |
+
|
232 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
233 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
234 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
235 |
+
|
236 |
+
use_motion_module=use_motion_module and (res in motion_module_resolutions),
|
237 |
+
motion_module_type=motion_module_type,
|
238 |
+
motion_module_kwargs=motion_module_kwargs,
|
239 |
+
)
|
240 |
+
self.up_blocks.append(up_block)
|
241 |
+
prev_output_channel = output_channel
|
242 |
+
|
243 |
+
# out
|
244 |
+
if use_inflated_groupnorm:
|
245 |
+
self.conv_norm_out = InflatedGroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
|
246 |
+
else:
|
247 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
|
248 |
+
self.conv_act = nn.SiLU()
|
249 |
+
self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
250 |
+
|
251 |
+
def set_attention_slice(self, slice_size):
|
252 |
+
r"""
|
253 |
+
Enable sliced attention computation.
|
254 |
+
|
255 |
+
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
256 |
+
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
257 |
+
|
258 |
+
Args:
|
259 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
260 |
+
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
261 |
+
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
|
262 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
263 |
+
must be a multiple of `slice_size`.
|
264 |
+
"""
|
265 |
+
sliceable_head_dims = []
|
266 |
+
|
267 |
+
def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
|
268 |
+
if hasattr(module, "set_attention_slice"):
|
269 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
270 |
+
|
271 |
+
for child in module.children():
|
272 |
+
fn_recursive_retrieve_slicable_dims(child)
|
273 |
+
|
274 |
+
# retrieve number of attention layers
|
275 |
+
for module in self.children():
|
276 |
+
fn_recursive_retrieve_slicable_dims(module)
|
277 |
+
|
278 |
+
num_slicable_layers = len(sliceable_head_dims)
|
279 |
+
|
280 |
+
if slice_size == "auto":
|
281 |
+
# half the attention head size is usually a good trade-off between
|
282 |
+
# speed and memory
|
283 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
284 |
+
elif slice_size == "max":
|
285 |
+
# make smallest slice possible
|
286 |
+
slice_size = num_slicable_layers * [1]
|
287 |
+
|
288 |
+
slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
289 |
+
|
290 |
+
if len(slice_size) != len(sliceable_head_dims):
|
291 |
+
raise ValueError(
|
292 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
293 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
294 |
+
)
|
295 |
+
|
296 |
+
for i in range(len(slice_size)):
|
297 |
+
size = slice_size[i]
|
298 |
+
dim = sliceable_head_dims[i]
|
299 |
+
if size is not None and size > dim:
|
300 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
301 |
+
|
302 |
+
# Recursively walk through all the children.
|
303 |
+
# Any children which exposes the set_attention_slice method
|
304 |
+
# gets the message
|
305 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
306 |
+
if hasattr(module, "set_attention_slice"):
|
307 |
+
module.set_attention_slice(slice_size.pop())
|
308 |
+
|
309 |
+
for child in module.children():
|
310 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
311 |
+
|
312 |
+
reversed_slice_size = list(reversed(slice_size))
|
313 |
+
for module in self.children():
|
314 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
315 |
+
|
316 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
317 |
+
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
|
318 |
+
module.gradient_checkpointing = value
|
319 |
+
|
320 |
+
def forward(
|
321 |
+
self,
|
322 |
+
sample: torch.FloatTensor,
|
323 |
+
timestep: Union[torch.Tensor, float, int],
|
324 |
+
encoder_hidden_states: torch.Tensor,
|
325 |
+
class_labels: Optional[torch.Tensor] = None,
|
326 |
+
attention_mask: Optional[torch.Tensor] = None,
|
327 |
+
return_dict: bool = True,
|
328 |
+
) -> Union[UNet3DConditionOutput, Tuple]:
|
329 |
+
r"""
|
330 |
+
Args:
|
331 |
+
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
332 |
+
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
333 |
+
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
|
334 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
335 |
+
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
336 |
+
|
337 |
+
Returns:
|
338 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
339 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
340 |
+
returning a tuple, the first element is the sample tensor.
|
341 |
+
"""
|
342 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
343 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
344 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
345 |
+
# on the fly if necessary.
|
346 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
347 |
+
|
348 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
349 |
+
forward_upsample_size = False
|
350 |
+
upsample_size = None
|
351 |
+
|
352 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
353 |
+
logger.info("Forward upsample size to force interpolation output size.")
|
354 |
+
forward_upsample_size = True
|
355 |
+
|
356 |
+
# prepare attention_mask
|
357 |
+
if attention_mask is not None:
|
358 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
359 |
+
attention_mask = attention_mask.unsqueeze(1)
|
360 |
+
|
361 |
+
# center input if necessary
|
362 |
+
if self.config.center_input_sample:
|
363 |
+
sample = 2 * sample - 1.0
|
364 |
+
|
365 |
+
# time
|
366 |
+
timesteps = timestep
|
367 |
+
if not torch.is_tensor(timesteps):
|
368 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
369 |
+
is_mps = sample.device.type == "mps"
|
370 |
+
if isinstance(timestep, float):
|
371 |
+
dtype = torch.float32 if is_mps else torch.float64
|
372 |
+
else:
|
373 |
+
dtype = torch.int32 if is_mps else torch.int64
|
374 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
375 |
+
elif len(timesteps.shape) == 0:
|
376 |
+
timesteps = timesteps[None].to(sample.device)
|
377 |
+
|
378 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
379 |
+
timesteps = timesteps.expand(sample.shape[0])
|
380 |
+
|
381 |
+
t_emb = self.time_proj(timesteps)
|
382 |
+
|
383 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
384 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
385 |
+
# there might be better ways to encapsulate this.
|
386 |
+
t_emb = t_emb.to(dtype=self.dtype)
|
387 |
+
emb = self.time_embedding(t_emb)
|
388 |
+
|
389 |
+
if self.class_embedding is not None:
|
390 |
+
if class_labels is None:
|
391 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
392 |
+
|
393 |
+
if self.config.class_embed_type == "timestep":
|
394 |
+
class_labels = self.time_proj(class_labels)
|
395 |
+
|
396 |
+
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
397 |
+
emb = emb + class_emb
|
398 |
+
|
399 |
+
# pre-process
|
400 |
+
sample = self.conv_in(sample)
|
401 |
+
|
402 |
+
# down
|
403 |
+
down_block_res_samples = (sample,)
|
404 |
+
for downsample_block in self.down_blocks:
|
405 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
406 |
+
sample, res_samples = downsample_block(
|
407 |
+
hidden_states=sample,
|
408 |
+
temb=emb,
|
409 |
+
encoder_hidden_states=encoder_hidden_states,
|
410 |
+
attention_mask=attention_mask,
|
411 |
+
)
|
412 |
+
else:
|
413 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states)
|
414 |
+
|
415 |
+
down_block_res_samples += res_samples
|
416 |
+
|
417 |
+
# mid
|
418 |
+
sample = self.mid_block(
|
419 |
+
sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
|
420 |
+
)
|
421 |
+
|
422 |
+
# up
|
423 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
424 |
+
is_final_block = i == len(self.up_blocks) - 1
|
425 |
+
|
426 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
427 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
428 |
+
|
429 |
+
# if we have not reached the final block and need to forward the
|
430 |
+
# upsample size, we do it here
|
431 |
+
if not is_final_block and forward_upsample_size:
|
432 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
433 |
+
|
434 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
435 |
+
sample = upsample_block(
|
436 |
+
hidden_states=sample,
|
437 |
+
temb=emb,
|
438 |
+
res_hidden_states_tuple=res_samples,
|
439 |
+
encoder_hidden_states=encoder_hidden_states,
|
440 |
+
upsample_size=upsample_size,
|
441 |
+
attention_mask=attention_mask,
|
442 |
+
)
|
443 |
+
else:
|
444 |
+
sample = upsample_block(
|
445 |
+
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states,
|
446 |
+
)
|
447 |
+
|
448 |
+
# post-process
|
449 |
+
sample = self.conv_norm_out(sample)
|
450 |
+
sample = self.conv_act(sample)
|
451 |
+
sample = self.conv_out(sample)
|
452 |
+
|
453 |
+
if not return_dict:
|
454 |
+
return (sample,)
|
455 |
+
|
456 |
+
return UNet3DConditionOutput(sample=sample)
|
457 |
+
|
458 |
+
@classmethod
|
459 |
+
def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None):
|
460 |
+
if subfolder is not None:
|
461 |
+
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
|
462 |
+
print(f"loaded temporal unet's pretrained weights from {pretrained_model_path} ...")
|
463 |
+
|
464 |
+
config_file = os.path.join(pretrained_model_path, 'config.json')
|
465 |
+
if not os.path.isfile(config_file):
|
466 |
+
raise RuntimeError(f"{config_file} does not exist")
|
467 |
+
with open(config_file, "r") as f:
|
468 |
+
config = json.load(f)
|
469 |
+
config["_class_name"] = cls.__name__
|
470 |
+
config["down_block_types"] = [
|
471 |
+
"CrossAttnDownBlock3D",
|
472 |
+
"CrossAttnDownBlock3D",
|
473 |
+
"CrossAttnDownBlock3D",
|
474 |
+
"DownBlock3D"
|
475 |
+
]
|
476 |
+
config["up_block_types"] = [
|
477 |
+
"UpBlock3D",
|
478 |
+
"CrossAttnUpBlock3D",
|
479 |
+
"CrossAttnUpBlock3D",
|
480 |
+
"CrossAttnUpBlock3D"
|
481 |
+
]
|
482 |
+
|
483 |
+
from diffusers.utils import WEIGHTS_NAME
|
484 |
+
model = cls.from_config(config, **unet_additional_kwargs)
|
485 |
+
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
486 |
+
if not os.path.isfile(model_file):
|
487 |
+
raise RuntimeError(f"{model_file} does not exist")
|
488 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
489 |
+
|
490 |
+
m, u = model.load_state_dict(state_dict, strict=False)
|
491 |
+
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
492 |
+
# print(f"### missing keys:\n{m}\n### unexpected keys:\n{u}\n")
|
493 |
+
|
494 |
+
params = [p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()]
|
495 |
+
print(f"### Temporal Module Parameters: {sum(params) / 1e6} M")
|
496 |
+
|
497 |
+
return model
|
animatediff/models/unet_blocks.py
ADDED
@@ -0,0 +1,760 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
from .attention import Transformer3DModel
|
7 |
+
from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
|
8 |
+
from .motion_module import get_motion_module
|
9 |
+
|
10 |
+
import pdb
|
11 |
+
|
12 |
+
def get_down_block(
|
13 |
+
down_block_type,
|
14 |
+
num_layers,
|
15 |
+
in_channels,
|
16 |
+
out_channels,
|
17 |
+
temb_channels,
|
18 |
+
add_downsample,
|
19 |
+
resnet_eps,
|
20 |
+
resnet_act_fn,
|
21 |
+
attn_num_head_channels,
|
22 |
+
resnet_groups=None,
|
23 |
+
cross_attention_dim=None,
|
24 |
+
downsample_padding=None,
|
25 |
+
dual_cross_attention=False,
|
26 |
+
use_linear_projection=False,
|
27 |
+
only_cross_attention=False,
|
28 |
+
upcast_attention=False,
|
29 |
+
resnet_time_scale_shift="default",
|
30 |
+
|
31 |
+
unet_use_cross_frame_attention=None,
|
32 |
+
unet_use_temporal_attention=None,
|
33 |
+
use_inflated_groupnorm=None,
|
34 |
+
|
35 |
+
use_motion_module=None,
|
36 |
+
|
37 |
+
motion_module_type=None,
|
38 |
+
motion_module_kwargs=None,
|
39 |
+
):
|
40 |
+
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
|
41 |
+
if down_block_type == "DownBlock3D":
|
42 |
+
return DownBlock3D(
|
43 |
+
num_layers=num_layers,
|
44 |
+
in_channels=in_channels,
|
45 |
+
out_channels=out_channels,
|
46 |
+
temb_channels=temb_channels,
|
47 |
+
add_downsample=add_downsample,
|
48 |
+
resnet_eps=resnet_eps,
|
49 |
+
resnet_act_fn=resnet_act_fn,
|
50 |
+
resnet_groups=resnet_groups,
|
51 |
+
downsample_padding=downsample_padding,
|
52 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
53 |
+
|
54 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
55 |
+
|
56 |
+
use_motion_module=use_motion_module,
|
57 |
+
motion_module_type=motion_module_type,
|
58 |
+
motion_module_kwargs=motion_module_kwargs,
|
59 |
+
)
|
60 |
+
elif down_block_type == "CrossAttnDownBlock3D":
|
61 |
+
if cross_attention_dim is None:
|
62 |
+
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
|
63 |
+
return CrossAttnDownBlock3D(
|
64 |
+
num_layers=num_layers,
|
65 |
+
in_channels=in_channels,
|
66 |
+
out_channels=out_channels,
|
67 |
+
temb_channels=temb_channels,
|
68 |
+
add_downsample=add_downsample,
|
69 |
+
resnet_eps=resnet_eps,
|
70 |
+
resnet_act_fn=resnet_act_fn,
|
71 |
+
resnet_groups=resnet_groups,
|
72 |
+
downsample_padding=downsample_padding,
|
73 |
+
cross_attention_dim=cross_attention_dim,
|
74 |
+
attn_num_head_channels=attn_num_head_channels,
|
75 |
+
dual_cross_attention=dual_cross_attention,
|
76 |
+
use_linear_projection=use_linear_projection,
|
77 |
+
only_cross_attention=only_cross_attention,
|
78 |
+
upcast_attention=upcast_attention,
|
79 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
80 |
+
|
81 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
82 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
83 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
84 |
+
|
85 |
+
use_motion_module=use_motion_module,
|
86 |
+
motion_module_type=motion_module_type,
|
87 |
+
motion_module_kwargs=motion_module_kwargs,
|
88 |
+
)
|
89 |
+
raise ValueError(f"{down_block_type} does not exist.")
|
90 |
+
|
91 |
+
|
92 |
+
def get_up_block(
|
93 |
+
up_block_type,
|
94 |
+
num_layers,
|
95 |
+
in_channels,
|
96 |
+
out_channels,
|
97 |
+
prev_output_channel,
|
98 |
+
temb_channels,
|
99 |
+
add_upsample,
|
100 |
+
resnet_eps,
|
101 |
+
resnet_act_fn,
|
102 |
+
attn_num_head_channels,
|
103 |
+
resnet_groups=None,
|
104 |
+
cross_attention_dim=None,
|
105 |
+
dual_cross_attention=False,
|
106 |
+
use_linear_projection=False,
|
107 |
+
only_cross_attention=False,
|
108 |
+
upcast_attention=False,
|
109 |
+
resnet_time_scale_shift="default",
|
110 |
+
|
111 |
+
unet_use_cross_frame_attention=None,
|
112 |
+
unet_use_temporal_attention=None,
|
113 |
+
use_inflated_groupnorm=None,
|
114 |
+
|
115 |
+
use_motion_module=None,
|
116 |
+
motion_module_type=None,
|
117 |
+
motion_module_kwargs=None,
|
118 |
+
):
|
119 |
+
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
120 |
+
if up_block_type == "UpBlock3D":
|
121 |
+
return UpBlock3D(
|
122 |
+
num_layers=num_layers,
|
123 |
+
in_channels=in_channels,
|
124 |
+
out_channels=out_channels,
|
125 |
+
prev_output_channel=prev_output_channel,
|
126 |
+
temb_channels=temb_channels,
|
127 |
+
add_upsample=add_upsample,
|
128 |
+
resnet_eps=resnet_eps,
|
129 |
+
resnet_act_fn=resnet_act_fn,
|
130 |
+
resnet_groups=resnet_groups,
|
131 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
132 |
+
|
133 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
134 |
+
|
135 |
+
use_motion_module=use_motion_module,
|
136 |
+
motion_module_type=motion_module_type,
|
137 |
+
motion_module_kwargs=motion_module_kwargs,
|
138 |
+
)
|
139 |
+
elif up_block_type == "CrossAttnUpBlock3D":
|
140 |
+
if cross_attention_dim is None:
|
141 |
+
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
|
142 |
+
return CrossAttnUpBlock3D(
|
143 |
+
num_layers=num_layers,
|
144 |
+
in_channels=in_channels,
|
145 |
+
out_channels=out_channels,
|
146 |
+
prev_output_channel=prev_output_channel,
|
147 |
+
temb_channels=temb_channels,
|
148 |
+
add_upsample=add_upsample,
|
149 |
+
resnet_eps=resnet_eps,
|
150 |
+
resnet_act_fn=resnet_act_fn,
|
151 |
+
resnet_groups=resnet_groups,
|
152 |
+
cross_attention_dim=cross_attention_dim,
|
153 |
+
attn_num_head_channels=attn_num_head_channels,
|
154 |
+
dual_cross_attention=dual_cross_attention,
|
155 |
+
use_linear_projection=use_linear_projection,
|
156 |
+
only_cross_attention=only_cross_attention,
|
157 |
+
upcast_attention=upcast_attention,
|
158 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
159 |
+
|
160 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
161 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
162 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
163 |
+
|
164 |
+
use_motion_module=use_motion_module,
|
165 |
+
motion_module_type=motion_module_type,
|
166 |
+
motion_module_kwargs=motion_module_kwargs,
|
167 |
+
)
|
168 |
+
raise ValueError(f"{up_block_type} does not exist.")
|
169 |
+
|
170 |
+
|
171 |
+
class UNetMidBlock3DCrossAttn(nn.Module):
|
172 |
+
def __init__(
|
173 |
+
self,
|
174 |
+
in_channels: int,
|
175 |
+
temb_channels: int,
|
176 |
+
dropout: float = 0.0,
|
177 |
+
num_layers: int = 1,
|
178 |
+
resnet_eps: float = 1e-6,
|
179 |
+
resnet_time_scale_shift: str = "default",
|
180 |
+
resnet_act_fn: str = "swish",
|
181 |
+
resnet_groups: int = 32,
|
182 |
+
resnet_pre_norm: bool = True,
|
183 |
+
attn_num_head_channels=1,
|
184 |
+
output_scale_factor=1.0,
|
185 |
+
cross_attention_dim=1280,
|
186 |
+
dual_cross_attention=False,
|
187 |
+
use_linear_projection=False,
|
188 |
+
upcast_attention=False,
|
189 |
+
|
190 |
+
unet_use_cross_frame_attention=None,
|
191 |
+
unet_use_temporal_attention=None,
|
192 |
+
use_inflated_groupnorm=None,
|
193 |
+
|
194 |
+
use_motion_module=None,
|
195 |
+
|
196 |
+
motion_module_type=None,
|
197 |
+
motion_module_kwargs=None,
|
198 |
+
):
|
199 |
+
super().__init__()
|
200 |
+
|
201 |
+
self.has_cross_attention = True
|
202 |
+
self.attn_num_head_channels = attn_num_head_channels
|
203 |
+
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
204 |
+
|
205 |
+
# there is always at least one resnet
|
206 |
+
resnets = [
|
207 |
+
ResnetBlock3D(
|
208 |
+
in_channels=in_channels,
|
209 |
+
out_channels=in_channels,
|
210 |
+
temb_channels=temb_channels,
|
211 |
+
eps=resnet_eps,
|
212 |
+
groups=resnet_groups,
|
213 |
+
dropout=dropout,
|
214 |
+
time_embedding_norm=resnet_time_scale_shift,
|
215 |
+
non_linearity=resnet_act_fn,
|
216 |
+
output_scale_factor=output_scale_factor,
|
217 |
+
pre_norm=resnet_pre_norm,
|
218 |
+
|
219 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
220 |
+
)
|
221 |
+
]
|
222 |
+
attentions = []
|
223 |
+
motion_modules = []
|
224 |
+
|
225 |
+
for _ in range(num_layers):
|
226 |
+
if dual_cross_attention:
|
227 |
+
raise NotImplementedError
|
228 |
+
attentions.append(
|
229 |
+
Transformer3DModel(
|
230 |
+
attn_num_head_channels,
|
231 |
+
in_channels // attn_num_head_channels,
|
232 |
+
in_channels=in_channels,
|
233 |
+
num_layers=1,
|
234 |
+
cross_attention_dim=cross_attention_dim,
|
235 |
+
norm_num_groups=resnet_groups,
|
236 |
+
use_linear_projection=use_linear_projection,
|
237 |
+
upcast_attention=upcast_attention,
|
238 |
+
|
239 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
240 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
241 |
+
)
|
242 |
+
)
|
243 |
+
motion_modules.append(
|
244 |
+
get_motion_module(
|
245 |
+
in_channels=in_channels,
|
246 |
+
motion_module_type=motion_module_type,
|
247 |
+
motion_module_kwargs=motion_module_kwargs,
|
248 |
+
) if use_motion_module else None
|
249 |
+
)
|
250 |
+
resnets.append(
|
251 |
+
ResnetBlock3D(
|
252 |
+
in_channels=in_channels,
|
253 |
+
out_channels=in_channels,
|
254 |
+
temb_channels=temb_channels,
|
255 |
+
eps=resnet_eps,
|
256 |
+
groups=resnet_groups,
|
257 |
+
dropout=dropout,
|
258 |
+
time_embedding_norm=resnet_time_scale_shift,
|
259 |
+
non_linearity=resnet_act_fn,
|
260 |
+
output_scale_factor=output_scale_factor,
|
261 |
+
pre_norm=resnet_pre_norm,
|
262 |
+
|
263 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
264 |
+
)
|
265 |
+
)
|
266 |
+
|
267 |
+
self.attentions = nn.ModuleList(attentions)
|
268 |
+
self.resnets = nn.ModuleList(resnets)
|
269 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
270 |
+
|
271 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
|
272 |
+
hidden_states = self.resnets[0](hidden_states, temb)
|
273 |
+
for attn, resnet, motion_module in zip(self.attentions, self.resnets[1:], self.motion_modules):
|
274 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
275 |
+
hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
|
276 |
+
hidden_states = resnet(hidden_states, temb)
|
277 |
+
|
278 |
+
return hidden_states
|
279 |
+
|
280 |
+
|
281 |
+
class CrossAttnDownBlock3D(nn.Module):
|
282 |
+
def __init__(
|
283 |
+
self,
|
284 |
+
in_channels: int,
|
285 |
+
out_channels: int,
|
286 |
+
temb_channels: int,
|
287 |
+
dropout: float = 0.0,
|
288 |
+
num_layers: int = 1,
|
289 |
+
resnet_eps: float = 1e-6,
|
290 |
+
resnet_time_scale_shift: str = "default",
|
291 |
+
resnet_act_fn: str = "swish",
|
292 |
+
resnet_groups: int = 32,
|
293 |
+
resnet_pre_norm: bool = True,
|
294 |
+
attn_num_head_channels=1,
|
295 |
+
cross_attention_dim=1280,
|
296 |
+
output_scale_factor=1.0,
|
297 |
+
downsample_padding=1,
|
298 |
+
add_downsample=True,
|
299 |
+
dual_cross_attention=False,
|
300 |
+
use_linear_projection=False,
|
301 |
+
only_cross_attention=False,
|
302 |
+
upcast_attention=False,
|
303 |
+
|
304 |
+
unet_use_cross_frame_attention=None,
|
305 |
+
unet_use_temporal_attention=None,
|
306 |
+
use_inflated_groupnorm=None,
|
307 |
+
|
308 |
+
use_motion_module=None,
|
309 |
+
|
310 |
+
motion_module_type=None,
|
311 |
+
motion_module_kwargs=None,
|
312 |
+
):
|
313 |
+
super().__init__()
|
314 |
+
resnets = []
|
315 |
+
attentions = []
|
316 |
+
motion_modules = []
|
317 |
+
|
318 |
+
self.has_cross_attention = True
|
319 |
+
self.attn_num_head_channels = attn_num_head_channels
|
320 |
+
|
321 |
+
for i in range(num_layers):
|
322 |
+
in_channels = in_channels if i == 0 else out_channels
|
323 |
+
resnets.append(
|
324 |
+
ResnetBlock3D(
|
325 |
+
in_channels=in_channels,
|
326 |
+
out_channels=out_channels,
|
327 |
+
temb_channels=temb_channels,
|
328 |
+
eps=resnet_eps,
|
329 |
+
groups=resnet_groups,
|
330 |
+
dropout=dropout,
|
331 |
+
time_embedding_norm=resnet_time_scale_shift,
|
332 |
+
non_linearity=resnet_act_fn,
|
333 |
+
output_scale_factor=output_scale_factor,
|
334 |
+
pre_norm=resnet_pre_norm,
|
335 |
+
|
336 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
337 |
+
)
|
338 |
+
)
|
339 |
+
if dual_cross_attention:
|
340 |
+
raise NotImplementedError
|
341 |
+
attentions.append(
|
342 |
+
Transformer3DModel(
|
343 |
+
attn_num_head_channels,
|
344 |
+
out_channels // attn_num_head_channels,
|
345 |
+
in_channels=out_channels,
|
346 |
+
num_layers=1,
|
347 |
+
cross_attention_dim=cross_attention_dim,
|
348 |
+
norm_num_groups=resnet_groups,
|
349 |
+
use_linear_projection=use_linear_projection,
|
350 |
+
only_cross_attention=only_cross_attention,
|
351 |
+
upcast_attention=upcast_attention,
|
352 |
+
|
353 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
354 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
355 |
+
)
|
356 |
+
)
|
357 |
+
motion_modules.append(
|
358 |
+
get_motion_module(
|
359 |
+
in_channels=out_channels,
|
360 |
+
motion_module_type=motion_module_type,
|
361 |
+
motion_module_kwargs=motion_module_kwargs,
|
362 |
+
) if use_motion_module else None
|
363 |
+
)
|
364 |
+
|
365 |
+
self.attentions = nn.ModuleList(attentions)
|
366 |
+
self.resnets = nn.ModuleList(resnets)
|
367 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
368 |
+
|
369 |
+
if add_downsample:
|
370 |
+
self.downsamplers = nn.ModuleList(
|
371 |
+
[
|
372 |
+
Downsample3D(
|
373 |
+
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
374 |
+
)
|
375 |
+
]
|
376 |
+
)
|
377 |
+
else:
|
378 |
+
self.downsamplers = None
|
379 |
+
|
380 |
+
self.gradient_checkpointing = False
|
381 |
+
|
382 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
|
383 |
+
output_states = ()
|
384 |
+
|
385 |
+
for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):
|
386 |
+
if self.training and self.gradient_checkpointing:
|
387 |
+
|
388 |
+
def create_custom_forward(module, return_dict=None):
|
389 |
+
def custom_forward(*inputs):
|
390 |
+
if return_dict is not None:
|
391 |
+
return module(*inputs, return_dict=return_dict)
|
392 |
+
else:
|
393 |
+
return module(*inputs)
|
394 |
+
|
395 |
+
return custom_forward
|
396 |
+
|
397 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
398 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
399 |
+
create_custom_forward(attn, return_dict=False),
|
400 |
+
hidden_states,
|
401 |
+
encoder_hidden_states,
|
402 |
+
)[0]
|
403 |
+
if motion_module is not None:
|
404 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
|
405 |
+
|
406 |
+
else:
|
407 |
+
hidden_states = resnet(hidden_states, temb)
|
408 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
409 |
+
|
410 |
+
# add motion module
|
411 |
+
hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
|
412 |
+
|
413 |
+
output_states += (hidden_states,)
|
414 |
+
|
415 |
+
if self.downsamplers is not None:
|
416 |
+
for downsampler in self.downsamplers:
|
417 |
+
hidden_states = downsampler(hidden_states)
|
418 |
+
|
419 |
+
output_states += (hidden_states,)
|
420 |
+
|
421 |
+
return hidden_states, output_states
|
422 |
+
|
423 |
+
|
424 |
+
class DownBlock3D(nn.Module):
|
425 |
+
def __init__(
|
426 |
+
self,
|
427 |
+
in_channels: int,
|
428 |
+
out_channels: int,
|
429 |
+
temb_channels: int,
|
430 |
+
dropout: float = 0.0,
|
431 |
+
num_layers: int = 1,
|
432 |
+
resnet_eps: float = 1e-6,
|
433 |
+
resnet_time_scale_shift: str = "default",
|
434 |
+
resnet_act_fn: str = "swish",
|
435 |
+
resnet_groups: int = 32,
|
436 |
+
resnet_pre_norm: bool = True,
|
437 |
+
output_scale_factor=1.0,
|
438 |
+
add_downsample=True,
|
439 |
+
downsample_padding=1,
|
440 |
+
|
441 |
+
use_inflated_groupnorm=None,
|
442 |
+
|
443 |
+
use_motion_module=None,
|
444 |
+
motion_module_type=None,
|
445 |
+
motion_module_kwargs=None,
|
446 |
+
):
|
447 |
+
super().__init__()
|
448 |
+
resnets = []
|
449 |
+
motion_modules = []
|
450 |
+
|
451 |
+
for i in range(num_layers):
|
452 |
+
in_channels = in_channels if i == 0 else out_channels
|
453 |
+
resnets.append(
|
454 |
+
ResnetBlock3D(
|
455 |
+
in_channels=in_channels,
|
456 |
+
out_channels=out_channels,
|
457 |
+
temb_channels=temb_channels,
|
458 |
+
eps=resnet_eps,
|
459 |
+
groups=resnet_groups,
|
460 |
+
dropout=dropout,
|
461 |
+
time_embedding_norm=resnet_time_scale_shift,
|
462 |
+
non_linearity=resnet_act_fn,
|
463 |
+
output_scale_factor=output_scale_factor,
|
464 |
+
pre_norm=resnet_pre_norm,
|
465 |
+
|
466 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
467 |
+
)
|
468 |
+
)
|
469 |
+
motion_modules.append(
|
470 |
+
get_motion_module(
|
471 |
+
in_channels=out_channels,
|
472 |
+
motion_module_type=motion_module_type,
|
473 |
+
motion_module_kwargs=motion_module_kwargs,
|
474 |
+
) if use_motion_module else None
|
475 |
+
)
|
476 |
+
|
477 |
+
self.resnets = nn.ModuleList(resnets)
|
478 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
479 |
+
|
480 |
+
if add_downsample:
|
481 |
+
self.downsamplers = nn.ModuleList(
|
482 |
+
[
|
483 |
+
Downsample3D(
|
484 |
+
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
485 |
+
)
|
486 |
+
]
|
487 |
+
)
|
488 |
+
else:
|
489 |
+
self.downsamplers = None
|
490 |
+
|
491 |
+
self.gradient_checkpointing = False
|
492 |
+
|
493 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
494 |
+
output_states = ()
|
495 |
+
|
496 |
+
for resnet, motion_module in zip(self.resnets, self.motion_modules):
|
497 |
+
if self.training and self.gradient_checkpointing:
|
498 |
+
def create_custom_forward(module):
|
499 |
+
def custom_forward(*inputs):
|
500 |
+
return module(*inputs)
|
501 |
+
|
502 |
+
return custom_forward
|
503 |
+
|
504 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
505 |
+
if motion_module is not None:
|
506 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
|
507 |
+
else:
|
508 |
+
hidden_states = resnet(hidden_states, temb)
|
509 |
+
|
510 |
+
# add motion module
|
511 |
+
hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
|
512 |
+
|
513 |
+
output_states += (hidden_states,)
|
514 |
+
|
515 |
+
if self.downsamplers is not None:
|
516 |
+
for downsampler in self.downsamplers:
|
517 |
+
hidden_states = downsampler(hidden_states)
|
518 |
+
|
519 |
+
output_states += (hidden_states,)
|
520 |
+
|
521 |
+
return hidden_states, output_states
|
522 |
+
|
523 |
+
|
524 |
+
class CrossAttnUpBlock3D(nn.Module):
|
525 |
+
def __init__(
|
526 |
+
self,
|
527 |
+
in_channels: int,
|
528 |
+
out_channels: int,
|
529 |
+
prev_output_channel: int,
|
530 |
+
temb_channels: int,
|
531 |
+
dropout: float = 0.0,
|
532 |
+
num_layers: int = 1,
|
533 |
+
resnet_eps: float = 1e-6,
|
534 |
+
resnet_time_scale_shift: str = "default",
|
535 |
+
resnet_act_fn: str = "swish",
|
536 |
+
resnet_groups: int = 32,
|
537 |
+
resnet_pre_norm: bool = True,
|
538 |
+
attn_num_head_channels=1,
|
539 |
+
cross_attention_dim=1280,
|
540 |
+
output_scale_factor=1.0,
|
541 |
+
add_upsample=True,
|
542 |
+
dual_cross_attention=False,
|
543 |
+
use_linear_projection=False,
|
544 |
+
only_cross_attention=False,
|
545 |
+
upcast_attention=False,
|
546 |
+
|
547 |
+
unet_use_cross_frame_attention=None,
|
548 |
+
unet_use_temporal_attention=None,
|
549 |
+
use_inflated_groupnorm=None,
|
550 |
+
|
551 |
+
use_motion_module=None,
|
552 |
+
|
553 |
+
motion_module_type=None,
|
554 |
+
motion_module_kwargs=None,
|
555 |
+
):
|
556 |
+
super().__init__()
|
557 |
+
resnets = []
|
558 |
+
attentions = []
|
559 |
+
motion_modules = []
|
560 |
+
|
561 |
+
self.has_cross_attention = True
|
562 |
+
self.attn_num_head_channels = attn_num_head_channels
|
563 |
+
|
564 |
+
for i in range(num_layers):
|
565 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
566 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
567 |
+
|
568 |
+
resnets.append(
|
569 |
+
ResnetBlock3D(
|
570 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
571 |
+
out_channels=out_channels,
|
572 |
+
temb_channels=temb_channels,
|
573 |
+
eps=resnet_eps,
|
574 |
+
groups=resnet_groups,
|
575 |
+
dropout=dropout,
|
576 |
+
time_embedding_norm=resnet_time_scale_shift,
|
577 |
+
non_linearity=resnet_act_fn,
|
578 |
+
output_scale_factor=output_scale_factor,
|
579 |
+
pre_norm=resnet_pre_norm,
|
580 |
+
|
581 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
582 |
+
)
|
583 |
+
)
|
584 |
+
if dual_cross_attention:
|
585 |
+
raise NotImplementedError
|
586 |
+
attentions.append(
|
587 |
+
Transformer3DModel(
|
588 |
+
attn_num_head_channels,
|
589 |
+
out_channels // attn_num_head_channels,
|
590 |
+
in_channels=out_channels,
|
591 |
+
num_layers=1,
|
592 |
+
cross_attention_dim=cross_attention_dim,
|
593 |
+
norm_num_groups=resnet_groups,
|
594 |
+
use_linear_projection=use_linear_projection,
|
595 |
+
only_cross_attention=only_cross_attention,
|
596 |
+
upcast_attention=upcast_attention,
|
597 |
+
|
598 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
599 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
600 |
+
)
|
601 |
+
)
|
602 |
+
motion_modules.append(
|
603 |
+
get_motion_module(
|
604 |
+
in_channels=out_channels,
|
605 |
+
motion_module_type=motion_module_type,
|
606 |
+
motion_module_kwargs=motion_module_kwargs,
|
607 |
+
) if use_motion_module else None
|
608 |
+
)
|
609 |
+
|
610 |
+
self.attentions = nn.ModuleList(attentions)
|
611 |
+
self.resnets = nn.ModuleList(resnets)
|
612 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
613 |
+
|
614 |
+
if add_upsample:
|
615 |
+
self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
|
616 |
+
else:
|
617 |
+
self.upsamplers = None
|
618 |
+
|
619 |
+
self.gradient_checkpointing = False
|
620 |
+
|
621 |
+
def forward(
|
622 |
+
self,
|
623 |
+
hidden_states,
|
624 |
+
res_hidden_states_tuple,
|
625 |
+
temb=None,
|
626 |
+
encoder_hidden_states=None,
|
627 |
+
upsample_size=None,
|
628 |
+
attention_mask=None,
|
629 |
+
):
|
630 |
+
for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):
|
631 |
+
# pop res hidden states
|
632 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
633 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
634 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
635 |
+
|
636 |
+
if self.training and self.gradient_checkpointing:
|
637 |
+
|
638 |
+
def create_custom_forward(module, return_dict=None):
|
639 |
+
def custom_forward(*inputs):
|
640 |
+
if return_dict is not None:
|
641 |
+
return module(*inputs, return_dict=return_dict)
|
642 |
+
else:
|
643 |
+
return module(*inputs)
|
644 |
+
|
645 |
+
return custom_forward
|
646 |
+
|
647 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
648 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
649 |
+
create_custom_forward(attn, return_dict=False),
|
650 |
+
hidden_states,
|
651 |
+
encoder_hidden_states,
|
652 |
+
)[0]
|
653 |
+
if motion_module is not None:
|
654 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
|
655 |
+
|
656 |
+
else:
|
657 |
+
hidden_states = resnet(hidden_states, temb)
|
658 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
659 |
+
|
660 |
+
# add motion module
|
661 |
+
hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
|
662 |
+
|
663 |
+
if self.upsamplers is not None:
|
664 |
+
for upsampler in self.upsamplers:
|
665 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
666 |
+
|
667 |
+
return hidden_states
|
668 |
+
|
669 |
+
|
670 |
+
class UpBlock3D(nn.Module):
|
671 |
+
def __init__(
|
672 |
+
self,
|
673 |
+
in_channels: int,
|
674 |
+
prev_output_channel: int,
|
675 |
+
out_channels: int,
|
676 |
+
temb_channels: int,
|
677 |
+
dropout: float = 0.0,
|
678 |
+
num_layers: int = 1,
|
679 |
+
resnet_eps: float = 1e-6,
|
680 |
+
resnet_time_scale_shift: str = "default",
|
681 |
+
resnet_act_fn: str = "swish",
|
682 |
+
resnet_groups: int = 32,
|
683 |
+
resnet_pre_norm: bool = True,
|
684 |
+
output_scale_factor=1.0,
|
685 |
+
add_upsample=True,
|
686 |
+
|
687 |
+
use_inflated_groupnorm=None,
|
688 |
+
|
689 |
+
use_motion_module=None,
|
690 |
+
motion_module_type=None,
|
691 |
+
motion_module_kwargs=None,
|
692 |
+
):
|
693 |
+
super().__init__()
|
694 |
+
resnets = []
|
695 |
+
motion_modules = []
|
696 |
+
|
697 |
+
for i in range(num_layers):
|
698 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
699 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
700 |
+
|
701 |
+
resnets.append(
|
702 |
+
ResnetBlock3D(
|
703 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
704 |
+
out_channels=out_channels,
|
705 |
+
temb_channels=temb_channels,
|
706 |
+
eps=resnet_eps,
|
707 |
+
groups=resnet_groups,
|
708 |
+
dropout=dropout,
|
709 |
+
time_embedding_norm=resnet_time_scale_shift,
|
710 |
+
non_linearity=resnet_act_fn,
|
711 |
+
output_scale_factor=output_scale_factor,
|
712 |
+
pre_norm=resnet_pre_norm,
|
713 |
+
|
714 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
715 |
+
)
|
716 |
+
)
|
717 |
+
motion_modules.append(
|
718 |
+
get_motion_module(
|
719 |
+
in_channels=out_channels,
|
720 |
+
motion_module_type=motion_module_type,
|
721 |
+
motion_module_kwargs=motion_module_kwargs,
|
722 |
+
) if use_motion_module else None
|
723 |
+
)
|
724 |
+
|
725 |
+
self.resnets = nn.ModuleList(resnets)
|
726 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
727 |
+
|
728 |
+
if add_upsample:
|
729 |
+
self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
|
730 |
+
else:
|
731 |
+
self.upsamplers = None
|
732 |
+
|
733 |
+
self.gradient_checkpointing = False
|
734 |
+
|
735 |
+
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, encoder_hidden_states=None,):
|
736 |
+
for resnet, motion_module in zip(self.resnets, self.motion_modules):
|
737 |
+
# pop res hidden states
|
738 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
739 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
740 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
741 |
+
|
742 |
+
if self.training and self.gradient_checkpointing:
|
743 |
+
def create_custom_forward(module):
|
744 |
+
def custom_forward(*inputs):
|
745 |
+
return module(*inputs)
|
746 |
+
|
747 |
+
return custom_forward
|
748 |
+
|
749 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
750 |
+
if motion_module is not None:
|
751 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
|
752 |
+
else:
|
753 |
+
hidden_states = resnet(hidden_states, temb)
|
754 |
+
hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
|
755 |
+
|
756 |
+
if self.upsamplers is not None:
|
757 |
+
for upsampler in self.upsamplers:
|
758 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
759 |
+
|
760 |
+
return hidden_states
|
animatediff/pipelines/pipeline_animation.py
ADDED
@@ -0,0 +1,656 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py
|
2 |
+
|
3 |
+
import inspect
|
4 |
+
from typing import Callable, List, Optional, Union
|
5 |
+
from dataclasses import dataclass
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from diffusers.utils import is_accelerate_available
|
12 |
+
from packaging import version
|
13 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
14 |
+
|
15 |
+
from diffusers.configuration_utils import FrozenDict
|
16 |
+
from diffusers.models import AutoencoderKL
|
17 |
+
from diffusers.pipeline_utils import DiffusionPipeline
|
18 |
+
from diffusers.schedulers import (
|
19 |
+
DDIMScheduler,
|
20 |
+
DPMSolverMultistepScheduler,
|
21 |
+
EulerAncestralDiscreteScheduler,
|
22 |
+
EulerDiscreteScheduler,
|
23 |
+
LMSDiscreteScheduler,
|
24 |
+
PNDMScheduler,
|
25 |
+
)
|
26 |
+
from diffusers.utils import deprecate, logging, BaseOutput
|
27 |
+
|
28 |
+
from einops import rearrange
|
29 |
+
|
30 |
+
from ..models.unet import UNet3DConditionModel
|
31 |
+
|
32 |
+
from ..utils.freeinit_utils import (
|
33 |
+
get_freq_filter,
|
34 |
+
freq_mix_3d,
|
35 |
+
)
|
36 |
+
import os
|
37 |
+
|
38 |
+
from ..utils.util import save_videos_grid
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
43 |
+
|
44 |
+
|
45 |
+
@dataclass
|
46 |
+
class AnimationPipelineOutput(BaseOutput):
|
47 |
+
videos: Union[torch.Tensor, np.ndarray]
|
48 |
+
|
49 |
+
@dataclass
|
50 |
+
class AnimationFreeInitPipelineOutput(BaseOutput):
|
51 |
+
videos: Union[torch.Tensor, np.ndarray]
|
52 |
+
orig_videos: Union[torch.Tensor, np.ndarray]
|
53 |
+
|
54 |
+
|
55 |
+
class AnimationPipeline(DiffusionPipeline):
|
56 |
+
_optional_components = []
|
57 |
+
|
58 |
+
def __init__(
|
59 |
+
self,
|
60 |
+
vae: AutoencoderKL,
|
61 |
+
text_encoder: CLIPTextModel,
|
62 |
+
tokenizer: CLIPTokenizer,
|
63 |
+
unet: UNet3DConditionModel,
|
64 |
+
scheduler: Union[
|
65 |
+
DDIMScheduler,
|
66 |
+
PNDMScheduler,
|
67 |
+
LMSDiscreteScheduler,
|
68 |
+
EulerDiscreteScheduler,
|
69 |
+
EulerAncestralDiscreteScheduler,
|
70 |
+
DPMSolverMultistepScheduler,
|
71 |
+
],
|
72 |
+
):
|
73 |
+
super().__init__()
|
74 |
+
|
75 |
+
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
76 |
+
deprecation_message = (
|
77 |
+
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
78 |
+
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
79 |
+
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
80 |
+
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
81 |
+
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
82 |
+
" file"
|
83 |
+
)
|
84 |
+
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
85 |
+
new_config = dict(scheduler.config)
|
86 |
+
new_config["steps_offset"] = 1
|
87 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
88 |
+
|
89 |
+
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
90 |
+
deprecation_message = (
|
91 |
+
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
92 |
+
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
93 |
+
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
94 |
+
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
95 |
+
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
96 |
+
)
|
97 |
+
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
98 |
+
new_config = dict(scheduler.config)
|
99 |
+
new_config["clip_sample"] = False
|
100 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
101 |
+
|
102 |
+
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
103 |
+
version.parse(unet.config._diffusers_version).base_version
|
104 |
+
) < version.parse("0.9.0.dev0")
|
105 |
+
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
106 |
+
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
107 |
+
deprecation_message = (
|
108 |
+
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
109 |
+
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
110 |
+
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
111 |
+
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
112 |
+
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
113 |
+
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
114 |
+
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
115 |
+
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
116 |
+
" the `unet/config.json` file"
|
117 |
+
)
|
118 |
+
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
119 |
+
new_config = dict(unet.config)
|
120 |
+
new_config["sample_size"] = 64
|
121 |
+
unet._internal_dict = FrozenDict(new_config)
|
122 |
+
|
123 |
+
self.register_modules(
|
124 |
+
vae=vae,
|
125 |
+
text_encoder=text_encoder,
|
126 |
+
tokenizer=tokenizer,
|
127 |
+
unet=unet,
|
128 |
+
scheduler=scheduler,
|
129 |
+
)
|
130 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
131 |
+
|
132 |
+
def enable_vae_slicing(self):
|
133 |
+
self.vae.enable_slicing()
|
134 |
+
|
135 |
+
def disable_vae_slicing(self):
|
136 |
+
self.vae.disable_slicing()
|
137 |
+
|
138 |
+
def enable_sequential_cpu_offload(self, gpu_id=0):
|
139 |
+
if is_accelerate_available():
|
140 |
+
from accelerate import cpu_offload
|
141 |
+
else:
|
142 |
+
raise ImportError("Please install accelerate via `pip install accelerate`")
|
143 |
+
|
144 |
+
device = torch.device(f"cuda:{gpu_id}")
|
145 |
+
|
146 |
+
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
147 |
+
if cpu_offloaded_model is not None:
|
148 |
+
cpu_offload(cpu_offloaded_model, device)
|
149 |
+
|
150 |
+
|
151 |
+
@property
|
152 |
+
def _execution_device(self):
|
153 |
+
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
154 |
+
return self.device
|
155 |
+
for module in self.unet.modules():
|
156 |
+
if (
|
157 |
+
hasattr(module, "_hf_hook")
|
158 |
+
and hasattr(module._hf_hook, "execution_device")
|
159 |
+
and module._hf_hook.execution_device is not None
|
160 |
+
):
|
161 |
+
return torch.device(module._hf_hook.execution_device)
|
162 |
+
return self.device
|
163 |
+
|
164 |
+
def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt):
|
165 |
+
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
166 |
+
|
167 |
+
text_inputs = self.tokenizer(
|
168 |
+
prompt,
|
169 |
+
padding="max_length",
|
170 |
+
max_length=self.tokenizer.model_max_length,
|
171 |
+
truncation=True,
|
172 |
+
return_tensors="pt",
|
173 |
+
)
|
174 |
+
text_input_ids = text_inputs.input_ids
|
175 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
176 |
+
|
177 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
178 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
|
179 |
+
logger.warning(
|
180 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
181 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
182 |
+
)
|
183 |
+
|
184 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
185 |
+
attention_mask = text_inputs.attention_mask.to(device)
|
186 |
+
else:
|
187 |
+
attention_mask = None
|
188 |
+
|
189 |
+
text_embeddings = self.text_encoder(
|
190 |
+
text_input_ids.to(device),
|
191 |
+
attention_mask=attention_mask,
|
192 |
+
)
|
193 |
+
text_embeddings = text_embeddings[0]
|
194 |
+
|
195 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
196 |
+
bs_embed, seq_len, _ = text_embeddings.shape
|
197 |
+
text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
|
198 |
+
text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
|
199 |
+
|
200 |
+
# get unconditional embeddings for classifier free guidance
|
201 |
+
if do_classifier_free_guidance:
|
202 |
+
uncond_tokens: List[str]
|
203 |
+
if negative_prompt is None:
|
204 |
+
uncond_tokens = [""] * batch_size
|
205 |
+
elif type(prompt) is not type(negative_prompt):
|
206 |
+
raise TypeError(
|
207 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
208 |
+
f" {type(prompt)}."
|
209 |
+
)
|
210 |
+
elif isinstance(negative_prompt, str):
|
211 |
+
uncond_tokens = [negative_prompt]
|
212 |
+
elif batch_size != len(negative_prompt):
|
213 |
+
raise ValueError(
|
214 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
215 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
216 |
+
" the batch size of `prompt`."
|
217 |
+
)
|
218 |
+
else:
|
219 |
+
uncond_tokens = negative_prompt
|
220 |
+
|
221 |
+
max_length = text_input_ids.shape[-1]
|
222 |
+
uncond_input = self.tokenizer(
|
223 |
+
uncond_tokens,
|
224 |
+
padding="max_length",
|
225 |
+
max_length=max_length,
|
226 |
+
truncation=True,
|
227 |
+
return_tensors="pt",
|
228 |
+
)
|
229 |
+
|
230 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
231 |
+
attention_mask = uncond_input.attention_mask.to(device)
|
232 |
+
else:
|
233 |
+
attention_mask = None
|
234 |
+
|
235 |
+
uncond_embeddings = self.text_encoder(
|
236 |
+
uncond_input.input_ids.to(device),
|
237 |
+
attention_mask=attention_mask,
|
238 |
+
)
|
239 |
+
uncond_embeddings = uncond_embeddings[0]
|
240 |
+
|
241 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
242 |
+
seq_len = uncond_embeddings.shape[1]
|
243 |
+
uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
|
244 |
+
uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
245 |
+
|
246 |
+
# For classifier free guidance, we need to do two forward passes.
|
247 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
248 |
+
# to avoid doing two forward passes
|
249 |
+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
250 |
+
|
251 |
+
return text_embeddings
|
252 |
+
|
253 |
+
def decode_latents(self, latents):
|
254 |
+
video_length = latents.shape[2]
|
255 |
+
latents = 1 / 0.18215 * latents
|
256 |
+
latents = rearrange(latents, "b c f h w -> (b f) c h w")
|
257 |
+
# video = self.vae.decode(latents).sample
|
258 |
+
video = []
|
259 |
+
for frame_idx in tqdm(range(latents.shape[0])):
|
260 |
+
video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
|
261 |
+
video = torch.cat(video)
|
262 |
+
video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
|
263 |
+
video = (video / 2 + 0.5).clamp(0, 1)
|
264 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
265 |
+
video = video.cpu().float().numpy()
|
266 |
+
return video
|
267 |
+
|
268 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
269 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
270 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
271 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
272 |
+
# and should be between [0, 1]
|
273 |
+
|
274 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
275 |
+
extra_step_kwargs = {}
|
276 |
+
if accepts_eta:
|
277 |
+
extra_step_kwargs["eta"] = eta
|
278 |
+
|
279 |
+
# check if the scheduler accepts generator
|
280 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
281 |
+
if accepts_generator:
|
282 |
+
extra_step_kwargs["generator"] = generator
|
283 |
+
return extra_step_kwargs
|
284 |
+
|
285 |
+
def check_inputs(self, prompt, height, width, callback_steps):
|
286 |
+
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
287 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
288 |
+
|
289 |
+
if height % 8 != 0 or width % 8 != 0:
|
290 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
291 |
+
|
292 |
+
if (callback_steps is None) or (
|
293 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
294 |
+
):
|
295 |
+
raise ValueError(
|
296 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
297 |
+
f" {type(callback_steps)}."
|
298 |
+
)
|
299 |
+
|
300 |
+
def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
|
301 |
+
shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
302 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
303 |
+
raise ValueError(
|
304 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
305 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
306 |
+
)
|
307 |
+
if latents is None:
|
308 |
+
rand_device = "cpu" if device.type == "mps" else device
|
309 |
+
|
310 |
+
if isinstance(generator, list):
|
311 |
+
shape = shape
|
312 |
+
# shape = (1,) + shape[1:]
|
313 |
+
latents = [
|
314 |
+
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
|
315 |
+
for i in range(batch_size)
|
316 |
+
]
|
317 |
+
latents = torch.cat(latents, dim=0).to(device)
|
318 |
+
else:
|
319 |
+
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
320 |
+
else:
|
321 |
+
if latents.shape != shape:
|
322 |
+
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
323 |
+
latents = latents.to(device)
|
324 |
+
|
325 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
326 |
+
latents = latents * self.scheduler.init_noise_sigma
|
327 |
+
return latents
|
328 |
+
|
329 |
+
@torch.no_grad()
|
330 |
+
def __call__(
|
331 |
+
self,
|
332 |
+
prompt: Union[str, List[str]],
|
333 |
+
video_length: Optional[int],
|
334 |
+
height: Optional[int] = None,
|
335 |
+
width: Optional[int] = None,
|
336 |
+
num_inference_steps: int = 50,
|
337 |
+
guidance_scale: float = 7.5,
|
338 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
339 |
+
num_videos_per_prompt: Optional[int] = 1,
|
340 |
+
eta: float = 0.0,
|
341 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
342 |
+
latents: Optional[torch.FloatTensor] = None,
|
343 |
+
output_type: Optional[str] = "tensor",
|
344 |
+
return_dict: bool = True,
|
345 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
346 |
+
callback_steps: Optional[int] = 1,
|
347 |
+
**kwargs,
|
348 |
+
):
|
349 |
+
# Default height and width to unet
|
350 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
351 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
352 |
+
|
353 |
+
# Check inputs. Raise error if not correct
|
354 |
+
self.check_inputs(prompt, height, width, callback_steps)
|
355 |
+
|
356 |
+
# Define call parameters
|
357 |
+
# batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
358 |
+
batch_size = 1
|
359 |
+
if latents is not None:
|
360 |
+
batch_size = latents.shape[0]
|
361 |
+
if isinstance(prompt, list):
|
362 |
+
batch_size = len(prompt)
|
363 |
+
|
364 |
+
device = self._execution_device
|
365 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
366 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
367 |
+
# corresponds to doing no classifier free guidance.
|
368 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
369 |
+
|
370 |
+
# Encode input prompt
|
371 |
+
prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size
|
372 |
+
if negative_prompt is not None:
|
373 |
+
negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size
|
374 |
+
text_embeddings = self._encode_prompt(
|
375 |
+
prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
|
376 |
+
)
|
377 |
+
|
378 |
+
# Prepare timesteps
|
379 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
380 |
+
timesteps = self.scheduler.timesteps
|
381 |
+
|
382 |
+
# Prepare latent variables
|
383 |
+
num_channels_latents = self.unet.in_channels
|
384 |
+
latents = self.prepare_latents(
|
385 |
+
batch_size * num_videos_per_prompt,
|
386 |
+
num_channels_latents,
|
387 |
+
video_length,
|
388 |
+
height,
|
389 |
+
width,
|
390 |
+
text_embeddings.dtype,
|
391 |
+
device,
|
392 |
+
generator,
|
393 |
+
latents,
|
394 |
+
)
|
395 |
+
latents_dtype = latents.dtype
|
396 |
+
|
397 |
+
# Prepare extra step kwargs.
|
398 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
399 |
+
|
400 |
+
# Denoising loop
|
401 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
402 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
403 |
+
for i, t in enumerate(timesteps):
|
404 |
+
# expand the latents if we are doing classifier free guidance
|
405 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
406 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
407 |
+
|
408 |
+
# predict the noise residual
|
409 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample.to(dtype=latents_dtype)
|
410 |
+
|
411 |
+
# perform guidance
|
412 |
+
if do_classifier_free_guidance:
|
413 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
414 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
415 |
+
|
416 |
+
# compute the previous noisy sample x_t -> x_t-1
|
417 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
418 |
+
|
419 |
+
# call the callback, if provided
|
420 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
421 |
+
progress_bar.update()
|
422 |
+
if callback is not None and i % callback_steps == 0:
|
423 |
+
callback(i, t, latents)
|
424 |
+
|
425 |
+
# Post-processing
|
426 |
+
video = self.decode_latents(latents)
|
427 |
+
|
428 |
+
# Convert to tensor
|
429 |
+
if output_type == "tensor":
|
430 |
+
video = torch.from_numpy(video)
|
431 |
+
|
432 |
+
if not return_dict:
|
433 |
+
return video
|
434 |
+
|
435 |
+
return AnimationPipelineOutput(videos=video)
|
436 |
+
|
437 |
+
|
438 |
+
class AnimationFreeInitPipeline(AnimationPipeline):
|
439 |
+
_optional_components = []
|
440 |
+
|
441 |
+
def __init__(
|
442 |
+
self,
|
443 |
+
vae: AutoencoderKL,
|
444 |
+
text_encoder: CLIPTextModel,
|
445 |
+
tokenizer: CLIPTokenizer,
|
446 |
+
unet: UNet3DConditionModel,
|
447 |
+
scheduler: Union[
|
448 |
+
DDIMScheduler,
|
449 |
+
PNDMScheduler,
|
450 |
+
LMSDiscreteScheduler,
|
451 |
+
EulerDiscreteScheduler,
|
452 |
+
EulerAncestralDiscreteScheduler,
|
453 |
+
DPMSolverMultistepScheduler,
|
454 |
+
],
|
455 |
+
):
|
456 |
+
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
|
457 |
+
self.freq_filter = None
|
458 |
+
|
459 |
+
|
460 |
+
@torch.no_grad()
|
461 |
+
def init_filter(self, video_length, height, width, filter_params):
|
462 |
+
# initialize frequency filter for noise reinitialization
|
463 |
+
batch_size = 1
|
464 |
+
num_channels_latents = self.unet.in_channels
|
465 |
+
filter_shape = [
|
466 |
+
batch_size,
|
467 |
+
num_channels_latents,
|
468 |
+
video_length,
|
469 |
+
height // self.vae_scale_factor,
|
470 |
+
width // self.vae_scale_factor
|
471 |
+
]
|
472 |
+
# self.freq_filter = get_freq_filter(filter_shape, device=self._execution_device, params=filter_params)
|
473 |
+
self.freq_filter = get_freq_filter(
|
474 |
+
filter_shape,
|
475 |
+
device=self._execution_device,
|
476 |
+
filter_type=filter_params.method,
|
477 |
+
n=filter_params.n,
|
478 |
+
d_s=filter_params.d_s,
|
479 |
+
d_t=filter_params.d_t
|
480 |
+
)
|
481 |
+
|
482 |
+
@torch.no_grad()
|
483 |
+
def __call__(
|
484 |
+
self,
|
485 |
+
prompt: Union[str, List[str]],
|
486 |
+
video_length: Optional[int],
|
487 |
+
height: Optional[int] = None,
|
488 |
+
width: Optional[int] = None,
|
489 |
+
num_inference_steps: int = 50,
|
490 |
+
guidance_scale: float = 7.5,
|
491 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
492 |
+
num_videos_per_prompt: Optional[int] = 1,
|
493 |
+
eta: float = 0.0,
|
494 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
495 |
+
latents: Optional[torch.FloatTensor] = None,
|
496 |
+
output_type: Optional[str] = "tensor",
|
497 |
+
return_dict: bool = True,
|
498 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
499 |
+
callback_steps: Optional[int] = 1,
|
500 |
+
# freeinit args
|
501 |
+
num_iters: int = 5,
|
502 |
+
use_fast_sampling: bool = False,
|
503 |
+
save_intermediate: bool = False,
|
504 |
+
return_orig: bool = False,
|
505 |
+
save_dir: str = None,
|
506 |
+
save_name: str = None,
|
507 |
+
use_fp16: bool = False,
|
508 |
+
**kwargs
|
509 |
+
):
|
510 |
+
if use_fp16:
|
511 |
+
print('Warning: using half percision for inferencing!')
|
512 |
+
self.vae.to(dtype=torch.float16)
|
513 |
+
self.unet.to(dtype=torch.float16)
|
514 |
+
self.text_encoder.to(dtype=torch.float16)
|
515 |
+
# Default height and width to unet
|
516 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
517 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
518 |
+
|
519 |
+
# Check inputs. Raise error if not correct
|
520 |
+
# import pdb
|
521 |
+
# pdb.set_trace()
|
522 |
+
self.check_inputs(prompt, height, width, callback_steps)
|
523 |
+
|
524 |
+
# Define call parameters
|
525 |
+
# batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
526 |
+
batch_size = 1
|
527 |
+
if latents is not None:
|
528 |
+
batch_size = latents.shape[0]
|
529 |
+
if isinstance(prompt, list):
|
530 |
+
batch_size = len(prompt)
|
531 |
+
|
532 |
+
device = self._execution_device
|
533 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
534 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
535 |
+
# corresponds to doing no classifier free guidance.
|
536 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
537 |
+
|
538 |
+
# Encode input prompt
|
539 |
+
prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size
|
540 |
+
if negative_prompt is not None:
|
541 |
+
negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size
|
542 |
+
text_embeddings = self._encode_prompt(
|
543 |
+
prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
|
544 |
+
)
|
545 |
+
|
546 |
+
# Prepare timesteps
|
547 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
548 |
+
timesteps = self.scheduler.timesteps
|
549 |
+
|
550 |
+
# Prepare latent variables
|
551 |
+
num_channels_latents = self.unet.in_channels
|
552 |
+
latents = self.prepare_latents(
|
553 |
+
batch_size * num_videos_per_prompt,
|
554 |
+
num_channels_latents,
|
555 |
+
video_length,
|
556 |
+
height,
|
557 |
+
width,
|
558 |
+
text_embeddings.dtype,
|
559 |
+
device,
|
560 |
+
generator,
|
561 |
+
latents,
|
562 |
+
)
|
563 |
+
latents_dtype = latents.dtype
|
564 |
+
|
565 |
+
# Prepare extra step kwargs.
|
566 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
567 |
+
|
568 |
+
# Sampling with FreeInit.
|
569 |
+
for iter in range(num_iters):
|
570 |
+
# FreeInit ------------------------------------------------------------------
|
571 |
+
if iter == 0:
|
572 |
+
initial_noise = latents.detach().clone()
|
573 |
+
else:
|
574 |
+
# 1. DDPM Forward with initial noise, get noisy latents z_T
|
575 |
+
# if use_fast_sampling:
|
576 |
+
# current_diffuse_timestep = self.scheduler.config.num_train_timesteps / num_iters * (iter + 1) - 1
|
577 |
+
# else:
|
578 |
+
# current_diffuse_timestep = self.scheduler.config.num_train_timesteps - 1
|
579 |
+
current_diffuse_timestep = self.scheduler.config.num_train_timesteps - 1 # diffuse to t=999 noise level
|
580 |
+
diffuse_timesteps = torch.full((batch_size,),int(current_diffuse_timestep))
|
581 |
+
diffuse_timesteps = diffuse_timesteps.long()
|
582 |
+
z_T = self.scheduler.add_noise(
|
583 |
+
original_samples=latents.to(device),
|
584 |
+
noise=initial_noise.to(device),
|
585 |
+
timesteps=diffuse_timesteps.to(device)
|
586 |
+
)
|
587 |
+
# 2. create random noise z_rand for high-frequency
|
588 |
+
z_rand = torch.randn((batch_size * num_videos_per_prompt, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor), device=device)
|
589 |
+
# 3. Roise Reinitialization
|
590 |
+
latents = freq_mix_3d(z_T.to(dtype=torch.float32), z_rand, LPF=self.freq_filter)
|
591 |
+
latents = latents.to(latents_dtype)
|
592 |
+
|
593 |
+
# Coarse-to-Fine Sampling for Fast Inference (can lead to sub-optimal results)
|
594 |
+
if use_fast_sampling:
|
595 |
+
current_num_inference_steps= int(num_inference_steps / num_iters * (iter + 1))
|
596 |
+
self.scheduler.set_timesteps(current_num_inference_steps, device=device)
|
597 |
+
timesteps = self.scheduler.timesteps
|
598 |
+
# --------------------------------------------------------------------------
|
599 |
+
|
600 |
+
# Denoising loop
|
601 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
602 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
603 |
+
# if use_fast_sampling:
|
604 |
+
# # Coarse-to-Fine Sampling for Fast Inference
|
605 |
+
# current_num_inference_steps= int(num_inference_steps / num_iters * (iter + 1))
|
606 |
+
# current_timesteps = timesteps[:current_num_inference_steps]
|
607 |
+
# else:
|
608 |
+
current_timesteps = timesteps
|
609 |
+
for i, t in enumerate(current_timesteps):
|
610 |
+
# expand the latents if we are doing classifier free guidance
|
611 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
612 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
613 |
+
|
614 |
+
# predict the noise residual
|
615 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample.to(dtype=latents_dtype)
|
616 |
+
|
617 |
+
# perform guidance
|
618 |
+
if do_classifier_free_guidance:
|
619 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
620 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
621 |
+
|
622 |
+
# compute the previous noisy sample x_t -> x_t-1
|
623 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
624 |
+
|
625 |
+
# call the callback, if provided
|
626 |
+
if i == len(current_timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
627 |
+
progress_bar.update()
|
628 |
+
if callback is not None and i % callback_steps == 0:
|
629 |
+
callback(i, t, latents)
|
630 |
+
|
631 |
+
# save intermediate results
|
632 |
+
if save_intermediate:
|
633 |
+
# Post-processing
|
634 |
+
video = self.decode_latents(latents)
|
635 |
+
video = torch.from_numpy(video)
|
636 |
+
os.makedirs(save_dir, exist_ok=True)
|
637 |
+
save_videos_grid(video, f"{save_dir}/{save_name}_iter{iter}.gif")
|
638 |
+
|
639 |
+
if return_orig and iter==0:
|
640 |
+
orig_video = self.decode_latents(latents)
|
641 |
+
orig_video = torch.from_numpy(orig_video)
|
642 |
+
|
643 |
+
# Post-processing
|
644 |
+
video = self.decode_latents(latents)
|
645 |
+
|
646 |
+
# Convert to tensor
|
647 |
+
if output_type == "tensor":
|
648 |
+
video = torch.from_numpy(video)
|
649 |
+
|
650 |
+
if not return_dict:
|
651 |
+
return video
|
652 |
+
|
653 |
+
if return_orig:
|
654 |
+
return AnimationFreeInitPipelineOutput(videos=video, orig_videos=orig_video)
|
655 |
+
|
656 |
+
return AnimationFreeInitPipelineOutput(videos=video)
|
animatediff/utils/convert_from_ckpt.py
ADDED
@@ -0,0 +1,959 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" Conversion script for the Stable Diffusion checkpoints."""
|
16 |
+
|
17 |
+
import re
|
18 |
+
from io import BytesIO
|
19 |
+
from typing import Optional
|
20 |
+
|
21 |
+
import requests
|
22 |
+
import torch
|
23 |
+
from transformers import (
|
24 |
+
AutoFeatureExtractor,
|
25 |
+
BertTokenizerFast,
|
26 |
+
CLIPImageProcessor,
|
27 |
+
CLIPTextModel,
|
28 |
+
CLIPTextModelWithProjection,
|
29 |
+
CLIPTokenizer,
|
30 |
+
CLIPVisionConfig,
|
31 |
+
CLIPVisionModelWithProjection,
|
32 |
+
)
|
33 |
+
|
34 |
+
from diffusers.models import (
|
35 |
+
AutoencoderKL,
|
36 |
+
PriorTransformer,
|
37 |
+
UNet2DConditionModel,
|
38 |
+
)
|
39 |
+
from diffusers.schedulers import (
|
40 |
+
DDIMScheduler,
|
41 |
+
DDPMScheduler,
|
42 |
+
DPMSolverMultistepScheduler,
|
43 |
+
EulerAncestralDiscreteScheduler,
|
44 |
+
EulerDiscreteScheduler,
|
45 |
+
HeunDiscreteScheduler,
|
46 |
+
LMSDiscreteScheduler,
|
47 |
+
PNDMScheduler,
|
48 |
+
UnCLIPScheduler,
|
49 |
+
)
|
50 |
+
from diffusers.utils.import_utils import BACKENDS_MAPPING
|
51 |
+
|
52 |
+
|
53 |
+
def shave_segments(path, n_shave_prefix_segments=1):
|
54 |
+
"""
|
55 |
+
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
56 |
+
"""
|
57 |
+
if n_shave_prefix_segments >= 0:
|
58 |
+
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
59 |
+
else:
|
60 |
+
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
61 |
+
|
62 |
+
|
63 |
+
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
64 |
+
"""
|
65 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
66 |
+
"""
|
67 |
+
mapping = []
|
68 |
+
for old_item in old_list:
|
69 |
+
new_item = old_item.replace("in_layers.0", "norm1")
|
70 |
+
new_item = new_item.replace("in_layers.2", "conv1")
|
71 |
+
|
72 |
+
new_item = new_item.replace("out_layers.0", "norm2")
|
73 |
+
new_item = new_item.replace("out_layers.3", "conv2")
|
74 |
+
|
75 |
+
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
|
76 |
+
new_item = new_item.replace("skip_connection", "conv_shortcut")
|
77 |
+
|
78 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
79 |
+
|
80 |
+
mapping.append({"old": old_item, "new": new_item})
|
81 |
+
|
82 |
+
return mapping
|
83 |
+
|
84 |
+
|
85 |
+
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
|
86 |
+
"""
|
87 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
88 |
+
"""
|
89 |
+
mapping = []
|
90 |
+
for old_item in old_list:
|
91 |
+
new_item = old_item
|
92 |
+
|
93 |
+
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
|
94 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
95 |
+
|
96 |
+
mapping.append({"old": old_item, "new": new_item})
|
97 |
+
|
98 |
+
return mapping
|
99 |
+
|
100 |
+
|
101 |
+
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
102 |
+
"""
|
103 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
104 |
+
"""
|
105 |
+
mapping = []
|
106 |
+
for old_item in old_list:
|
107 |
+
new_item = old_item
|
108 |
+
|
109 |
+
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
110 |
+
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
111 |
+
|
112 |
+
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
113 |
+
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
114 |
+
|
115 |
+
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
116 |
+
|
117 |
+
mapping.append({"old": old_item, "new": new_item})
|
118 |
+
|
119 |
+
return mapping
|
120 |
+
|
121 |
+
|
122 |
+
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
123 |
+
"""
|
124 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
125 |
+
"""
|
126 |
+
mapping = []
|
127 |
+
for old_item in old_list:
|
128 |
+
new_item = old_item
|
129 |
+
|
130 |
+
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
131 |
+
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
132 |
+
|
133 |
+
new_item = new_item.replace("q.weight", "query.weight")
|
134 |
+
new_item = new_item.replace("q.bias", "query.bias")
|
135 |
+
|
136 |
+
new_item = new_item.replace("k.weight", "key.weight")
|
137 |
+
new_item = new_item.replace("k.bias", "key.bias")
|
138 |
+
|
139 |
+
new_item = new_item.replace("v.weight", "value.weight")
|
140 |
+
new_item = new_item.replace("v.bias", "value.bias")
|
141 |
+
|
142 |
+
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
143 |
+
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
144 |
+
|
145 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
146 |
+
|
147 |
+
mapping.append({"old": old_item, "new": new_item})
|
148 |
+
|
149 |
+
return mapping
|
150 |
+
|
151 |
+
|
152 |
+
def assign_to_checkpoint(
|
153 |
+
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
154 |
+
):
|
155 |
+
"""
|
156 |
+
This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
|
157 |
+
attention layers, and takes into account additional replacements that may arise.
|
158 |
+
|
159 |
+
Assigns the weights to the new checkpoint.
|
160 |
+
"""
|
161 |
+
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
162 |
+
|
163 |
+
# Splits the attention layers into three variables.
|
164 |
+
if attention_paths_to_split is not None:
|
165 |
+
for path, path_map in attention_paths_to_split.items():
|
166 |
+
old_tensor = old_checkpoint[path]
|
167 |
+
channels = old_tensor.shape[0] // 3
|
168 |
+
|
169 |
+
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
170 |
+
|
171 |
+
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
172 |
+
|
173 |
+
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
174 |
+
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
175 |
+
|
176 |
+
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
177 |
+
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
178 |
+
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
179 |
+
|
180 |
+
for path in paths:
|
181 |
+
new_path = path["new"]
|
182 |
+
|
183 |
+
# These have already been assigned
|
184 |
+
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
185 |
+
continue
|
186 |
+
|
187 |
+
# Global renaming happens here
|
188 |
+
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
|
189 |
+
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
190 |
+
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
|
191 |
+
|
192 |
+
if additional_replacements is not None:
|
193 |
+
for replacement in additional_replacements:
|
194 |
+
new_path = new_path.replace(replacement["old"], replacement["new"])
|
195 |
+
|
196 |
+
# proj_attn.weight has to be converted from conv 1D to linear
|
197 |
+
if "proj_attn.weight" in new_path:
|
198 |
+
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
|
199 |
+
else:
|
200 |
+
checkpoint[new_path] = old_checkpoint[path["old"]]
|
201 |
+
|
202 |
+
|
203 |
+
def conv_attn_to_linear(checkpoint):
|
204 |
+
keys = list(checkpoint.keys())
|
205 |
+
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
206 |
+
for key in keys:
|
207 |
+
if ".".join(key.split(".")[-2:]) in attn_keys:
|
208 |
+
if checkpoint[key].ndim > 2:
|
209 |
+
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
210 |
+
elif "proj_attn.weight" in key:
|
211 |
+
if checkpoint[key].ndim > 2:
|
212 |
+
checkpoint[key] = checkpoint[key][:, :, 0]
|
213 |
+
|
214 |
+
|
215 |
+
def create_unet_diffusers_config(original_config, image_size: int, controlnet=False):
|
216 |
+
"""
|
217 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
218 |
+
"""
|
219 |
+
if controlnet:
|
220 |
+
unet_params = original_config.model.params.control_stage_config.params
|
221 |
+
else:
|
222 |
+
unet_params = original_config.model.params.unet_config.params
|
223 |
+
|
224 |
+
vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
225 |
+
|
226 |
+
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
|
227 |
+
|
228 |
+
down_block_types = []
|
229 |
+
resolution = 1
|
230 |
+
for i in range(len(block_out_channels)):
|
231 |
+
block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
|
232 |
+
down_block_types.append(block_type)
|
233 |
+
if i != len(block_out_channels) - 1:
|
234 |
+
resolution *= 2
|
235 |
+
|
236 |
+
up_block_types = []
|
237 |
+
for i in range(len(block_out_channels)):
|
238 |
+
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
|
239 |
+
up_block_types.append(block_type)
|
240 |
+
resolution //= 2
|
241 |
+
|
242 |
+
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
|
243 |
+
|
244 |
+
head_dim = unet_params.num_heads if "num_heads" in unet_params else None
|
245 |
+
use_linear_projection = (
|
246 |
+
unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
|
247 |
+
)
|
248 |
+
if use_linear_projection:
|
249 |
+
# stable diffusion 2-base-512 and 2-768
|
250 |
+
if head_dim is None:
|
251 |
+
head_dim = [5, 10, 20, 20]
|
252 |
+
|
253 |
+
class_embed_type = None
|
254 |
+
projection_class_embeddings_input_dim = None
|
255 |
+
|
256 |
+
if "num_classes" in unet_params:
|
257 |
+
if unet_params.num_classes == "sequential":
|
258 |
+
class_embed_type = "projection"
|
259 |
+
assert "adm_in_channels" in unet_params
|
260 |
+
projection_class_embeddings_input_dim = unet_params.adm_in_channels
|
261 |
+
else:
|
262 |
+
raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}")
|
263 |
+
|
264 |
+
config = {
|
265 |
+
"sample_size": image_size // vae_scale_factor,
|
266 |
+
"in_channels": unet_params.in_channels,
|
267 |
+
"down_block_types": tuple(down_block_types),
|
268 |
+
"block_out_channels": tuple(block_out_channels),
|
269 |
+
"layers_per_block": unet_params.num_res_blocks,
|
270 |
+
"cross_attention_dim": unet_params.context_dim,
|
271 |
+
"attention_head_dim": head_dim,
|
272 |
+
"use_linear_projection": use_linear_projection,
|
273 |
+
"class_embed_type": class_embed_type,
|
274 |
+
"projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
|
275 |
+
}
|
276 |
+
|
277 |
+
if not controlnet:
|
278 |
+
config["out_channels"] = unet_params.out_channels
|
279 |
+
config["up_block_types"] = tuple(up_block_types)
|
280 |
+
|
281 |
+
return config
|
282 |
+
|
283 |
+
|
284 |
+
def create_vae_diffusers_config(original_config, image_size: int):
|
285 |
+
"""
|
286 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
287 |
+
"""
|
288 |
+
vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
289 |
+
_ = original_config.model.params.first_stage_config.params.embed_dim
|
290 |
+
|
291 |
+
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
|
292 |
+
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
293 |
+
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
294 |
+
|
295 |
+
config = {
|
296 |
+
"sample_size": image_size,
|
297 |
+
"in_channels": vae_params.in_channels,
|
298 |
+
"out_channels": vae_params.out_ch,
|
299 |
+
"down_block_types": tuple(down_block_types),
|
300 |
+
"up_block_types": tuple(up_block_types),
|
301 |
+
"block_out_channels": tuple(block_out_channels),
|
302 |
+
"latent_channels": vae_params.z_channels,
|
303 |
+
"layers_per_block": vae_params.num_res_blocks,
|
304 |
+
}
|
305 |
+
return config
|
306 |
+
|
307 |
+
|
308 |
+
def create_diffusers_schedular(original_config):
|
309 |
+
schedular = DDIMScheduler(
|
310 |
+
num_train_timesteps=original_config.model.params.timesteps,
|
311 |
+
beta_start=original_config.model.params.linear_start,
|
312 |
+
beta_end=original_config.model.params.linear_end,
|
313 |
+
beta_schedule="scaled_linear",
|
314 |
+
)
|
315 |
+
return schedular
|
316 |
+
|
317 |
+
|
318 |
+
def create_ldm_bert_config(original_config):
|
319 |
+
bert_params = original_config.model.parms.cond_stage_config.params
|
320 |
+
config = LDMBertConfig(
|
321 |
+
d_model=bert_params.n_embed,
|
322 |
+
encoder_layers=bert_params.n_layer,
|
323 |
+
encoder_ffn_dim=bert_params.n_embed * 4,
|
324 |
+
)
|
325 |
+
return config
|
326 |
+
|
327 |
+
|
328 |
+
def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False):
|
329 |
+
"""
|
330 |
+
Takes a state dict and a config, and returns a converted checkpoint.
|
331 |
+
"""
|
332 |
+
|
333 |
+
# extract state_dict for UNet
|
334 |
+
unet_state_dict = {}
|
335 |
+
keys = list(checkpoint.keys())
|
336 |
+
|
337 |
+
if controlnet:
|
338 |
+
unet_key = "control_model."
|
339 |
+
else:
|
340 |
+
unet_key = "model.diffusion_model."
|
341 |
+
|
342 |
+
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
343 |
+
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
|
344 |
+
print(f"Checkpoint {path} has both EMA and non-EMA weights.")
|
345 |
+
print(
|
346 |
+
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
|
347 |
+
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
|
348 |
+
)
|
349 |
+
for key in keys:
|
350 |
+
if key.startswith("model.diffusion_model"):
|
351 |
+
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
352 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
|
353 |
+
else:
|
354 |
+
if sum(k.startswith("model_ema") for k in keys) > 100:
|
355 |
+
print(
|
356 |
+
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
|
357 |
+
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
|
358 |
+
)
|
359 |
+
|
360 |
+
for key in keys:
|
361 |
+
if key.startswith(unet_key):
|
362 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
363 |
+
|
364 |
+
new_checkpoint = {}
|
365 |
+
|
366 |
+
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
|
367 |
+
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
|
368 |
+
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
|
369 |
+
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
|
370 |
+
|
371 |
+
if config["class_embed_type"] is None:
|
372 |
+
# No parameters to port
|
373 |
+
...
|
374 |
+
elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
|
375 |
+
new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
|
376 |
+
new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
|
377 |
+
new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
|
378 |
+
new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
|
379 |
+
else:
|
380 |
+
raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
|
381 |
+
|
382 |
+
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
383 |
+
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
384 |
+
|
385 |
+
if not controlnet:
|
386 |
+
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
|
387 |
+
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
|
388 |
+
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
|
389 |
+
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
|
390 |
+
|
391 |
+
# Retrieves the keys for the input blocks only
|
392 |
+
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
393 |
+
input_blocks = {
|
394 |
+
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
|
395 |
+
for layer_id in range(num_input_blocks)
|
396 |
+
}
|
397 |
+
|
398 |
+
# Retrieves the keys for the middle blocks only
|
399 |
+
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
400 |
+
middle_blocks = {
|
401 |
+
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
|
402 |
+
for layer_id in range(num_middle_blocks)
|
403 |
+
}
|
404 |
+
|
405 |
+
# Retrieves the keys for the output blocks only
|
406 |
+
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
407 |
+
output_blocks = {
|
408 |
+
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
|
409 |
+
for layer_id in range(num_output_blocks)
|
410 |
+
}
|
411 |
+
|
412 |
+
for i in range(1, num_input_blocks):
|
413 |
+
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
414 |
+
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
|
415 |
+
|
416 |
+
resnets = [
|
417 |
+
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
|
418 |
+
]
|
419 |
+
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
420 |
+
|
421 |
+
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
|
422 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
|
423 |
+
f"input_blocks.{i}.0.op.weight"
|
424 |
+
)
|
425 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
|
426 |
+
f"input_blocks.{i}.0.op.bias"
|
427 |
+
)
|
428 |
+
|
429 |
+
paths = renew_resnet_paths(resnets)
|
430 |
+
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
431 |
+
assign_to_checkpoint(
|
432 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
433 |
+
)
|
434 |
+
|
435 |
+
if len(attentions):
|
436 |
+
paths = renew_attention_paths(attentions)
|
437 |
+
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
|
438 |
+
assign_to_checkpoint(
|
439 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
440 |
+
)
|
441 |
+
|
442 |
+
resnet_0 = middle_blocks[0]
|
443 |
+
attentions = middle_blocks[1]
|
444 |
+
resnet_1 = middle_blocks[2]
|
445 |
+
|
446 |
+
resnet_0_paths = renew_resnet_paths(resnet_0)
|
447 |
+
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
|
448 |
+
|
449 |
+
resnet_1_paths = renew_resnet_paths(resnet_1)
|
450 |
+
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
|
451 |
+
|
452 |
+
attentions_paths = renew_attention_paths(attentions)
|
453 |
+
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
|
454 |
+
assign_to_checkpoint(
|
455 |
+
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
456 |
+
)
|
457 |
+
|
458 |
+
for i in range(num_output_blocks):
|
459 |
+
block_id = i // (config["layers_per_block"] + 1)
|
460 |
+
layer_in_block_id = i % (config["layers_per_block"] + 1)
|
461 |
+
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
462 |
+
output_block_list = {}
|
463 |
+
|
464 |
+
for layer in output_block_layers:
|
465 |
+
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
466 |
+
if layer_id in output_block_list:
|
467 |
+
output_block_list[layer_id].append(layer_name)
|
468 |
+
else:
|
469 |
+
output_block_list[layer_id] = [layer_name]
|
470 |
+
|
471 |
+
if len(output_block_list) > 1:
|
472 |
+
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
473 |
+
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
|
474 |
+
|
475 |
+
resnet_0_paths = renew_resnet_paths(resnets)
|
476 |
+
paths = renew_resnet_paths(resnets)
|
477 |
+
|
478 |
+
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
479 |
+
assign_to_checkpoint(
|
480 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
481 |
+
)
|
482 |
+
|
483 |
+
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
|
484 |
+
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
485 |
+
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
|
486 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
487 |
+
f"output_blocks.{i}.{index}.conv.weight"
|
488 |
+
]
|
489 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
490 |
+
f"output_blocks.{i}.{index}.conv.bias"
|
491 |
+
]
|
492 |
+
|
493 |
+
# Clear attentions as they have been attributed above.
|
494 |
+
if len(attentions) == 2:
|
495 |
+
attentions = []
|
496 |
+
|
497 |
+
if len(attentions):
|
498 |
+
paths = renew_attention_paths(attentions)
|
499 |
+
meta_path = {
|
500 |
+
"old": f"output_blocks.{i}.1",
|
501 |
+
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
|
502 |
+
}
|
503 |
+
assign_to_checkpoint(
|
504 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
505 |
+
)
|
506 |
+
else:
|
507 |
+
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
508 |
+
for path in resnet_0_paths:
|
509 |
+
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
510 |
+
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
|
511 |
+
|
512 |
+
new_checkpoint[new_path] = unet_state_dict[old_path]
|
513 |
+
|
514 |
+
if controlnet:
|
515 |
+
# conditioning embedding
|
516 |
+
|
517 |
+
orig_index = 0
|
518 |
+
|
519 |
+
new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop(
|
520 |
+
f"input_hint_block.{orig_index}.weight"
|
521 |
+
)
|
522 |
+
new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop(
|
523 |
+
f"input_hint_block.{orig_index}.bias"
|
524 |
+
)
|
525 |
+
|
526 |
+
orig_index += 2
|
527 |
+
|
528 |
+
diffusers_index = 0
|
529 |
+
|
530 |
+
while diffusers_index < 6:
|
531 |
+
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop(
|
532 |
+
f"input_hint_block.{orig_index}.weight"
|
533 |
+
)
|
534 |
+
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop(
|
535 |
+
f"input_hint_block.{orig_index}.bias"
|
536 |
+
)
|
537 |
+
diffusers_index += 1
|
538 |
+
orig_index += 2
|
539 |
+
|
540 |
+
new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop(
|
541 |
+
f"input_hint_block.{orig_index}.weight"
|
542 |
+
)
|
543 |
+
new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop(
|
544 |
+
f"input_hint_block.{orig_index}.bias"
|
545 |
+
)
|
546 |
+
|
547 |
+
# down blocks
|
548 |
+
for i in range(num_input_blocks):
|
549 |
+
new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight")
|
550 |
+
new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias")
|
551 |
+
|
552 |
+
# mid block
|
553 |
+
new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight")
|
554 |
+
new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias")
|
555 |
+
|
556 |
+
return new_checkpoint
|
557 |
+
|
558 |
+
|
559 |
+
def convert_ldm_vae_checkpoint(checkpoint, config):
|
560 |
+
# extract state dict for VAE
|
561 |
+
vae_state_dict = {}
|
562 |
+
vae_key = "first_stage_model."
|
563 |
+
keys = list(checkpoint.keys())
|
564 |
+
for key in keys:
|
565 |
+
if key.startswith(vae_key):
|
566 |
+
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
567 |
+
|
568 |
+
new_checkpoint = {}
|
569 |
+
|
570 |
+
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
571 |
+
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
|
572 |
+
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
|
573 |
+
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
|
574 |
+
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
|
575 |
+
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
|
576 |
+
|
577 |
+
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
|
578 |
+
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
|
579 |
+
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
|
580 |
+
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
|
581 |
+
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
|
582 |
+
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
|
583 |
+
|
584 |
+
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
|
585 |
+
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
|
586 |
+
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
|
587 |
+
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
|
588 |
+
|
589 |
+
# Retrieves the keys for the encoder down blocks only
|
590 |
+
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
|
591 |
+
down_blocks = {
|
592 |
+
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
|
593 |
+
}
|
594 |
+
|
595 |
+
# Retrieves the keys for the decoder up blocks only
|
596 |
+
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
|
597 |
+
up_blocks = {
|
598 |
+
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
|
599 |
+
}
|
600 |
+
|
601 |
+
for i in range(num_down_blocks):
|
602 |
+
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
603 |
+
|
604 |
+
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
605 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
606 |
+
f"encoder.down.{i}.downsample.conv.weight"
|
607 |
+
)
|
608 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
|
609 |
+
f"encoder.down.{i}.downsample.conv.bias"
|
610 |
+
)
|
611 |
+
|
612 |
+
paths = renew_vae_resnet_paths(resnets)
|
613 |
+
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
614 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
615 |
+
|
616 |
+
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
617 |
+
num_mid_res_blocks = 2
|
618 |
+
for i in range(1, num_mid_res_blocks + 1):
|
619 |
+
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
620 |
+
|
621 |
+
paths = renew_vae_resnet_paths(resnets)
|
622 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
623 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
624 |
+
|
625 |
+
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
626 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
627 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
628 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
629 |
+
conv_attn_to_linear(new_checkpoint)
|
630 |
+
|
631 |
+
for i in range(num_up_blocks):
|
632 |
+
block_id = num_up_blocks - 1 - i
|
633 |
+
resnets = [
|
634 |
+
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
|
635 |
+
]
|
636 |
+
|
637 |
+
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
638 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
639 |
+
f"decoder.up.{block_id}.upsample.conv.weight"
|
640 |
+
]
|
641 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
|
642 |
+
f"decoder.up.{block_id}.upsample.conv.bias"
|
643 |
+
]
|
644 |
+
|
645 |
+
paths = renew_vae_resnet_paths(resnets)
|
646 |
+
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
647 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
648 |
+
|
649 |
+
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
650 |
+
num_mid_res_blocks = 2
|
651 |
+
for i in range(1, num_mid_res_blocks + 1):
|
652 |
+
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
653 |
+
|
654 |
+
paths = renew_vae_resnet_paths(resnets)
|
655 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
656 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
657 |
+
|
658 |
+
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
659 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
660 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
661 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
662 |
+
conv_attn_to_linear(new_checkpoint)
|
663 |
+
return new_checkpoint
|
664 |
+
|
665 |
+
|
666 |
+
def convert_ldm_bert_checkpoint(checkpoint, config):
|
667 |
+
def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
|
668 |
+
hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
|
669 |
+
hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
|
670 |
+
hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
|
671 |
+
|
672 |
+
hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
|
673 |
+
hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
|
674 |
+
|
675 |
+
def _copy_linear(hf_linear, pt_linear):
|
676 |
+
hf_linear.weight = pt_linear.weight
|
677 |
+
hf_linear.bias = pt_linear.bias
|
678 |
+
|
679 |
+
def _copy_layer(hf_layer, pt_layer):
|
680 |
+
# copy layer norms
|
681 |
+
_copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
|
682 |
+
_copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
|
683 |
+
|
684 |
+
# copy attn
|
685 |
+
_copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
|
686 |
+
|
687 |
+
# copy MLP
|
688 |
+
pt_mlp = pt_layer[1][1]
|
689 |
+
_copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
|
690 |
+
_copy_linear(hf_layer.fc2, pt_mlp.net[2])
|
691 |
+
|
692 |
+
def _copy_layers(hf_layers, pt_layers):
|
693 |
+
for i, hf_layer in enumerate(hf_layers):
|
694 |
+
if i != 0:
|
695 |
+
i += i
|
696 |
+
pt_layer = pt_layers[i : i + 2]
|
697 |
+
_copy_layer(hf_layer, pt_layer)
|
698 |
+
|
699 |
+
hf_model = LDMBertModel(config).eval()
|
700 |
+
|
701 |
+
# copy embeds
|
702 |
+
hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
|
703 |
+
hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
|
704 |
+
|
705 |
+
# copy layer norm
|
706 |
+
_copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
|
707 |
+
|
708 |
+
# copy hidden layers
|
709 |
+
_copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
|
710 |
+
|
711 |
+
_copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
|
712 |
+
|
713 |
+
return hf_model
|
714 |
+
|
715 |
+
|
716 |
+
def convert_ldm_clip_checkpoint(checkpoint):
|
717 |
+
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
718 |
+
keys = list(checkpoint.keys())
|
719 |
+
|
720 |
+
text_model_dict = {}
|
721 |
+
|
722 |
+
for key in keys:
|
723 |
+
if key.startswith("cond_stage_model.transformer"):
|
724 |
+
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
725 |
+
|
726 |
+
text_model.load_state_dict(text_model_dict)
|
727 |
+
|
728 |
+
return text_model
|
729 |
+
|
730 |
+
|
731 |
+
textenc_conversion_lst = [
|
732 |
+
("cond_stage_model.model.positional_embedding", "text_model.embeddings.position_embedding.weight"),
|
733 |
+
("cond_stage_model.model.token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
|
734 |
+
("cond_stage_model.model.ln_final.weight", "text_model.final_layer_norm.weight"),
|
735 |
+
("cond_stage_model.model.ln_final.bias", "text_model.final_layer_norm.bias"),
|
736 |
+
]
|
737 |
+
textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst}
|
738 |
+
|
739 |
+
textenc_transformer_conversion_lst = [
|
740 |
+
# (stable-diffusion, HF Diffusers)
|
741 |
+
("resblocks.", "text_model.encoder.layers."),
|
742 |
+
("ln_1", "layer_norm1"),
|
743 |
+
("ln_2", "layer_norm2"),
|
744 |
+
(".c_fc.", ".fc1."),
|
745 |
+
(".c_proj.", ".fc2."),
|
746 |
+
(".attn", ".self_attn"),
|
747 |
+
("ln_final.", "transformer.text_model.final_layer_norm."),
|
748 |
+
("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
|
749 |
+
("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
|
750 |
+
]
|
751 |
+
protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}
|
752 |
+
textenc_pattern = re.compile("|".join(protected.keys()))
|
753 |
+
|
754 |
+
|
755 |
+
def convert_paint_by_example_checkpoint(checkpoint):
|
756 |
+
config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14")
|
757 |
+
model = PaintByExampleImageEncoder(config)
|
758 |
+
|
759 |
+
keys = list(checkpoint.keys())
|
760 |
+
|
761 |
+
text_model_dict = {}
|
762 |
+
|
763 |
+
for key in keys:
|
764 |
+
if key.startswith("cond_stage_model.transformer"):
|
765 |
+
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
766 |
+
|
767 |
+
# load clip vision
|
768 |
+
model.model.load_state_dict(text_model_dict)
|
769 |
+
|
770 |
+
# load mapper
|
771 |
+
keys_mapper = {
|
772 |
+
k[len("cond_stage_model.mapper.res") :]: v
|
773 |
+
for k, v in checkpoint.items()
|
774 |
+
if k.startswith("cond_stage_model.mapper")
|
775 |
+
}
|
776 |
+
|
777 |
+
MAPPING = {
|
778 |
+
"attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"],
|
779 |
+
"attn.c_proj": ["attn1.to_out.0"],
|
780 |
+
"ln_1": ["norm1"],
|
781 |
+
"ln_2": ["norm3"],
|
782 |
+
"mlp.c_fc": ["ff.net.0.proj"],
|
783 |
+
"mlp.c_proj": ["ff.net.2"],
|
784 |
+
}
|
785 |
+
|
786 |
+
mapped_weights = {}
|
787 |
+
for key, value in keys_mapper.items():
|
788 |
+
prefix = key[: len("blocks.i")]
|
789 |
+
suffix = key.split(prefix)[-1].split(".")[-1]
|
790 |
+
name = key.split(prefix)[-1].split(suffix)[0][1:-1]
|
791 |
+
mapped_names = MAPPING[name]
|
792 |
+
|
793 |
+
num_splits = len(mapped_names)
|
794 |
+
for i, mapped_name in enumerate(mapped_names):
|
795 |
+
new_name = ".".join([prefix, mapped_name, suffix])
|
796 |
+
shape = value.shape[0] // num_splits
|
797 |
+
mapped_weights[new_name] = value[i * shape : (i + 1) * shape]
|
798 |
+
|
799 |
+
model.mapper.load_state_dict(mapped_weights)
|
800 |
+
|
801 |
+
# load final layer norm
|
802 |
+
model.final_layer_norm.load_state_dict(
|
803 |
+
{
|
804 |
+
"bias": checkpoint["cond_stage_model.final_ln.bias"],
|
805 |
+
"weight": checkpoint["cond_stage_model.final_ln.weight"],
|
806 |
+
}
|
807 |
+
)
|
808 |
+
|
809 |
+
# load final proj
|
810 |
+
model.proj_out.load_state_dict(
|
811 |
+
{
|
812 |
+
"bias": checkpoint["proj_out.bias"],
|
813 |
+
"weight": checkpoint["proj_out.weight"],
|
814 |
+
}
|
815 |
+
)
|
816 |
+
|
817 |
+
# load uncond vector
|
818 |
+
model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"])
|
819 |
+
return model
|
820 |
+
|
821 |
+
|
822 |
+
def convert_open_clip_checkpoint(checkpoint):
|
823 |
+
text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
|
824 |
+
|
825 |
+
keys = list(checkpoint.keys())
|
826 |
+
|
827 |
+
text_model_dict = {}
|
828 |
+
|
829 |
+
if "cond_stage_model.model.text_projection" in checkpoint:
|
830 |
+
d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0])
|
831 |
+
else:
|
832 |
+
d_model = 1024
|
833 |
+
|
834 |
+
text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
|
835 |
+
|
836 |
+
for key in keys:
|
837 |
+
if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer
|
838 |
+
continue
|
839 |
+
if key in textenc_conversion_map:
|
840 |
+
text_model_dict[textenc_conversion_map[key]] = checkpoint[key]
|
841 |
+
if key.startswith("cond_stage_model.model.transformer."):
|
842 |
+
new_key = key[len("cond_stage_model.model.transformer.") :]
|
843 |
+
if new_key.endswith(".in_proj_weight"):
|
844 |
+
new_key = new_key[: -len(".in_proj_weight")]
|
845 |
+
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
|
846 |
+
text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
|
847 |
+
text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :]
|
848 |
+
text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :]
|
849 |
+
elif new_key.endswith(".in_proj_bias"):
|
850 |
+
new_key = new_key[: -len(".in_proj_bias")]
|
851 |
+
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
|
852 |
+
text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model]
|
853 |
+
text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2]
|
854 |
+
text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :]
|
855 |
+
else:
|
856 |
+
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
|
857 |
+
|
858 |
+
text_model_dict[new_key] = checkpoint[key]
|
859 |
+
|
860 |
+
text_model.load_state_dict(text_model_dict)
|
861 |
+
|
862 |
+
return text_model
|
863 |
+
|
864 |
+
|
865 |
+
def stable_unclip_image_encoder(original_config):
|
866 |
+
"""
|
867 |
+
Returns the image processor and clip image encoder for the img2img unclip pipeline.
|
868 |
+
|
869 |
+
We currently know of two types of stable unclip models which separately use the clip and the openclip image
|
870 |
+
encoders.
|
871 |
+
"""
|
872 |
+
|
873 |
+
image_embedder_config = original_config.model.params.embedder_config
|
874 |
+
|
875 |
+
sd_clip_image_embedder_class = image_embedder_config.target
|
876 |
+
sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1]
|
877 |
+
|
878 |
+
if sd_clip_image_embedder_class == "ClipImageEmbedder":
|
879 |
+
clip_model_name = image_embedder_config.params.model
|
880 |
+
|
881 |
+
if clip_model_name == "ViT-L/14":
|
882 |
+
feature_extractor = CLIPImageProcessor()
|
883 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
|
884 |
+
else:
|
885 |
+
raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}")
|
886 |
+
|
887 |
+
elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder":
|
888 |
+
feature_extractor = CLIPImageProcessor()
|
889 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
|
890 |
+
else:
|
891 |
+
raise NotImplementedError(
|
892 |
+
f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}"
|
893 |
+
)
|
894 |
+
|
895 |
+
return feature_extractor, image_encoder
|
896 |
+
|
897 |
+
|
898 |
+
def stable_unclip_image_noising_components(
|
899 |
+
original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None
|
900 |
+
):
|
901 |
+
"""
|
902 |
+
Returns the noising components for the img2img and txt2img unclip pipelines.
|
903 |
+
|
904 |
+
Converts the stability noise augmentor into
|
905 |
+
1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats
|
906 |
+
2. a `DDPMScheduler` for holding the noise schedule
|
907 |
+
|
908 |
+
If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided.
|
909 |
+
"""
|
910 |
+
noise_aug_config = original_config.model.params.noise_aug_config
|
911 |
+
noise_aug_class = noise_aug_config.target
|
912 |
+
noise_aug_class = noise_aug_class.split(".")[-1]
|
913 |
+
|
914 |
+
if noise_aug_class == "CLIPEmbeddingNoiseAugmentation":
|
915 |
+
noise_aug_config = noise_aug_config.params
|
916 |
+
embedding_dim = noise_aug_config.timestep_dim
|
917 |
+
max_noise_level = noise_aug_config.noise_schedule_config.timesteps
|
918 |
+
beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule
|
919 |
+
|
920 |
+
image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim)
|
921 |
+
image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule)
|
922 |
+
|
923 |
+
if "clip_stats_path" in noise_aug_config:
|
924 |
+
if clip_stats_path is None:
|
925 |
+
raise ValueError("This stable unclip config requires a `clip_stats_path`")
|
926 |
+
|
927 |
+
clip_mean, clip_std = torch.load(clip_stats_path, map_location=device)
|
928 |
+
clip_mean = clip_mean[None, :]
|
929 |
+
clip_std = clip_std[None, :]
|
930 |
+
|
931 |
+
clip_stats_state_dict = {
|
932 |
+
"mean": clip_mean,
|
933 |
+
"std": clip_std,
|
934 |
+
}
|
935 |
+
|
936 |
+
image_normalizer.load_state_dict(clip_stats_state_dict)
|
937 |
+
else:
|
938 |
+
raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}")
|
939 |
+
|
940 |
+
return image_normalizer, image_noising_scheduler
|
941 |
+
|
942 |
+
|
943 |
+
def convert_controlnet_checkpoint(
|
944 |
+
checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
|
945 |
+
):
|
946 |
+
ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
|
947 |
+
ctrlnet_config["upcast_attention"] = upcast_attention
|
948 |
+
|
949 |
+
ctrlnet_config.pop("sample_size")
|
950 |
+
|
951 |
+
controlnet_model = ControlNetModel(**ctrlnet_config)
|
952 |
+
|
953 |
+
converted_ctrl_checkpoint = convert_ldm_unet_checkpoint(
|
954 |
+
checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True
|
955 |
+
)
|
956 |
+
|
957 |
+
controlnet_model.load_state_dict(converted_ctrl_checkpoint)
|
958 |
+
|
959 |
+
return controlnet_model
|
animatediff/utils/convert_lora_safetensor_to_diffusers.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
""" Conversion script for the LoRA's safetensors checkpoints. """
|
17 |
+
|
18 |
+
import argparse
|
19 |
+
|
20 |
+
import torch
|
21 |
+
from safetensors.torch import load_file
|
22 |
+
|
23 |
+
from diffusers import StableDiffusionPipeline
|
24 |
+
import pdb
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
def convert_motion_lora_ckpt_to_diffusers(pipeline, state_dict, alpha=1.0):
|
29 |
+
# directly update weight in diffusers model
|
30 |
+
for key in state_dict:
|
31 |
+
# only process lora down key
|
32 |
+
if "up." in key: continue
|
33 |
+
|
34 |
+
up_key = key.replace(".down.", ".up.")
|
35 |
+
model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "")
|
36 |
+
model_key = model_key.replace("to_out.", "to_out.0.")
|
37 |
+
layer_infos = model_key.split(".")[:-1]
|
38 |
+
|
39 |
+
curr_layer = pipeline.unet
|
40 |
+
while len(layer_infos) > 0:
|
41 |
+
temp_name = layer_infos.pop(0)
|
42 |
+
curr_layer = curr_layer.__getattr__(temp_name)
|
43 |
+
|
44 |
+
weight_down = state_dict[key]
|
45 |
+
weight_up = state_dict[up_key]
|
46 |
+
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
|
47 |
+
|
48 |
+
return pipeline
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6):
|
53 |
+
# load base model
|
54 |
+
# pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32)
|
55 |
+
|
56 |
+
# load LoRA weight from .safetensors
|
57 |
+
# state_dict = load_file(checkpoint_path)
|
58 |
+
|
59 |
+
visited = []
|
60 |
+
|
61 |
+
# directly update weight in diffusers model
|
62 |
+
for key in state_dict:
|
63 |
+
# it is suggested to print out the key, it usually will be something like below
|
64 |
+
# "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
|
65 |
+
|
66 |
+
# as we have set the alpha beforehand, so just skip
|
67 |
+
if ".alpha" in key or key in visited:
|
68 |
+
continue
|
69 |
+
|
70 |
+
if "text" in key:
|
71 |
+
layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
|
72 |
+
curr_layer = pipeline.text_encoder
|
73 |
+
else:
|
74 |
+
layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
|
75 |
+
curr_layer = pipeline.unet
|
76 |
+
|
77 |
+
# find the target layer
|
78 |
+
temp_name = layer_infos.pop(0)
|
79 |
+
while len(layer_infos) > -1:
|
80 |
+
try:
|
81 |
+
curr_layer = curr_layer.__getattr__(temp_name)
|
82 |
+
if len(layer_infos) > 0:
|
83 |
+
temp_name = layer_infos.pop(0)
|
84 |
+
elif len(layer_infos) == 0:
|
85 |
+
break
|
86 |
+
except Exception:
|
87 |
+
if len(temp_name) > 0:
|
88 |
+
temp_name += "_" + layer_infos.pop(0)
|
89 |
+
else:
|
90 |
+
temp_name = layer_infos.pop(0)
|
91 |
+
|
92 |
+
pair_keys = []
|
93 |
+
if "lora_down" in key:
|
94 |
+
pair_keys.append(key.replace("lora_down", "lora_up"))
|
95 |
+
pair_keys.append(key)
|
96 |
+
else:
|
97 |
+
pair_keys.append(key)
|
98 |
+
pair_keys.append(key.replace("lora_up", "lora_down"))
|
99 |
+
|
100 |
+
# update weight
|
101 |
+
if len(state_dict[pair_keys[0]].shape) == 4:
|
102 |
+
weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
|
103 |
+
weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
|
104 |
+
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device)
|
105 |
+
else:
|
106 |
+
weight_up = state_dict[pair_keys[0]].to(torch.float32)
|
107 |
+
weight_down = state_dict[pair_keys[1]].to(torch.float32)
|
108 |
+
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
|
109 |
+
|
110 |
+
# update visited list
|
111 |
+
for item in pair_keys:
|
112 |
+
visited.append(item)
|
113 |
+
|
114 |
+
return pipeline
|
115 |
+
|
116 |
+
|
117 |
+
if __name__ == "__main__":
|
118 |
+
parser = argparse.ArgumentParser()
|
119 |
+
|
120 |
+
parser.add_argument(
|
121 |
+
"--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format."
|
122 |
+
)
|
123 |
+
parser.add_argument(
|
124 |
+
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
|
125 |
+
)
|
126 |
+
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
127 |
+
parser.add_argument(
|
128 |
+
"--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors"
|
129 |
+
)
|
130 |
+
parser.add_argument(
|
131 |
+
"--lora_prefix_text_encoder",
|
132 |
+
default="lora_te",
|
133 |
+
type=str,
|
134 |
+
help="The prefix of text encoder weight in safetensors",
|
135 |
+
)
|
136 |
+
parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW")
|
137 |
+
parser.add_argument(
|
138 |
+
"--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not."
|
139 |
+
)
|
140 |
+
parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
|
141 |
+
|
142 |
+
args = parser.parse_args()
|
143 |
+
|
144 |
+
base_model_path = args.base_model_path
|
145 |
+
checkpoint_path = args.checkpoint_path
|
146 |
+
dump_path = args.dump_path
|
147 |
+
lora_prefix_unet = args.lora_prefix_unet
|
148 |
+
lora_prefix_text_encoder = args.lora_prefix_text_encoder
|
149 |
+
alpha = args.alpha
|
150 |
+
|
151 |
+
pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha)
|
152 |
+
|
153 |
+
pipe = pipe.to(args.device)
|
154 |
+
pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
|
animatediff/utils/freeinit_utils.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.fft as fft
|
3 |
+
import math
|
4 |
+
|
5 |
+
|
6 |
+
def freq_mix_3d(x, noise, LPF):
|
7 |
+
"""
|
8 |
+
Noise reinitialization.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
x: diffused latent
|
12 |
+
noise: randomly sampled noise
|
13 |
+
LPF: low pass filter
|
14 |
+
"""
|
15 |
+
# FFT
|
16 |
+
x_freq = fft.fftn(x, dim=(-3, -2, -1))
|
17 |
+
x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1))
|
18 |
+
noise_freq = fft.fftn(noise, dim=(-3, -2, -1))
|
19 |
+
noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1))
|
20 |
+
|
21 |
+
# frequency mix
|
22 |
+
HPF = 1 - LPF
|
23 |
+
x_freq_low = x_freq * LPF
|
24 |
+
noise_freq_high = noise_freq * HPF
|
25 |
+
x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain
|
26 |
+
|
27 |
+
# IFFT
|
28 |
+
x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1))
|
29 |
+
x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real
|
30 |
+
|
31 |
+
return x_mixed
|
32 |
+
|
33 |
+
|
34 |
+
def get_freq_filter(shape, device, filter_type, n, d_s, d_t):
|
35 |
+
"""
|
36 |
+
Form the frequency filter for noise reinitialization.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
shape: shape of latent (B, C, T, H, W)
|
40 |
+
filter_type: type of the freq filter
|
41 |
+
n: (only for butterworth) order of the filter, larger n ~ ideal, smaller n ~ gaussian
|
42 |
+
d_s: normalized stop frequency for spatial dimensions (0.0-1.0)
|
43 |
+
d_t: normalized stop frequency for temporal dimension (0.0-1.0)
|
44 |
+
"""
|
45 |
+
if filter_type == "gaussian":
|
46 |
+
return gaussian_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device)
|
47 |
+
elif filter_type == "ideal":
|
48 |
+
return ideal_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device)
|
49 |
+
elif filter_type == "box":
|
50 |
+
return box_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device)
|
51 |
+
elif filter_type == "butterworth":
|
52 |
+
return butterworth_low_pass_filter(shape=shape, n=n, d_s=d_s, d_t=d_t).to(device)
|
53 |
+
else:
|
54 |
+
raise NotImplementedError
|
55 |
+
|
56 |
+
def gaussian_low_pass_filter(shape, d_s=0.25, d_t=0.25):
|
57 |
+
"""
|
58 |
+
Compute the gaussian low pass filter mask.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
shape: shape of the filter (volume)
|
62 |
+
d_s: normalized stop frequency for spatial dimensions (0.0-1.0)
|
63 |
+
d_t: normalized stop frequency for temporal dimension (0.0-1.0)
|
64 |
+
"""
|
65 |
+
T, H, W = shape[-3], shape[-2], shape[-1]
|
66 |
+
mask = torch.zeros(shape)
|
67 |
+
if d_s==0 or d_t==0:
|
68 |
+
return mask
|
69 |
+
for t in range(T):
|
70 |
+
for h in range(H):
|
71 |
+
for w in range(W):
|
72 |
+
d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2)
|
73 |
+
mask[..., t,h,w] = math.exp(-1/(2*d_s**2) * d_square)
|
74 |
+
return mask
|
75 |
+
|
76 |
+
|
77 |
+
def butterworth_low_pass_filter(shape, n=4, d_s=0.25, d_t=0.25):
|
78 |
+
"""
|
79 |
+
Compute the butterworth low pass filter mask.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
shape: shape of the filter (volume)
|
83 |
+
n: order of the filter, larger n ~ ideal, smaller n ~ gaussian
|
84 |
+
d_s: normalized stop frequency for spatial dimensions (0.0-1.0)
|
85 |
+
d_t: normalized stop frequency for temporal dimension (0.0-1.0)
|
86 |
+
"""
|
87 |
+
T, H, W = shape[-3], shape[-2], shape[-1]
|
88 |
+
mask = torch.zeros(shape)
|
89 |
+
if d_s==0 or d_t==0:
|
90 |
+
return mask
|
91 |
+
for t in range(T):
|
92 |
+
for h in range(H):
|
93 |
+
for w in range(W):
|
94 |
+
d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2)
|
95 |
+
mask[..., t,h,w] = 1 / (1 + (d_square / d_s**2)**n)
|
96 |
+
return mask
|
97 |
+
|
98 |
+
|
99 |
+
def ideal_low_pass_filter(shape, d_s=0.25, d_t=0.25):
|
100 |
+
"""
|
101 |
+
Compute the ideal low pass filter mask.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
shape: shape of the filter (volume)
|
105 |
+
d_s: normalized stop frequency for spatial dimensions (0.0-1.0)
|
106 |
+
d_t: normalized stop frequency for temporal dimension (0.0-1.0)
|
107 |
+
"""
|
108 |
+
T, H, W = shape[-3], shape[-2], shape[-1]
|
109 |
+
mask = torch.zeros(shape)
|
110 |
+
if d_s==0 or d_t==0:
|
111 |
+
return mask
|
112 |
+
for t in range(T):
|
113 |
+
for h in range(H):
|
114 |
+
for w in range(W):
|
115 |
+
d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2)
|
116 |
+
mask[..., t,h,w] = 1 if d_square <= d_s*2 else 0
|
117 |
+
return mask
|
118 |
+
|
119 |
+
|
120 |
+
def box_low_pass_filter(shape, d_s=0.25, d_t=0.25):
|
121 |
+
"""
|
122 |
+
Compute the ideal low pass filter mask (approximated version).
|
123 |
+
|
124 |
+
Args:
|
125 |
+
shape: shape of the filter (volume)
|
126 |
+
d_s: normalized stop frequency for spatial dimensions (0.0-1.0)
|
127 |
+
d_t: normalized stop frequency for temporal dimension (0.0-1.0)
|
128 |
+
"""
|
129 |
+
T, H, W = shape[-3], shape[-2], shape[-1]
|
130 |
+
mask = torch.zeros(shape)
|
131 |
+
if d_s==0 or d_t==0:
|
132 |
+
return mask
|
133 |
+
|
134 |
+
threshold_s = round(int(H // 2) * d_s)
|
135 |
+
threshold_t = round(T // 2 * d_t)
|
136 |
+
|
137 |
+
cframe, crow, ccol = T // 2, H // 2, W //2
|
138 |
+
mask[..., cframe - threshold_t:cframe + threshold_t, crow - threshold_s:crow + threshold_s, ccol - threshold_s:ccol + threshold_s] = 1.0
|
139 |
+
|
140 |
+
return mask
|
animatediff/utils/util.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import imageio
|
3 |
+
import numpy as np
|
4 |
+
from typing import Union
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torchvision
|
8 |
+
import torch.distributed as dist
|
9 |
+
|
10 |
+
from safetensors import safe_open
|
11 |
+
from tqdm import tqdm
|
12 |
+
from einops import rearrange
|
13 |
+
from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
|
14 |
+
from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora, convert_motion_lora_ckpt_to_diffusers
|
15 |
+
|
16 |
+
|
17 |
+
def zero_rank_print(s):
|
18 |
+
if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s)
|
19 |
+
|
20 |
+
|
21 |
+
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
|
22 |
+
videos = rearrange(videos, "b c t h w -> t b c h w")
|
23 |
+
outputs = []
|
24 |
+
for x in videos:
|
25 |
+
x = torchvision.utils.make_grid(x, nrow=n_rows)
|
26 |
+
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
27 |
+
if rescale:
|
28 |
+
x = (x + 1.0) / 2.0 # -1,1 -> 0,1
|
29 |
+
x = (x * 255).numpy().astype(np.uint8)
|
30 |
+
outputs.append(x)
|
31 |
+
|
32 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
33 |
+
imageio.mimsave(path, outputs, fps=fps)
|
34 |
+
|
35 |
+
|
36 |
+
# DDIM Inversion
|
37 |
+
@torch.no_grad()
|
38 |
+
def init_prompt(prompt, pipeline):
|
39 |
+
uncond_input = pipeline.tokenizer(
|
40 |
+
[""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
|
41 |
+
return_tensors="pt"
|
42 |
+
)
|
43 |
+
uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
|
44 |
+
text_input = pipeline.tokenizer(
|
45 |
+
[prompt],
|
46 |
+
padding="max_length",
|
47 |
+
max_length=pipeline.tokenizer.model_max_length,
|
48 |
+
truncation=True,
|
49 |
+
return_tensors="pt",
|
50 |
+
)
|
51 |
+
text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
|
52 |
+
context = torch.cat([uncond_embeddings, text_embeddings])
|
53 |
+
|
54 |
+
return context
|
55 |
+
|
56 |
+
|
57 |
+
def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
|
58 |
+
sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
|
59 |
+
timestep, next_timestep = min(
|
60 |
+
timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
|
61 |
+
alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
|
62 |
+
alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
|
63 |
+
beta_prod_t = 1 - alpha_prod_t
|
64 |
+
next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
|
65 |
+
next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
|
66 |
+
next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
|
67 |
+
return next_sample
|
68 |
+
|
69 |
+
|
70 |
+
def get_noise_pred_single(latents, t, context, unet):
|
71 |
+
noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
|
72 |
+
return noise_pred
|
73 |
+
|
74 |
+
|
75 |
+
@torch.no_grad()
|
76 |
+
def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
|
77 |
+
context = init_prompt(prompt, pipeline)
|
78 |
+
uncond_embeddings, cond_embeddings = context.chunk(2)
|
79 |
+
all_latent = [latent]
|
80 |
+
latent = latent.clone().detach()
|
81 |
+
for i in tqdm(range(num_inv_steps)):
|
82 |
+
t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
|
83 |
+
noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
|
84 |
+
latent = next_step(noise_pred, t, latent, ddim_scheduler)
|
85 |
+
all_latent.append(latent)
|
86 |
+
return all_latent
|
87 |
+
|
88 |
+
|
89 |
+
@torch.no_grad()
|
90 |
+
def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
|
91 |
+
ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
|
92 |
+
return ddim_latents
|
93 |
+
|
94 |
+
def load_weights(
|
95 |
+
animation_pipeline,
|
96 |
+
# motion module
|
97 |
+
motion_module_path = "",
|
98 |
+
motion_module_lora_configs = [],
|
99 |
+
# image layers
|
100 |
+
dreambooth_model_path = "",
|
101 |
+
lora_model_path = "",
|
102 |
+
lora_alpha = 0.8,
|
103 |
+
):
|
104 |
+
# 1.1 motion module
|
105 |
+
unet_state_dict = {}
|
106 |
+
if motion_module_path != "":
|
107 |
+
print(f"load motion module from {motion_module_path}")
|
108 |
+
motion_module_state_dict = torch.load(motion_module_path, map_location="cpu")
|
109 |
+
motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict
|
110 |
+
unet_state_dict.update({name: param for name, param in motion_module_state_dict.items() if "motion_modules." in name})
|
111 |
+
|
112 |
+
missing, unexpected = animation_pipeline.unet.load_state_dict(unet_state_dict, strict=False)
|
113 |
+
assert len(unexpected) == 0
|
114 |
+
del unet_state_dict
|
115 |
+
|
116 |
+
if dreambooth_model_path != "":
|
117 |
+
print(f"load dreambooth model from {dreambooth_model_path}")
|
118 |
+
if dreambooth_model_path.endswith(".safetensors"):
|
119 |
+
dreambooth_state_dict = {}
|
120 |
+
with safe_open(dreambooth_model_path, framework="pt", device="cpu") as f:
|
121 |
+
for key in f.keys():
|
122 |
+
dreambooth_state_dict[key] = f.get_tensor(key)
|
123 |
+
elif dreambooth_model_path.endswith(".ckpt"):
|
124 |
+
dreambooth_state_dict = torch.load(dreambooth_model_path, map_location="cpu")
|
125 |
+
|
126 |
+
# 1. vae
|
127 |
+
converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, animation_pipeline.vae.config)
|
128 |
+
animation_pipeline.vae.load_state_dict(converted_vae_checkpoint)
|
129 |
+
# 2. unet
|
130 |
+
converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, animation_pipeline.unet.config)
|
131 |
+
animation_pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
|
132 |
+
# 3. text_model
|
133 |
+
animation_pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict)
|
134 |
+
del dreambooth_state_dict
|
135 |
+
|
136 |
+
if lora_model_path != "":
|
137 |
+
print(f"load lora model from {lora_model_path}")
|
138 |
+
assert lora_model_path.endswith(".safetensors")
|
139 |
+
lora_state_dict = {}
|
140 |
+
with safe_open(lora_model_path, framework="pt", device="cpu") as f:
|
141 |
+
for key in f.keys():
|
142 |
+
lora_state_dict[key] = f.get_tensor(key)
|
143 |
+
|
144 |
+
animation_pipeline = convert_lora(animation_pipeline, lora_state_dict, alpha=lora_alpha)
|
145 |
+
del lora_state_dict
|
146 |
+
|
147 |
+
|
148 |
+
for motion_module_lora_config in motion_module_lora_configs:
|
149 |
+
path, alpha = motion_module_lora_config["path"], motion_module_lora_config["alpha"]
|
150 |
+
print(f"load motion LoRA from {path}")
|
151 |
+
|
152 |
+
motion_lora_state_dict = torch.load(path, map_location="cpu")
|
153 |
+
motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict
|
154 |
+
|
155 |
+
animation_pipeline = convert_motion_lora_ckpt_to_diffusers(animation_pipeline, motion_lora_state_dict, alpha)
|
156 |
+
|
157 |
+
return animation_pipeline
|
app.py
ADDED
@@ -0,0 +1,488 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import random
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
from glob import glob
|
7 |
+
from omegaconf import OmegaConf
|
8 |
+
from safetensors import safe_open
|
9 |
+
|
10 |
+
from diffusers import AutoencoderKL
|
11 |
+
from diffusers import EulerDiscreteScheduler, DDIMScheduler
|
12 |
+
from diffusers.utils.import_utils import is_xformers_available
|
13 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
14 |
+
|
15 |
+
from animatediff.models.unet import UNet3DConditionModel
|
16 |
+
from animatediff.pipelines.pipeline_animation import AnimationFreeInitPipeline
|
17 |
+
from animatediff.utils.util import save_videos_grid
|
18 |
+
from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
|
19 |
+
from diffusers.training_utils import set_seed
|
20 |
+
|
21 |
+
from animatediff.utils.freeinit_utils import get_freq_filter
|
22 |
+
from collections import namedtuple
|
23 |
+
|
24 |
+
pretrained_model_path = "models/StableDiffusion/stable-diffusion-v1-5"
|
25 |
+
inference_config_path = "configs/inference/inference-v1.yaml"
|
26 |
+
|
27 |
+
css = """
|
28 |
+
.toolbutton {
|
29 |
+
margin-buttom: 0em 0em 0em 0em;
|
30 |
+
max-width: 2.5em;
|
31 |
+
min-width: 2.5em !important;
|
32 |
+
height: 2.5em;
|
33 |
+
}
|
34 |
+
"""
|
35 |
+
|
36 |
+
examples = [
|
37 |
+
# 1-ToonYou
|
38 |
+
[
|
39 |
+
"toonyou_beta3.safetensors",
|
40 |
+
"mm_sd_v14.ckpt",
|
41 |
+
"(best quality, masterpiece), close up, 1girl, red clothes, sitting, elf, pond, in water, deep forest, waterfall, looking away, blurry background",
|
42 |
+
"worst quality, low quality, nsfw, logo",
|
43 |
+
512, 512, "1566149281915957",
|
44 |
+
"butterworth", 0.25, 0.25, 3,
|
45 |
+
["use_fp16"]
|
46 |
+
],
|
47 |
+
# 2-Lyriel
|
48 |
+
[
|
49 |
+
"lyriel_v16.safetensors",
|
50 |
+
"mm_sd_v14.ckpt",
|
51 |
+
"hypercars cyberpunk moving, muted colors, swirling color smokes, legend, cityscape, space",
|
52 |
+
"3d, cartoon, anime, sketches, worst quality, low quality, nsfw, logo",
|
53 |
+
512, 512, "4954488479039740",
|
54 |
+
"butterworth", 0.25, 0.25, 3,
|
55 |
+
["use_fp16"]
|
56 |
+
],
|
57 |
+
# 3-RCNZ
|
58 |
+
[
|
59 |
+
"rcnzCartoon3d_v10.safetensors",
|
60 |
+
"mm_sd_v14.ckpt",
|
61 |
+
"A cute raccoon playing guitar in a boat on the ocean",
|
62 |
+
"worst quality, low quality, nsfw, logo",
|
63 |
+
512, 512, "2005563494988190",
|
64 |
+
"butterworth", 0.25, 0.25, 3,
|
65 |
+
["use_fp16"]
|
66 |
+
],
|
67 |
+
# 4-MajicMix
|
68 |
+
[
|
69 |
+
"majicmixRealistic_v5Preview.safetensors",
|
70 |
+
"mm_sd_v14.ckpt",
|
71 |
+
"1girl, reading book",
|
72 |
+
"bad hand, worst quality, low quality, normal quality, lowres, bad anatomy, bad hands, watermark, moles",
|
73 |
+
512, 512, "2005563494988190",
|
74 |
+
"butterworth", 0.25, 0.25, 3,
|
75 |
+
["use_fp16"]
|
76 |
+
],
|
77 |
+
# # 5-RealisticVision
|
78 |
+
# [
|
79 |
+
# "realisticVisionV51_v20Novae.safetensors",
|
80 |
+
# "mm_sd_v14.ckpt",
|
81 |
+
# "A panda standing on a surfboard in the ocean in sunset.",
|
82 |
+
# "worst quality, low quality, nsfw, logo",
|
83 |
+
# 512, 512, "2005563494988190",
|
84 |
+
# "butterworth", 0.25, 0.25, 3,
|
85 |
+
# ["use_fp16"]
|
86 |
+
# ]
|
87 |
+
# 5-RealisticVision
|
88 |
+
[
|
89 |
+
"realisticVisionV51_v20Novae.safetensors",
|
90 |
+
"mm_sd_v14.ckpt",
|
91 |
+
"b&w photo of 42 y.o man in black clothes, bald, face, half body, body, high detailed skin, skin pores, coastline, overcast weather, wind, waves, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3",
|
92 |
+
"(semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck",
|
93 |
+
512, 512, "1566149281915957",
|
94 |
+
"butterworth", 0.25, 0.25, 3,
|
95 |
+
["use_fp16"]
|
96 |
+
]
|
97 |
+
]
|
98 |
+
|
99 |
+
# clean unrelated ckpts
|
100 |
+
# ckpts = [
|
101 |
+
# "realisticVisionV40_v20Novae.safetensors",
|
102 |
+
# "majicmixRealistic_v5Preview.safetensors",
|
103 |
+
# "rcnzCartoon3d_v10.safetensors",
|
104 |
+
# "lyriel_v16.safetensors",
|
105 |
+
# "toonyou_beta3.safetensors"
|
106 |
+
# ]
|
107 |
+
|
108 |
+
# for path in glob(os.path.join("models", "DreamBooth_LoRA", "*.safetensors")):
|
109 |
+
# for ckpt in ckpts:
|
110 |
+
# if path.endswith(ckpt): break
|
111 |
+
# else:
|
112 |
+
# print(f"### Cleaning {path} ...")
|
113 |
+
# os.system(f"rm -rf {path}")
|
114 |
+
|
115 |
+
# os.system(f"rm -rf {os.path.join('models', 'DreamBooth_LoRA', '*.safetensors')}")
|
116 |
+
|
117 |
+
# os.system(f"bash download_bashscripts/1-ToonYou.sh")
|
118 |
+
# os.system(f"bash download_bashscripts/2-Lyriel.sh")
|
119 |
+
# os.system(f"bash download_bashscripts/3-RcnzCartoon.sh")
|
120 |
+
# os.system(f"bash download_bashscripts/4-MajicMix.sh")
|
121 |
+
# os.system(f"bash download_bashscripts/5-RealisticVision.sh")
|
122 |
+
|
123 |
+
# clean Gradio cache
|
124 |
+
print(f"### Cleaning cached examples ...")
|
125 |
+
os.system(f"rm -rf gradio_cached_examples/")
|
126 |
+
|
127 |
+
|
128 |
+
class AnimateController:
|
129 |
+
def __init__(self):
|
130 |
+
|
131 |
+
# config dirs
|
132 |
+
self.basedir = os.getcwd()
|
133 |
+
self.stable_diffusion_dir = os.path.join(self.basedir, "models", "StableDiffusion")
|
134 |
+
self.motion_module_dir = os.path.join(self.basedir, "models", "Motion_Module")
|
135 |
+
self.personalized_model_dir = os.path.join(self.basedir, "models", "DreamBooth_LoRA")
|
136 |
+
self.savedir = os.path.join(self.basedir, "samples")
|
137 |
+
os.makedirs(self.savedir, exist_ok=True)
|
138 |
+
|
139 |
+
self.base_model_list = []
|
140 |
+
self.motion_module_list = []
|
141 |
+
self.filter_type_list = [
|
142 |
+
"butterworth",
|
143 |
+
"gaussian",
|
144 |
+
"box",
|
145 |
+
"ideal"
|
146 |
+
]
|
147 |
+
|
148 |
+
self.selected_base_model = None
|
149 |
+
self.selected_motion_module = None
|
150 |
+
self.selected_filter_type = None
|
151 |
+
self.set_width = None
|
152 |
+
self.set_height = None
|
153 |
+
self.set_d_s = None
|
154 |
+
self.set_d_t = None
|
155 |
+
|
156 |
+
self.refresh_motion_module()
|
157 |
+
self.refresh_personalized_model()
|
158 |
+
|
159 |
+
# config models
|
160 |
+
self.inference_config = OmegaConf.load(inference_config_path)
|
161 |
+
|
162 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
|
163 |
+
self.text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder").cuda()
|
164 |
+
self.vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").cuda()
|
165 |
+
self.unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).cuda()
|
166 |
+
|
167 |
+
self.freq_filter = None
|
168 |
+
|
169 |
+
self.update_base_model(self.base_model_list[-2])
|
170 |
+
self.update_motion_module(self.motion_module_list[0])
|
171 |
+
self.update_filter(512, 512, self.filter_type_list[0], 0.25, 0.25)
|
172 |
+
|
173 |
+
|
174 |
+
def refresh_motion_module(self):
|
175 |
+
motion_module_list = glob(os.path.join(self.motion_module_dir, "*.ckpt"))
|
176 |
+
self.motion_module_list = sorted([os.path.basename(p) for p in motion_module_list])
|
177 |
+
|
178 |
+
def refresh_personalized_model(self):
|
179 |
+
base_model_list = glob(os.path.join(self.personalized_model_dir, "*.safetensors"))
|
180 |
+
self.base_model_list = sorted([os.path.basename(p) for p in base_model_list])
|
181 |
+
|
182 |
+
|
183 |
+
def update_base_model(self, base_model_dropdown):
|
184 |
+
self.selected_base_model = base_model_dropdown
|
185 |
+
|
186 |
+
base_model_dropdown = os.path.join(self.personalized_model_dir, base_model_dropdown)
|
187 |
+
base_model_state_dict = {}
|
188 |
+
with safe_open(base_model_dropdown, framework="pt", device="cpu") as f:
|
189 |
+
for key in f.keys(): base_model_state_dict[key] = f.get_tensor(key)
|
190 |
+
|
191 |
+
converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_model_state_dict, self.vae.config)
|
192 |
+
self.vae.load_state_dict(converted_vae_checkpoint)
|
193 |
+
|
194 |
+
converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_model_state_dict, self.unet.config)
|
195 |
+
self.unet.load_state_dict(converted_unet_checkpoint, strict=False)
|
196 |
+
|
197 |
+
self.text_encoder = convert_ldm_clip_checkpoint(base_model_state_dict)
|
198 |
+
return gr.Dropdown.update()
|
199 |
+
|
200 |
+
def update_motion_module(self, motion_module_dropdown):
|
201 |
+
self.selected_motion_module = motion_module_dropdown
|
202 |
+
|
203 |
+
motion_module_dropdown = os.path.join(self.motion_module_dir, motion_module_dropdown)
|
204 |
+
motion_module_state_dict = torch.load(motion_module_dropdown, map_location="cpu")
|
205 |
+
_, unexpected = self.unet.load_state_dict(motion_module_state_dict, strict=False)
|
206 |
+
assert len(unexpected) == 0
|
207 |
+
return gr.Dropdown.update()
|
208 |
+
|
209 |
+
# def update_filter(self, shape, method, n, d_s, d_t):
|
210 |
+
def update_filter(self, width_slider, height_slider, filter_type_dropdown, d_s_slider, d_t_slider):
|
211 |
+
self.set_width = width_slider
|
212 |
+
self.set_height = height_slider
|
213 |
+
self.selected_filter_type = filter_type_dropdown
|
214 |
+
self.set_d_s = d_s_slider
|
215 |
+
self.set_d_t = d_t_slider
|
216 |
+
|
217 |
+
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
218 |
+
|
219 |
+
shape = [1, 4, 16, self.set_width//vae_scale_factor, self.set_height//vae_scale_factor]
|
220 |
+
self.freq_filter = get_freq_filter(
|
221 |
+
shape,
|
222 |
+
device="cuda",
|
223 |
+
filter_type=self.selected_filter_type,
|
224 |
+
n=4,
|
225 |
+
d_s=self.set_d_s,
|
226 |
+
d_t=self.set_d_t
|
227 |
+
)
|
228 |
+
|
229 |
+
def animate(
|
230 |
+
self,
|
231 |
+
base_model_dropdown,
|
232 |
+
motion_module_dropdown,
|
233 |
+
prompt_textbox,
|
234 |
+
negative_prompt_textbox,
|
235 |
+
width_slider,
|
236 |
+
height_slider,
|
237 |
+
seed_textbox,
|
238 |
+
# freeinit params
|
239 |
+
filter_type_dropdown,
|
240 |
+
d_s_slider,
|
241 |
+
d_t_slider,
|
242 |
+
num_iters_slider,
|
243 |
+
# speed up
|
244 |
+
speed_up_options
|
245 |
+
):
|
246 |
+
# set global seed
|
247 |
+
set_seed(42)
|
248 |
+
|
249 |
+
d_s = float(d_s_slider)
|
250 |
+
d_t = float(d_t_slider)
|
251 |
+
num_iters = int(num_iters_slider)
|
252 |
+
|
253 |
+
|
254 |
+
if self.selected_base_model != base_model_dropdown: self.update_base_model(base_model_dropdown)
|
255 |
+
if self.selected_motion_module != motion_module_dropdown: self.update_motion_module(motion_module_dropdown)
|
256 |
+
|
257 |
+
self.set_width = width_slider
|
258 |
+
self.set_height = height_slider
|
259 |
+
self.selected_filter_type = filter_type_dropdown
|
260 |
+
self.set_d_s = d_s
|
261 |
+
self.set_d_t = d_t
|
262 |
+
if self.set_width != width_slider or self.set_height != height_slider or self.selected_filter_type != filter_type_dropdown or self.set_d_s != d_s or self.set_d_t != d_t:
|
263 |
+
self.update_filter(width_slider, height_slider, filter_type_dropdown, d_s, d_t)
|
264 |
+
|
265 |
+
if is_xformers_available(): self.unet.enable_xformers_memory_efficient_attention()
|
266 |
+
|
267 |
+
pipeline = AnimationFreeInitPipeline(
|
268 |
+
vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet,
|
269 |
+
scheduler=DDIMScheduler(**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
|
270 |
+
).to("cuda")
|
271 |
+
|
272 |
+
# (freeinit) initialize frequency filter for noise reinitialization -------------
|
273 |
+
pipeline.freq_filter = self.freq_filter
|
274 |
+
# -------------------------------------------------------------------------------
|
275 |
+
|
276 |
+
|
277 |
+
if int(seed_textbox) > 0: seed = int(seed_textbox)
|
278 |
+
else: seed = random.randint(1, 1e16)
|
279 |
+
torch.manual_seed(int(seed))
|
280 |
+
|
281 |
+
assert seed == torch.initial_seed()
|
282 |
+
print(f"### seed: {seed}")
|
283 |
+
|
284 |
+
generator = torch.Generator(device="cuda")
|
285 |
+
generator.manual_seed(seed)
|
286 |
+
|
287 |
+
sample_output = pipeline(
|
288 |
+
prompt_textbox,
|
289 |
+
negative_prompt = negative_prompt_textbox,
|
290 |
+
num_inference_steps = 25,
|
291 |
+
guidance_scale = 7.5,
|
292 |
+
width = width_slider,
|
293 |
+
height = height_slider,
|
294 |
+
video_length = 16,
|
295 |
+
num_iters = num_iters,
|
296 |
+
use_fast_sampling = True if "use_coarse_to_fine_sampling" in speed_up_options else False,
|
297 |
+
save_intermediate = False,
|
298 |
+
return_orig = True,
|
299 |
+
use_fp16 = True if "use_fp16" in speed_up_options else False
|
300 |
+
)
|
301 |
+
orig_sample = sample_output.orig_videos
|
302 |
+
sample = sample_output.videos
|
303 |
+
|
304 |
+
save_sample_path = os.path.join(self.savedir, f"sample.mp4")
|
305 |
+
save_videos_grid(sample, save_sample_path)
|
306 |
+
|
307 |
+
save_orig_sample_path = os.path.join(self.savedir, f"sample_orig.mp4")
|
308 |
+
save_videos_grid(orig_sample, save_orig_sample_path)
|
309 |
+
|
310 |
+
# save_compare_path = os.path.join(self.savedir, f"compare.mp4")
|
311 |
+
# save_videos_grid(torch.concat([orig_sample, sample]), save_compare_path)
|
312 |
+
|
313 |
+
json_config = {
|
314 |
+
"prompt": prompt_textbox,
|
315 |
+
"n_prompt": negative_prompt_textbox,
|
316 |
+
"width": width_slider,
|
317 |
+
"height": height_slider,
|
318 |
+
"seed": seed,
|
319 |
+
"base_model": base_model_dropdown,
|
320 |
+
"motion_module": motion_module_dropdown,
|
321 |
+
"filter_type": filter_type_dropdown,
|
322 |
+
"d_s": d_s,
|
323 |
+
"d_t": d_t,
|
324 |
+
"num_iters": num_iters,
|
325 |
+
"use_fp16": True if "use_fp16" in speed_up_options else False,
|
326 |
+
"use_coarse_to_fine_sampling": True if "use_coarse_to_fine_sampling" in speed_up_options else False
|
327 |
+
}
|
328 |
+
|
329 |
+
# return gr.Video.update(value=save_compare_path), gr.Json.update(value=json_config)
|
330 |
+
# return gr.Video.update(value=save_orig_sample_path), gr.Video.update(value=save_sample_path), gr.Video.update(value=save_compare_path), gr.Json.update(value=json_config)
|
331 |
+
return gr.Video.update(value=save_orig_sample_path), gr.Video.update(value=save_sample_path), gr.Json.update(value=json_config)
|
332 |
+
|
333 |
+
|
334 |
+
controller = AnimateController()
|
335 |
+
|
336 |
+
|
337 |
+
def ui():
|
338 |
+
with gr.Blocks(css=css) as demo:
|
339 |
+
# gr.Markdown('# FreeInit')
|
340 |
+
gr.Markdown(
|
341 |
+
"""
|
342 |
+
<div align="center">
|
343 |
+
<h1>FreeInit</h1>
|
344 |
+
</div>
|
345 |
+
"""
|
346 |
+
)
|
347 |
+
gr.Markdown(
|
348 |
+
"""
|
349 |
+
<p align="center">
|
350 |
+
<a title="Project Page" href="https://tianxingwu.github.io/pages/FreeInit/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
|
351 |
+
<img src="https://img.shields.io/badge/Project-Website-5B7493?logo=googlechrome&logoColor=5B7493">
|
352 |
+
</a>
|
353 |
+
<a title="arXiv" href="https://arxiv.org/abs/2312.07537" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
|
354 |
+
<img src="https://img.shields.io/badge/arXiv-Paper-b31b1b?logo=arxiv&logoColor=b31b1b">
|
355 |
+
</a>
|
356 |
+
<a title="GitHub" href="https://github.com/TianxingWu/FreeInit" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
|
357 |
+
<img src="https://img.shields.io/github/stars/TianxingWu/FreeInit?label=GitHub%20%E2%98%85&&logo=github" alt="badge-github-stars">
|
358 |
+
</a>
|
359 |
+
<a title="Video" href="https://youtu.be/lS5IYbAqriI" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
|
360 |
+
<img src="https://img.shields.io/badge/YouTube-Video-red?logo=youtube&logoColor=red">
|
361 |
+
</a>
|
362 |
+
</p>
|
363 |
+
"""
|
364 |
+
# <a title="Visitor" href="https://hits.seeyoufarm.com" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
|
365 |
+
# <img src="https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fhuggingface.co%2Fspaces%2FTianxingWu%2FFreeInit&count_bg=%23678F74&title_bg=%23555555&icon=&icon_color=%23E7E7E7&title=hits&edge_flat=false">
|
366 |
+
# </a>
|
367 |
+
)
|
368 |
+
gr.Markdown(
|
369 |
+
"""
|
370 |
+
Official Gradio Demo for ***FreeInit: Bridging Initialization Gap in Video Diffusion Models***.<br>
|
371 |
+
FreeInit improves time consistency of diffusion-based video generation at inference time.
|
372 |
+
In this demo, we apply FreeInit on [AnimateDiff v1](https://github.com/guoyww/AnimateDiff) as an example.<br>
|
373 |
+
"""
|
374 |
+
)
|
375 |
+
|
376 |
+
with gr.Row():
|
377 |
+
with gr.Column():
|
378 |
+
# gr.Markdown(
|
379 |
+
# """
|
380 |
+
# ### Usage
|
381 |
+
# 1. Select customized model and motion module in `Model Settings`.
|
382 |
+
# 3. Set `FreeInit Settings`.
|
383 |
+
# 3. Provide `Prompt` and `Negative Prompt` for your selected model. You can refer to each model's webpage on CivitAI to learn how to write prompts for them:
|
384 |
+
# - [`toonyou_beta3.safetensors`](https://civitai.com/models/30240?modelVersionId=78775)
|
385 |
+
# - [`lyriel_v16.safetensors`](https://civitai.com/models/22922/lyriel)
|
386 |
+
# - [`rcnzCartoon3d_v10.safetensors`](https://civitai.com/models/66347?modelVersionId=71009)
|
387 |
+
# - [`majicmixRealistic_v5Preview.safetensors`](https://civitai.com/models/43331?modelVersionId=79068)
|
388 |
+
# - [`realisticVisionV20_v20.safetensors`](https://civitai.com/models/4201?modelVersionId=29460)
|
389 |
+
# 4. Click `Generate`.
|
390 |
+
# """
|
391 |
+
# )
|
392 |
+
prompt_textbox = gr.Textbox( label="Prompt", lines=3, placeholder="Enter your prompt here")
|
393 |
+
negative_prompt_textbox = gr.Textbox( label="Negative Prompt", lines=3, value="worst quality, low quality, nsfw, logo")
|
394 |
+
|
395 |
+
gr.Markdown(
|
396 |
+
"""
|
397 |
+
*Prompt Tips:*
|
398 |
+
|
399 |
+
For each personalized model in `Model Settings`, you can refer to their webpage on CivitAI to learn how to write good prompts for them:
|
400 |
+
- [`realisticVisionV20_v20.safetensors`](https://civitai.com/models/4201?modelVersionId=29460)
|
401 |
+
- [`toonyou_beta3.safetensors`](https://civitai.com/models/30240?modelVersionId=78775)
|
402 |
+
- [`lyriel_v16.safetensors`](https://civitai.com/models/22922/lyriel)
|
403 |
+
- [`rcnzCartoon3d_v10.safetensors`](https://civitai.com/models/66347?modelVersionId=71009)
|
404 |
+
- [`majicmixRealistic_v5Preview.safetensors`](https://civitai.com/models/43331?modelVersionId=79068)
|
405 |
+
"""
|
406 |
+
)
|
407 |
+
|
408 |
+
with gr.Accordion("Model Settings", open=False):
|
409 |
+
gr.Markdown(
|
410 |
+
"""
|
411 |
+
Select personalized model and motion module for AnimateDiff.
|
412 |
+
"""
|
413 |
+
)
|
414 |
+
base_model_dropdown = gr.Dropdown( label="Base DreamBooth Model", choices=controller.base_model_list, value=controller.base_model_list[-2], interactive=True,
|
415 |
+
info="Select personalized text-to-image model from community")
|
416 |
+
motion_module_dropdown = gr.Dropdown( label="Motion Module", choices=controller.motion_module_list, value=controller.motion_module_list[0], interactive=True,
|
417 |
+
info="Select motion module. Recommend mm_sd_v14.ckpt for larger movements.")
|
418 |
+
|
419 |
+
base_model_dropdown.change(fn=controller.update_base_model, inputs=[base_model_dropdown], outputs=[base_model_dropdown])
|
420 |
+
motion_module_dropdown.change(fn=controller.update_motion_module, inputs=[motion_module_dropdown], outputs=[motion_module_dropdown])
|
421 |
+
|
422 |
+
with gr.Accordion("FreeInit Params", open=False):
|
423 |
+
gr.Markdown(
|
424 |
+
"""
|
425 |
+
Adjust to control the smoothness.
|
426 |
+
"""
|
427 |
+
)
|
428 |
+
filter_type_dropdown = gr.Dropdown( label="Filter Type", choices=controller.filter_type_list, value=controller.filter_type_list[0], interactive=True,
|
429 |
+
info="Default as Butterworth. To fix large inconsistencies, consider using Gaussian.")
|
430 |
+
d_s_slider = gr.Slider( label="d_s", value=0.25, minimum=0, maximum=1, step=0.125,
|
431 |
+
info="Stop frequency for spatial dimensions (0.0-1.0)")
|
432 |
+
d_t_slider = gr.Slider( label="d_t", value=0.25, minimum=0, maximum=1, step=0.125,
|
433 |
+
info="Stop frequency for temporal dimension (0.0-1.0)")
|
434 |
+
# num_iters_textbox = gr.Textbox( label="FreeInit Iterations", value=3, info="Sould be integer >1, larger value leads to smoother results)")
|
435 |
+
num_iters_slider = gr.Slider( label="FreeInit Iterations", value=3, minimum=2, maximum=5, step=1,
|
436 |
+
info="Larger value leads to smoother results & longer inference time.")
|
437 |
+
|
438 |
+
with gr.Accordion("Advance", open=False):
|
439 |
+
with gr.Row():
|
440 |
+
width_slider = gr.Slider( label="Width", value=512, minimum=256, maximum=1024, step=64 )
|
441 |
+
height_slider = gr.Slider( label="Height", value=512, minimum=256, maximum=1024, step=64 )
|
442 |
+
with gr.Row():
|
443 |
+
seed_textbox = gr.Textbox( label="Seed", value=1566149281915957)
|
444 |
+
seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
|
445 |
+
seed_button.click(fn=lambda: gr.Textbox.update(value=random.randint(1, 1e16)), inputs=[], outputs=[seed_textbox])
|
446 |
+
with gr.Row():
|
447 |
+
speed_up_options = gr.CheckboxGroup(
|
448 |
+
["use_fp16", "use_coarse_to_fine_sampling"],
|
449 |
+
label="Speed-Up Options",
|
450 |
+
value=["use_fp16"]
|
451 |
+
)
|
452 |
+
|
453 |
+
|
454 |
+
generate_button = gr.Button( value="Generate", variant='primary' )
|
455 |
+
|
456 |
+
|
457 |
+
# with gr.Column():
|
458 |
+
# result_video = gr.Video( label="Generated Animation", interactive=False )
|
459 |
+
# json_config = gr.Json( label="Config", value=None )
|
460 |
+
with gr.Column():
|
461 |
+
with gr.Row():
|
462 |
+
orig_video = gr.Video( label="AnimateDiff", interactive=False )
|
463 |
+
freeinit_video = gr.Video( label="AnimateDiff + FreeInit", interactive=False )
|
464 |
+
# with gr.Row():
|
465 |
+
# compare_video = gr.Video( label="Compare", interactive=False )
|
466 |
+
with gr.Row():
|
467 |
+
json_config = gr.Json( label="Config", value=None )
|
468 |
+
|
469 |
+
inputs = [base_model_dropdown, motion_module_dropdown,
|
470 |
+
prompt_textbox, negative_prompt_textbox, width_slider, height_slider, seed_textbox,
|
471 |
+
filter_type_dropdown, d_s_slider, d_t_slider, num_iters_slider,
|
472 |
+
speed_up_options
|
473 |
+
]
|
474 |
+
# outputs = [result_video, json_config]
|
475 |
+
# outputs = [orig_video, freeinit_video, compare_video, json_config]
|
476 |
+
outputs = [orig_video, freeinit_video, json_config]
|
477 |
+
|
478 |
+
generate_button.click( fn=controller.animate, inputs=inputs, outputs=outputs )
|
479 |
+
|
480 |
+
gr.Examples( fn=controller.animate, examples=examples, inputs=inputs, outputs=outputs, cache_examples=True)
|
481 |
+
|
482 |
+
return demo
|
483 |
+
|
484 |
+
|
485 |
+
if __name__ == "__main__":
|
486 |
+
demo = ui()
|
487 |
+
demo.queue(max_size=20)
|
488 |
+
demo.launch(share=True)
|
configs/inference/inference-v1.yaml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
unet_additional_kwargs:
|
2 |
+
unet_use_cross_frame_attention: false
|
3 |
+
unet_use_temporal_attention: false
|
4 |
+
use_motion_module: true
|
5 |
+
motion_module_resolutions:
|
6 |
+
- 1
|
7 |
+
- 2
|
8 |
+
- 4
|
9 |
+
- 8
|
10 |
+
motion_module_mid_block: false
|
11 |
+
motion_module_decoder_only: false
|
12 |
+
motion_module_type: Vanilla
|
13 |
+
motion_module_kwargs:
|
14 |
+
num_attention_heads: 8
|
15 |
+
num_transformer_block: 1
|
16 |
+
attention_block_types:
|
17 |
+
- Temporal_Self
|
18 |
+
- Temporal_Self
|
19 |
+
temporal_position_encoding: true
|
20 |
+
temporal_position_encoding_max_len: 24
|
21 |
+
temporal_attention_dim_div: 1
|
22 |
+
|
23 |
+
noise_scheduler_kwargs:
|
24 |
+
beta_start: 0.00085
|
25 |
+
beta_end: 0.012
|
26 |
+
beta_schedule: "linear"
|
configs/inference/inference-v2.yaml
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
unet_additional_kwargs:
|
2 |
+
use_inflated_groupnorm: true
|
3 |
+
unet_use_cross_frame_attention: false
|
4 |
+
unet_use_temporal_attention: false
|
5 |
+
use_motion_module: true
|
6 |
+
motion_module_resolutions:
|
7 |
+
- 1
|
8 |
+
- 2
|
9 |
+
- 4
|
10 |
+
- 8
|
11 |
+
motion_module_mid_block: true
|
12 |
+
motion_module_decoder_only: false
|
13 |
+
motion_module_type: Vanilla
|
14 |
+
motion_module_kwargs:
|
15 |
+
num_attention_heads: 8
|
16 |
+
num_transformer_block: 1
|
17 |
+
attention_block_types:
|
18 |
+
- Temporal_Self
|
19 |
+
- Temporal_Self
|
20 |
+
temporal_position_encoding: true
|
21 |
+
temporal_position_encoding_max_len: 32
|
22 |
+
temporal_attention_dim_div: 1
|
23 |
+
|
24 |
+
noise_scheduler_kwargs:
|
25 |
+
beta_start: 0.00085
|
26 |
+
beta_end: 0.012
|
27 |
+
beta_schedule: "linear"
|
configs/prompts/1-ToonYou.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ToonYou:
|
2 |
+
motion_module:
|
3 |
+
- "models/Motion_Module/mm_sd_v14.ckpt"
|
4 |
+
- "models/Motion_Module/mm_sd_v15.ckpt"
|
5 |
+
|
6 |
+
dreambooth_path: "models/DreamBooth_LoRA/toonyou_beta3.safetensors"
|
7 |
+
lora_model_path: ""
|
8 |
+
|
9 |
+
seed: [10788741199826055526, 6520604954829636163, 6519455744612555650, 16372571278361863751]
|
10 |
+
steps: 25
|
11 |
+
guidance_scale: 7.5
|
12 |
+
|
13 |
+
prompt:
|
14 |
+
- "best quality, masterpiece, 1girl, looking at viewer, blurry background, upper body, contemporary, dress"
|
15 |
+
- "masterpiece, best quality, 1girl, solo, cherry blossoms, hanami, pink flower, white flower, spring season, wisteria, petals, flower, plum blossoms, outdoors, falling petals, white hair, black eyes,"
|
16 |
+
- "best quality, masterpiece, 1boy, formal, abstract, looking at viewer, masculine, marble pattern"
|
17 |
+
- "best quality, masterpiece, 1girl, cloudy sky, dandelion, contrapposto, alternate hairstyle,"
|
18 |
+
|
19 |
+
n_prompt:
|
20 |
+
- ""
|
21 |
+
- "badhandv4,easynegative,ng_deepnegative_v1_75t,verybadimagenegative_v1.3, bad-artist, bad_prompt_version2-neg, teeth"
|
22 |
+
- ""
|
23 |
+
- ""
|
configs/prompts/2-Lyriel.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Lyriel:
|
2 |
+
motion_module:
|
3 |
+
- "models/Motion_Module/mm_sd_v14.ckpt"
|
4 |
+
- "models/Motion_Module/mm_sd_v15.ckpt"
|
5 |
+
|
6 |
+
dreambooth_path: "models/DreamBooth_LoRA/lyriel_v16.safetensors"
|
7 |
+
lora_model_path: ""
|
8 |
+
|
9 |
+
seed: [10917152860782582783, 6399018107401806238, 15875751942533906793, 6653196880059936551]
|
10 |
+
steps: 25
|
11 |
+
guidance_scale: 7.5
|
12 |
+
|
13 |
+
prompt:
|
14 |
+
- "dark shot, epic realistic, portrait of halo, sunglasses, blue eyes, tartan scarf, white hair by atey ghailan, by greg rutkowski, by greg tocchini, by james gilleard, by joe fenton, by kaethe butcher, gradient yellow, black, brown and magenta color scheme, grunge aesthetic!!! graffiti tag wall background, art by greg rutkowski and artgerm, soft cinematic light, adobe lightroom, photolab, hdr, intricate, highly detailed, depth of field, faded, neutral colors, hdr, muted colors, hyperdetailed, artstation, cinematic, warm lights, dramatic light, intricate details, complex background, rutkowski, teal and orange"
|
15 |
+
- "A forbidden castle high up in the mountains, pixel art, intricate details2, hdr, intricate details, hyperdetailed5, natural skin texture, hyperrealism, soft light, sharp, game art, key visual, surreal"
|
16 |
+
- "dark theme, medieval portrait of a man sharp features, grim, cold stare, dark colors, Volumetric lighting, baroque oil painting by Greg Rutkowski, Artgerm, WLOP, Alphonse Mucha dynamic lighting hyperdetailed intricately detailed, hdr, muted colors, complex background, hyperrealism, hyperdetailed, amandine van ray"
|
17 |
+
- "As I have gone alone in there and with my treasures bold, I can keep my secret where and hint of riches new and old. Begin it where warm waters halt and take it in a canyon down, not far but too far to walk, put in below the home of brown."
|
18 |
+
|
19 |
+
n_prompt:
|
20 |
+
- "3d, cartoon, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name, young, loli, elf, 3d, illustration"
|
21 |
+
- "3d, cartoon, anime, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, bad anatomy, girl, loli, young, large breasts, red eyes, muscular"
|
22 |
+
- "dof, grayscale, black and white, bw, 3d, cartoon, anime, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, bad anatomy, girl, loli, young, large breasts, red eyes, muscular,badhandsv5-neg, By bad artist -neg 1, monochrome"
|
23 |
+
- "holding an item, cowboy, hat, cartoon, 3d, disfigured, bad art, deformed,extra limbs,close up,b&w, wierd colors, blurry, duplicate, morbid, mutilated, [out of frame], extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, ugly, blurry, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, out of frame, ugly, extra limbs, bad anatomy, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, mutated hands, fused fingers, too many fingers, long neck, Photoshop, video game, ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, mutation, mutated, extra limbs, extra legs, extra arms, disfigured, deformed, cross-eye, body out of frame, blurry, bad art, bad anatomy, 3d render"
|
configs/prompts/3-RcnzCartoon.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
RcnzCartoon:
|
2 |
+
motion_module:
|
3 |
+
- "models/Motion_Module/mm_sd_v14.ckpt"
|
4 |
+
- "models/Motion_Module/mm_sd_v15.ckpt"
|
5 |
+
|
6 |
+
dreambooth_path: "models/DreamBooth_LoRA/rcnzCartoon3d_v10.safetensors"
|
7 |
+
lora_model_path: ""
|
8 |
+
|
9 |
+
seed: [16931037867122267877, 2094308009433392066, 4292543217695451092, 15572665120852309890]
|
10 |
+
steps: 25
|
11 |
+
guidance_scale: 7.5
|
12 |
+
|
13 |
+
prompt:
|
14 |
+
- "Jane Eyre with headphones, natural skin texture,4mm,k textures, soft cinematic light, adobe lightroom, photolab, hdr, intricate, elegant, highly detailed, sharp focus, cinematic look, soothing tones, insane details, intricate details, hyperdetailed, low contrast, soft cinematic light, dim colors, exposure blend, hdr, faded"
|
15 |
+
- "close up Portrait photo of muscular bearded guy in a worn mech suit, light bokeh, intricate, steel metal [rust], elegant, sharp focus, photo by greg rutkowski, soft lighting, vibrant colors, masterpiece, streets, detailed face"
|
16 |
+
- "absurdres, photorealistic, masterpiece, a 30 year old man with gold framed, aviator reading glasses and a black hooded jacket and a beard, professional photo, a character portrait, altermodern, detailed eyes, detailed lips, detailed face, grey eyes"
|
17 |
+
- "a golden labrador, warm vibrant colours, natural lighting, dappled lighting, diffused lighting, absurdres, highres,k, uhd, hdr, rtx, unreal, octane render, RAW photo, photorealistic, global illumination, subsurface scattering"
|
18 |
+
|
19 |
+
n_prompt:
|
20 |
+
- "deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, mutated hands and fingers, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation"
|
21 |
+
- "nude, cross eyed, tongue, open mouth, inside, 3d, cartoon, anime, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, bad anatomy, red eyes, muscular"
|
22 |
+
- "easynegative, cartoon, anime, sketches, necklace, earrings worst quality, low quality, normal quality, bad anatomy, bad hands, shiny skin, error, missing fingers, extra digit, fewer digits, jpeg artifacts, signature, watermark, username, blurry, chubby, anorectic, bad eyes, old, wrinkled skin, red skin, photograph By bad artist -neg, big eyes, muscular face,"
|
23 |
+
- "beard, EasyNegative, lowres, chromatic aberration, depth of field, motion blur, blurry, bokeh, bad quality, worst quality, multiple arms, badhand"
|
configs/prompts/4-MajicMix.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MajicMix:
|
2 |
+
motion_module:
|
3 |
+
- "models/Motion_Module/mm_sd_v14.ckpt"
|
4 |
+
- "models/Motion_Module/mm_sd_v15.ckpt"
|
5 |
+
|
6 |
+
dreambooth_path: "models/DreamBooth_LoRA/majicmixRealistic_v5Preview.safetensors"
|
7 |
+
lora_model_path: ""
|
8 |
+
|
9 |
+
seed: [1572448948722921032, 1099474677988590681, 6488833139725635347, 18339859844376517918]
|
10 |
+
steps: 25
|
11 |
+
guidance_scale: 7.5
|
12 |
+
|
13 |
+
prompt:
|
14 |
+
- "1girl, offshoulder, light smile, shiny skin best quality, masterpiece, photorealistic"
|
15 |
+
- "best quality, masterpiece, photorealistic, 1boy, 50 years old beard, dramatic lighting"
|
16 |
+
- "best quality, masterpiece, photorealistic, 1girl, light smile, shirt with collars, waist up, dramatic lighting, from below"
|
17 |
+
- "male, man, beard, bodybuilder, skinhead,cold face, tough guy, cowboyshot, tattoo, french windows, luxury hotel masterpiece, best quality, photorealistic"
|
18 |
+
|
19 |
+
n_prompt:
|
20 |
+
- "ng_deepnegative_v1_75t, badhandv4, worst quality, low quality, normal quality, lowres, bad anatomy, bad hands, watermark, moles"
|
21 |
+
- "nsfw, ng_deepnegative_v1_75t,badhandv4, worst quality, low quality, normal quality, lowres,watermark, monochrome"
|
22 |
+
- "nsfw, ng_deepnegative_v1_75t,badhandv4, worst quality, low quality, normal quality, lowres,watermark, monochrome"
|
23 |
+
- "nude, nsfw, ng_deepnegative_v1_75t, badhandv4, worst quality, low quality, normal quality, lowres, bad anatomy, bad hands, monochrome, grayscale watermark, moles, people"
|
configs/prompts/5-RealisticVision.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
RealisticVision:
|
2 |
+
motion_module:
|
3 |
+
- "models/Motion_Module/mm_sd_v14.ckpt"
|
4 |
+
- "models/Motion_Module/mm_sd_v15.ckpt"
|
5 |
+
|
6 |
+
dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
|
7 |
+
lora_model_path: ""
|
8 |
+
|
9 |
+
seed: [5658137986800322009, 12099779162349365895, 10499524853910852697, 16768009035333711932]
|
10 |
+
steps: 25
|
11 |
+
guidance_scale: 7.5
|
12 |
+
|
13 |
+
prompt:
|
14 |
+
- "b&w photo of 42 y.o man in black clothes, bald, face, half body, body, high detailed skin, skin pores, coastline, overcast weather, wind, waves, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
|
15 |
+
- "close up photo of a rabbit, forest, haze, halation, bloom, dramatic atmosphere, centred, rule of thirds, 200mm 1.4f macro shot"
|
16 |
+
- "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
|
17 |
+
- "night, b&w photo of old house, post apocalypse, forest, storm weather, wind, rocks, 8k uhd, dslr, soft lighting, high quality, film grain"
|
18 |
+
|
19 |
+
n_prompt:
|
20 |
+
- "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
|
21 |
+
- "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
|
22 |
+
- "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
|
23 |
+
- "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, art, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
|
configs/prompts/6-Tusun.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Tusun:
|
2 |
+
motion_module:
|
3 |
+
- "models/Motion_Module/mm_sd_v14.ckpt"
|
4 |
+
- "models/Motion_Module/mm_sd_v15.ckpt"
|
5 |
+
|
6 |
+
dreambooth_path: "models/DreamBooth_LoRA/moonfilm_reality20.safetensors"
|
7 |
+
lora_model_path: "models/DreamBooth_LoRA/TUSUN.safetensors"
|
8 |
+
lora_alpha: 0.6
|
9 |
+
|
10 |
+
seed: [10154078483724687116, 2664393535095473805, 4231566096207622938, 1713349740448094493]
|
11 |
+
steps: 25
|
12 |
+
guidance_scale: 7.5
|
13 |
+
|
14 |
+
prompt:
|
15 |
+
- "tusuncub with its mouth open, blurry, open mouth, fangs, photo background, looking at viewer, tongue, full body, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing"
|
16 |
+
- "cute tusun with a blurry background, black background, simple background, signature, face, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing"
|
17 |
+
- "cut tusuncub walking in the snow, blurry, looking at viewer, depth of field, blurry background, full body, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing"
|
18 |
+
- "character design, cyberpunk tusun kitten wearing astronaut suit, sci-fic, realistic eye color and details, fluffy, big head, science fiction, communist ideology, Cyborg, fantasy, intense angle, soft lighting, photograph, 4k, hyper detailed, portrait wallpaper, realistic, photo-realistic, DSLR, 24 Megapixels, Full Frame, vibrant details, octane render, finely detail, best quality, incredibly absurdres, robotic parts, rim light, vibrant details, luxurious cyberpunk, hyperrealistic, cable electric wires, microchip, full body"
|
19 |
+
|
20 |
+
n_prompt:
|
21 |
+
- "worst quality, low quality, deformed, distorted, disfigured, bad eyes, bad anatomy, disconnected limbs, wrong body proportions, low quality, worst quality, text, watermark, signatre, logo, illustration, painting, cartoons, ugly, easy_negative"
|
configs/prompts/7-FilmVelvia.yaml
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FilmVelvia:
|
2 |
+
motion_module:
|
3 |
+
- "models/Motion_Module/mm_sd_v14.ckpt"
|
4 |
+
- "models/Motion_Module/mm_sd_v15.ckpt"
|
5 |
+
|
6 |
+
dreambooth_path: "models/DreamBooth_LoRA/majicmixRealistic_v4.safetensors"
|
7 |
+
lora_model_path: "models/DreamBooth_LoRA/FilmVelvia2.safetensors"
|
8 |
+
lora_alpha: 0.6
|
9 |
+
|
10 |
+
seed: [358675358833372813, 3519455280971923743, 11684545350557985081, 8696855302100399877]
|
11 |
+
steps: 25
|
12 |
+
guidance_scale: 7.5
|
13 |
+
|
14 |
+
prompt:
|
15 |
+
- "a woman standing on the side of a road at night,girl, long hair, motor vehicle, car, looking at viewer, ground vehicle, night, hands in pockets, blurry background, coat, black hair, parted lips, bokeh, jacket, brown hair, outdoors, red lips, upper body, artist name"
|
16 |
+
- ", dark shot,0mm, portrait quality of a arab man worker,boy, wasteland that stands out vividly against the background of the desert, barren landscape, closeup, moles skin, soft light, sharp, exposure blend, medium shot, bokeh, hdr, high contrast, cinematic, teal and orange5, muted colors, dim colors, soothing tones, low saturation, hyperdetailed, noir"
|
17 |
+
- "fashion photography portrait of 1girl, offshoulder, fluffy short hair, soft light, rim light, beautiful shadow, low key, photorealistic, raw photo, natural skin texture, realistic eye and face details, hyperrealism, ultra high res, 4K, Best quality, masterpiece, necklace, cleavage, in the dark"
|
18 |
+
- "In this lighthearted portrait, a woman is dressed as a fierce warrior, armed with an arsenal of paintbrushes and palette knives. Her war paint is composed of thick, vibrant strokes of color, and her armor is made of paint tubes and paint-splattered canvases. She stands victoriously atop a mountain of conquered blank canvases, with a beautiful, colorful landscape behind her, symbolizing the power of art and creativity. bust Portrait, close-up, Bright and transparent scene lighting, "
|
19 |
+
|
20 |
+
n_prompt:
|
21 |
+
- "cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg"
|
22 |
+
- "cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg"
|
23 |
+
- "wrong white balance, dark, cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg"
|
24 |
+
- "wrong white balance, dark, cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg"
|
configs/prompts/8-GhibliBackground.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
GhibliBackground:
|
2 |
+
motion_module:
|
3 |
+
- "models/Motion_Module/mm_sd_v14.ckpt"
|
4 |
+
- "models/Motion_Module/mm_sd_v15.ckpt"
|
5 |
+
|
6 |
+
dreambooth_path: "models/DreamBooth_LoRA/CounterfeitV30_25.safetensors"
|
7 |
+
lora_model_path: "models/DreamBooth_LoRA/lora_Ghibli_n3.safetensors"
|
8 |
+
lora_alpha: 1.0
|
9 |
+
|
10 |
+
seed: [8775748474469046618, 5893874876080607656, 11911465742147695752, 12437784838692000640]
|
11 |
+
steps: 25
|
12 |
+
guidance_scale: 7.5
|
13 |
+
|
14 |
+
prompt:
|
15 |
+
- "best quality,single build,architecture, blue_sky, building,cloudy_sky, day, fantasy, fence, field, house, build,architecture,landscape, moss, outdoors, overgrown, path, river, road, rock, scenery, sky, sword, tower, tree, waterfall"
|
16 |
+
- "black_border, building, city, day, fantasy, ice, landscape, letterboxed, mountain, ocean, outdoors, planet, scenery, ship, snow, snowing, water, watercraft, waterfall, winter"
|
17 |
+
- ",mysterious sea area, fantasy,build,concept"
|
18 |
+
- "Tomb Raider,Scenography,Old building"
|
19 |
+
|
20 |
+
n_prompt:
|
21 |
+
- "easynegative,bad_construction,bad_structure,bad_wail,bad_windows,blurry,cloned_window,cropped,deformed,disfigured,error,extra_windows,extra_chimney,extra_door,extra_structure,extra_frame,fewer_digits,fused_structure,gross_proportions,jpeg_artifacts,long_roof,low_quality,structure_limbs,missing_windows,missing_doors,missing_roofs,mutated_structure,mutation,normal_quality,out_of_frame,owres,poorly_drawn_structure,poorly_drawn_house,signature,text,too_many_windows,ugly,username,uta,watermark,worst_quality"
|
configs/prompts/freeinit_examples/RcnzCartoon_v2.yaml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
RealisticVision:
|
2 |
+
inference_config: "configs/inference/inference-v2.yaml"
|
3 |
+
motion_module:
|
4 |
+
# - "models/Motion_Module/mm_sd_v14.ckpt"
|
5 |
+
- "/mnt/petrelfs/sichenyang.p/code/diffsuion/git_code/AnimateDiff/AnimateDiff/models/Motion_Module/mm_sd_v15_v2.ckpt"
|
6 |
+
|
7 |
+
# dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
|
8 |
+
dreambooth_path: "/mnt/petrelfs/sichenyang.p/code/diffsuion/git_code/AnimateDiff/AnimateDiff/models/DreamBooth_LoRA/rcnzCartoon3d_v10.safetensors"
|
9 |
+
lora_model_path: ""
|
10 |
+
|
11 |
+
seed: [849, 502, 1334]
|
12 |
+
steps: 25
|
13 |
+
guidance_scale: 7.5
|
14 |
+
|
15 |
+
filter_params:
|
16 |
+
method: 'butterworth'
|
17 |
+
n: 4
|
18 |
+
d_s: 0.25
|
19 |
+
d_t: 0.25
|
20 |
+
|
21 |
+
# filter_params:
|
22 |
+
# method: 'gaussian'
|
23 |
+
# d_s: 0.25
|
24 |
+
# d_t: 0.25
|
25 |
+
|
26 |
+
prompt:
|
27 |
+
- "Gwen Stacy reading a book"
|
28 |
+
- "A cute raccoon playing guitar in a boat on the ocean"
|
29 |
+
|
30 |
+
n_prompt:
|
31 |
+
- ""
|
32 |
+
- ""
|
33 |
+
|
configs/prompts/freeinit_examples/RealisticVision_v1.yaml
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
RealisticVision:
|
2 |
+
motion_module:
|
3 |
+
# - "models/Motion_Module/mm_sd_v14.ckpt"
|
4 |
+
- "/mnt/petrelfs/sichenyang.p/code/diffsuion/git_code/AnimateDiff/AnimateDiff/models/Motion_Module/mm_sd_v14.ckpt"
|
5 |
+
|
6 |
+
# dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
|
7 |
+
dreambooth_path: "/mnt/petrelfs/sichenyang.p/code/diffsuion/git_code/AnimateDiff/AnimateDiff/models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
|
8 |
+
lora_model_path: ""
|
9 |
+
|
10 |
+
seed: [502, 5206]
|
11 |
+
steps: 25
|
12 |
+
guidance_scale: 7.5
|
13 |
+
|
14 |
+
# filter_params:
|
15 |
+
# method: 'butterworth'
|
16 |
+
# n: 4
|
17 |
+
# d_s: 0.25
|
18 |
+
# d_t: 0.25
|
19 |
+
|
20 |
+
filter_params:
|
21 |
+
method: 'gaussian'
|
22 |
+
d_s: 0.25
|
23 |
+
d_t: 0.25
|
24 |
+
|
25 |
+
prompt:
|
26 |
+
- "A cute raccoon playing guitar in a boat on the ocean."
|
27 |
+
- "A panda standing on a surfboard in the ocean in sunset."
|
28 |
+
|
29 |
+
n_prompt:
|
30 |
+
- ""
|
31 |
+
- ""
|
32 |
+
|
configs/prompts/freeinit_examples/RealisticVision_v2.yaml
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
RealisticVision:
|
2 |
+
inference_config: "configs/inference/inference-v2.yaml"
|
3 |
+
motion_module:
|
4 |
+
# - "models/Motion_Module/mm_sd_v14.ckpt"
|
5 |
+
- "/mnt/petrelfs/sichenyang.p/code/diffsuion/git_code/AnimateDiff/AnimateDiff/models/Motion_Module/mm_sd_v15_v2.ckpt"
|
6 |
+
|
7 |
+
# dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
|
8 |
+
dreambooth_path: "/mnt/petrelfs/sichenyang.p/code/diffsuion/git_code/AnimateDiff/AnimateDiff/models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
|
9 |
+
lora_model_path: ""
|
10 |
+
|
11 |
+
seed: [9620, 913, 6840, 1334]
|
12 |
+
steps: 25
|
13 |
+
guidance_scale: 7.5
|
14 |
+
|
15 |
+
filter_params:
|
16 |
+
method: 'butterworth'
|
17 |
+
n: 4
|
18 |
+
d_s: 0.25
|
19 |
+
d_t: 0.25
|
20 |
+
|
21 |
+
# filter_params:
|
22 |
+
# method: 'gaussian'
|
23 |
+
# d_s: 0.25
|
24 |
+
# d_t: 0.25
|
25 |
+
|
26 |
+
prompt:
|
27 |
+
- "A panda cooking in the kitchen"
|
28 |
+
- "A cat wearing sunglasses and working as a lifeguard at a pool."
|
29 |
+
- "A confused panda in calculus class"
|
30 |
+
- "A robot DJ is playing the turntable, in heavy raining futuristic tokyo rooftop cyberpunk night, sci-fi, fantasy"
|
31 |
+
|
32 |
+
n_prompt:
|
33 |
+
- ""
|
34 |
+
- ""
|
35 |
+
- ""
|
36 |
+
- ""
|
37 |
+
|
configs/prompts/v2/5-RealisticVision-MotionLoRA.yaml
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ZoomIn:
|
2 |
+
inference_config: "configs/inference/inference-v2.yaml"
|
3 |
+
motion_module:
|
4 |
+
- "models/Motion_Module/mm_sd_v15_v2.ckpt"
|
5 |
+
|
6 |
+
motion_module_lora_configs:
|
7 |
+
- path: "models/MotionLoRA/v2_lora_ZoomIn.ckpt"
|
8 |
+
alpha: 1.0
|
9 |
+
|
10 |
+
dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
|
11 |
+
lora_model_path: ""
|
12 |
+
|
13 |
+
seed: 45987230
|
14 |
+
steps: 25
|
15 |
+
guidance_scale: 7.5
|
16 |
+
|
17 |
+
prompt:
|
18 |
+
- "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
|
19 |
+
|
20 |
+
n_prompt:
|
21 |
+
- "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
ZoomOut:
|
26 |
+
inference_config: "configs/inference/inference-v2.yaml"
|
27 |
+
motion_module:
|
28 |
+
- "models/Motion_Module/mm_sd_v15_v2.ckpt"
|
29 |
+
|
30 |
+
motion_module_lora_configs:
|
31 |
+
- path: "models/MotionLoRA/v2_lora_ZoomOut.ckpt"
|
32 |
+
alpha: 1.0
|
33 |
+
|
34 |
+
dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
|
35 |
+
lora_model_path: ""
|
36 |
+
|
37 |
+
seed: 45987230
|
38 |
+
steps: 25
|
39 |
+
guidance_scale: 7.5
|
40 |
+
|
41 |
+
prompt:
|
42 |
+
- "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
|
43 |
+
|
44 |
+
n_prompt:
|
45 |
+
- "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
PanLeft:
|
50 |
+
inference_config: "configs/inference/inference-v2.yaml"
|
51 |
+
motion_module:
|
52 |
+
- "models/Motion_Module/mm_sd_v15_v2.ckpt"
|
53 |
+
|
54 |
+
motion_module_lora_configs:
|
55 |
+
- path: "models/MotionLoRA/v2_lora_PanLeft.ckpt"
|
56 |
+
alpha: 1.0
|
57 |
+
|
58 |
+
dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
|
59 |
+
lora_model_path: ""
|
60 |
+
|
61 |
+
seed: 45987230
|
62 |
+
steps: 25
|
63 |
+
guidance_scale: 7.5
|
64 |
+
|
65 |
+
prompt:
|
66 |
+
- "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
|
67 |
+
|
68 |
+
n_prompt:
|
69 |
+
- "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
PanRight:
|
74 |
+
inference_config: "configs/inference/inference-v2.yaml"
|
75 |
+
motion_module:
|
76 |
+
- "models/Motion_Module/mm_sd_v15_v2.ckpt"
|
77 |
+
|
78 |
+
motion_module_lora_configs:
|
79 |
+
- path: "models/MotionLoRA/v2_lora_PanRight.ckpt"
|
80 |
+
alpha: 1.0
|
81 |
+
|
82 |
+
dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
|
83 |
+
lora_model_path: ""
|
84 |
+
|
85 |
+
seed: 45987230
|
86 |
+
steps: 25
|
87 |
+
guidance_scale: 7.5
|
88 |
+
|
89 |
+
prompt:
|
90 |
+
- "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
|
91 |
+
|
92 |
+
n_prompt:
|
93 |
+
- "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
TiltUp:
|
98 |
+
inference_config: "configs/inference/inference-v2.yaml"
|
99 |
+
motion_module:
|
100 |
+
- "models/Motion_Module/mm_sd_v15_v2.ckpt"
|
101 |
+
|
102 |
+
motion_module_lora_configs:
|
103 |
+
- path: "models/MotionLoRA/v2_lora_TiltUp.ckpt"
|
104 |
+
alpha: 1.0
|
105 |
+
|
106 |
+
dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
|
107 |
+
lora_model_path: ""
|
108 |
+
|
109 |
+
seed: 45987230
|
110 |
+
steps: 25
|
111 |
+
guidance_scale: 7.5
|
112 |
+
|
113 |
+
prompt:
|
114 |
+
- "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
|
115 |
+
|
116 |
+
n_prompt:
|
117 |
+
- "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
|
118 |
+
|
119 |
+
|
120 |
+
|
121 |
+
TiltDown:
|
122 |
+
inference_config: "configs/inference/inference-v2.yaml"
|
123 |
+
motion_module:
|
124 |
+
- "models/Motion_Module/mm_sd_v15_v2.ckpt"
|
125 |
+
|
126 |
+
motion_module_lora_configs:
|
127 |
+
- path: "models/MotionLoRA/v2_lora_TiltDown.ckpt"
|
128 |
+
alpha: 1.0
|
129 |
+
|
130 |
+
dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
|
131 |
+
lora_model_path: ""
|
132 |
+
|
133 |
+
seed: 45987230
|
134 |
+
steps: 25
|
135 |
+
guidance_scale: 7.5
|
136 |
+
|
137 |
+
prompt:
|
138 |
+
- "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
|
139 |
+
|
140 |
+
n_prompt:
|
141 |
+
- "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
|
142 |
+
|
143 |
+
|
144 |
+
|
145 |
+
RollingAnticlockwise:
|
146 |
+
inference_config: "configs/inference/inference-v2.yaml"
|
147 |
+
motion_module:
|
148 |
+
- "models/Motion_Module/mm_sd_v15_v2.ckpt"
|
149 |
+
|
150 |
+
motion_module_lora_configs:
|
151 |
+
- path: "models/MotionLoRA/v2_lora_RollingAnticlockwise.ckpt"
|
152 |
+
alpha: 1.0
|
153 |
+
|
154 |
+
dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
|
155 |
+
lora_model_path: ""
|
156 |
+
|
157 |
+
seed: 45987230
|
158 |
+
steps: 25
|
159 |
+
guidance_scale: 7.5
|
160 |
+
|
161 |
+
prompt:
|
162 |
+
- "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
|
163 |
+
|
164 |
+
n_prompt:
|
165 |
+
- "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
|
166 |
+
|
167 |
+
|
168 |
+
|
169 |
+
RollingClockwise:
|
170 |
+
inference_config: "configs/inference/inference-v2.yaml"
|
171 |
+
motion_module:
|
172 |
+
- "models/Motion_Module/mm_sd_v15_v2.ckpt"
|
173 |
+
|
174 |
+
motion_module_lora_configs:
|
175 |
+
- path: "models/MotionLoRA/v2_lora_RollingClockwise.ckpt"
|
176 |
+
alpha: 1.0
|
177 |
+
|
178 |
+
dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
|
179 |
+
lora_model_path: ""
|
180 |
+
|
181 |
+
seed: 45987230
|
182 |
+
steps: 25
|
183 |
+
guidance_scale: 7.5
|
184 |
+
|
185 |
+
prompt:
|
186 |
+
- "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
|
187 |
+
|
188 |
+
n_prompt:
|
189 |
+
- "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
|
configs/prompts/v2/5-RealisticVision.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
RealisticVision:
|
2 |
+
inference_config: "configs/inference/inference-v2.yaml"
|
3 |
+
motion_module:
|
4 |
+
- "models/Motion_Module/mm_sd_v15_v2.ckpt"
|
5 |
+
|
6 |
+
dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
|
7 |
+
lora_model_path: ""
|
8 |
+
|
9 |
+
seed: [13100322578370451493, 14752961627088720670, 9329399085567825781, 16987697414827649302]
|
10 |
+
steps: 25
|
11 |
+
guidance_scale: 7.5
|
12 |
+
|
13 |
+
prompt:
|
14 |
+
- "b&w photo of 42 y.o man in black clothes, bald, face, half body, body, high detailed skin, skin pores, coastline, overcast weather, wind, waves, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
|
15 |
+
- "close up photo of a rabbit, forest, haze, halation, bloom, dramatic atmosphere, centred, rule of thirds, 200mm 1.4f macro shot"
|
16 |
+
- "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
|
17 |
+
- "night, b&w photo of old house, post apocalypse, forest, storm weather, wind, rocks, 8k uhd, dslr, soft lighting, high quality, film grain"
|
18 |
+
|
19 |
+
n_prompt:
|
20 |
+
- "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
|
21 |
+
- "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
|
22 |
+
- "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
|
23 |
+
- "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, art, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
|
models/DreamBooth_LoRA/Put personalized T2I checkpoints here.txt
ADDED
File without changes
|
models/MotionLoRA/Put MotionLoRA checkpoints here.txt
ADDED
File without changes
|
models/Motion_Module/Put motion module checkpoints here.txt
ADDED
File without changes
|
models/StableDiffusion/Put diffusers stable-diffusion-v1-5 repo here.txt
ADDED
File without changes
|
requirements.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.13.1
|
2 |
+
torchvision==0.14.1
|
3 |
+
torchaudio==0.13.1
|
4 |
+
diffusers==0.11.1
|
5 |
+
transformers==4.25.1
|
6 |
+
xformers==0.0.16
|
7 |
+
imageio==2.27.0
|
8 |
+
gdown
|
9 |
+
einops
|
10 |
+
omegaconf
|
11 |
+
safetensors
|
12 |
+
gradio
|
13 |
+
imageio[ffmpeg]
|
14 |
+
imageio[pyav]
|
15 |
+
accelerate
|