Spaces:
Runtime error
Runtime error
support composable adapter (#5)
Browse files- add composable adapter (0177fec31c28bc1bdccd21e971f1e48dcbaf3d1a)
This view is limited to 50 files because it contains too many changes. Β
See raw diff
- .gitignore +1 -5
- app.py +277 -63
- {models β configs/mm}/faster_rcnn_r50_fpn_coco.py +182 -182
- {models β configs/mm}/hrnet_w48_coco_256x192.py +169 -169
- configs/stable-diffusion/sd-v1-inference.yaml +65 -0
- configs/stable-diffusion/sd-v1-train.yaml +86 -0
- configs/stable-diffusion/train_keypose.yaml +87 -0
- configs/stable-diffusion/train_mask.yaml +87 -0
- configs/stable-diffusion/train_sketch.yaml +87 -0
- demo/demos.py +0 -309
- demo/model.py +0 -979
- dist_util.py +91 -0
- docs/AdapterZoo.md +16 -0
- docs/FAQ.md +5 -0
- docs/examples.md +41 -0
- environment.yaml +0 -31
- ldm/modules/structure_condition/midas/__init__.py β experiments/README.md +0 -0
- ldm/data/base.py +0 -23
- ldm/data/dataset_coco.py +36 -0
- ldm/data/dataset_depth.py +35 -0
- ldm/data/dataset_laion.py +130 -0
- ldm/data/dataset_wikiart.py +67 -0
- ldm/data/imagenet.py +0 -394
- ldm/data/lsun.py +0 -92
- ldm/data/utils.py +40 -0
- ldm/inference_base.py +282 -0
- ldm/models/autoencoder.py +43 -275
- ldm/models/diffusion/classifier.py +0 -267
- ldm/models/diffusion/ddim.py +68 -17
- ldm/models/diffusion/ddpm.py +251 -384
- ldm/models/diffusion/dpm_solver/dpm_solver.py +152 -119
- ldm/models/diffusion/dpm_solver/sampler.py +8 -3
- ldm/models/diffusion/plms.py +23 -48
- ldm/modules/attention.py +4 -0
- ldm/modules/diffusionmodules/openaimodel.py +85 -263
- ldm/modules/diffusionmodules/util.py +5 -2
- ldm/modules/ema.py +12 -8
- ldm/modules/encoders/adapter.py +84 -76
- ldm/modules/encoders/modules.py +349 -142
- ldm/modules/{structure_condition β extra_condition}/__init__.py +0 -0
- ldm/modules/extra_condition/api.py +269 -0
- ldm/modules/{structure_condition/midas β extra_condition}/midas/__init__.py +0 -0
- ldm/modules/{structure_condition β extra_condition}/midas/api.py +4 -4
- ldm/modules/{structure_condition/openpose β extra_condition/midas/midas}/__init__.py +0 -0
- ldm/modules/{structure_condition β extra_condition}/midas/midas/base_model.py +0 -0
- ldm/modules/{structure_condition β extra_condition}/midas/midas/blocks.py +0 -0
- ldm/modules/{structure_condition β extra_condition}/midas/midas/dpt_depth.py +0 -0
- ldm/modules/{structure_condition β extra_condition}/midas/midas/midas_net.py +0 -0
- ldm/modules/{structure_condition β extra_condition}/midas/midas/midas_net_custom.py +0 -0
- ldm/modules/{structure_condition β extra_condition}/midas/midas/transforms.py +0 -0
.gitignore
CHANGED
@@ -1,6 +1,3 @@
|
|
1 |
-
# ignored folders
|
2 |
-
models
|
3 |
-
|
4 |
# ignored folders
|
5 |
tmp/*
|
6 |
|
@@ -23,7 +20,6 @@ version.py
|
|
23 |
|
24 |
# Byte-compiled / optimized / DLL files
|
25 |
__pycache__/
|
26 |
-
*.pyc
|
27 |
*.py[cod]
|
28 |
*$py.class
|
29 |
|
@@ -125,4 +121,4 @@ venv.bak/
|
|
125 |
/site
|
126 |
|
127 |
# mypy
|
128 |
-
.mypy_cache/
|
|
|
|
|
|
|
|
|
1 |
# ignored folders
|
2 |
tmp/*
|
3 |
|
|
|
20 |
|
21 |
# Byte-compiled / optimized / DLL files
|
22 |
__pycache__/
|
|
|
23 |
*.py[cod]
|
24 |
*$py.class
|
25 |
|
|
|
121 |
/site
|
122 |
|
123 |
# mypy
|
124 |
+
.mypy_cache/
|
app.py
CHANGED
@@ -1,29 +1,44 @@
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
from demo.model import Model_all
|
10 |
import gradio as gr
|
11 |
-
from demo.demos import create_demo_keypose, create_demo_sketch, create_demo_draw, create_demo_seg, create_demo_depth, create_demo_depth_keypose, create_demo_color, create_demo_color_sketch, create_demo_openpose, create_demo_style_sketch, create_demo_canny
|
12 |
import torch
|
13 |
-
import
|
14 |
-
import shlex
|
15 |
from huggingface_hub import hf_hub_url
|
|
|
|
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
urls = {
|
18 |
-
'TencentARC/T2I-Adapter':[
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
}
|
22 |
-
|
23 |
-
'https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth',
|
24 |
-
'https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth',
|
25 |
-
'https://github.com/kazuto1011/deeplab-pytorch/releases/download/v1.0/deeplabv2_resnet101_msc-cocostuff164k-100000.pth'
|
26 |
-
]
|
27 |
if os.path.exists('models') == False:
|
28 |
os.mkdir('models')
|
29 |
for repo in urls:
|
@@ -31,58 +46,257 @@ for repo in urls:
|
|
31 |
for file in files:
|
32 |
url = hf_hub_url(repo, file)
|
33 |
name_ckp = url.split('/')[-1]
|
34 |
-
save_path = os.path.join('models',name_ckp)
|
35 |
if os.path.exists(save_path) == False:
|
36 |
subprocess.run(shlex.split(f'wget {url} -O {save_path}'))
|
37 |
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
46 |
|
47 |
-
|
|
|
48 |
|
49 |
-
|
|
|
|
|
50 |
|
51 |
-
|
|
|
|
|
|
|
|
|
52 |
|
53 |
-
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
with gr.Blocks(css='style.css') as demo:
|
57 |
gr.Markdown(DESCRIPTION)
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# demo inspired by https://huggingface.co/spaces/lambdalabs/image-mixer-demo
|
2 |
+
import argparse
|
3 |
+
import copy
|
4 |
import os
|
5 |
+
import shlex
|
6 |
+
import subprocess
|
7 |
+
from functools import partial
|
8 |
+
from itertools import chain
|
9 |
+
|
10 |
+
import cv2
|
|
|
|
|
11 |
import gradio as gr
|
|
|
12 |
import torch
|
13 |
+
from basicsr.utils import tensor2img
|
|
|
14 |
from huggingface_hub import hf_hub_url
|
15 |
+
from pytorch_lightning import seed_everything
|
16 |
+
from torch import autocast
|
17 |
|
18 |
+
from ldm.inference_base import (DEFAULT_NEGATIVE_PROMPT, diffusion_inference,
|
19 |
+
get_adapters, get_sd_models)
|
20 |
+
from ldm.modules.extra_condition import api
|
21 |
+
from ldm.modules.extra_condition.api import (ExtraCondition,
|
22 |
+
get_adapter_feature,
|
23 |
+
get_cond_model)
|
24 |
+
|
25 |
+
torch.set_grad_enabled(False)
|
26 |
+
|
27 |
+
supported_cond = ['style', 'color', 'canny', 'sketch', 'openpose', 'depth']
|
28 |
+
|
29 |
+
# download the checkpoints
|
30 |
urls = {
|
31 |
+
'TencentARC/T2I-Adapter': [
|
32 |
+
'models/t2iadapter_keypose_sd14v1.pth', 'models/t2iadapter_color_sd14v1.pth',
|
33 |
+
'models/t2iadapter_openpose_sd14v1.pth', 'models/t2iadapter_seg_sd14v1.pth',
|
34 |
+
'models/t2iadapter_sketch_sd14v1.pth', 'models/t2iadapter_depth_sd14v1.pth',
|
35 |
+
'third-party-models/body_pose_model.pth', "models/t2iadapter_style_sd14v1.pth",
|
36 |
+
"models/t2iadapter_canny_sd14v1.pth", "third-party-models/table5_pidinet.pth"
|
37 |
+
],
|
38 |
+
'runwayml/stable-diffusion-v1-5': ['v1-5-pruned-emaonly.ckpt'],
|
39 |
+
'andite/anything-v4.0': ['anything-v4.0-pruned.ckpt', 'anything-v4.0.vae.pt'],
|
40 |
}
|
41 |
+
|
|
|
|
|
|
|
|
|
42 |
if os.path.exists('models') == False:
|
43 |
os.mkdir('models')
|
44 |
for repo in urls:
|
|
|
46 |
for file in files:
|
47 |
url = hf_hub_url(repo, file)
|
48 |
name_ckp = url.split('/')[-1]
|
49 |
+
save_path = os.path.join('models', name_ckp)
|
50 |
if os.path.exists(save_path) == False:
|
51 |
subprocess.run(shlex.split(f'wget {url} -O {save_path}'))
|
52 |
|
53 |
+
# config
|
54 |
+
parser = argparse.ArgumentParser()
|
55 |
+
parser.add_argument(
|
56 |
+
'--sd_ckpt',
|
57 |
+
type=str,
|
58 |
+
default='models/v1-5-pruned-emaonly.ckpt',
|
59 |
+
help='path to checkpoint of stable diffusion model, both .ckpt and .safetensor are supported',
|
60 |
+
)
|
61 |
+
parser.add_argument(
|
62 |
+
'--vae_ckpt',
|
63 |
+
type=str,
|
64 |
+
default=None,
|
65 |
+
help='vae checkpoint, anime SD models usually have seperate vae ckpt that need to be loaded',
|
66 |
+
)
|
67 |
+
global_opt = parser.parse_args()
|
68 |
+
global_opt.config = 'configs/stable-diffusion/sd-v1-inference.yaml'
|
69 |
+
for cond_name in supported_cond:
|
70 |
+
setattr(global_opt, f'{cond_name}_adapter_ckpt', f'models/t2iadapter_{cond_name}_sd14v1.pth')
|
71 |
+
global_opt.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
72 |
+
global_opt.max_resolution = 512 * 512
|
73 |
+
global_opt.sampler = 'ddim'
|
74 |
+
global_opt.cond_weight = 1.0
|
75 |
+
global_opt.C = 4
|
76 |
+
global_opt.f = 8
|
77 |
+
|
78 |
+
# stable-diffusion model
|
79 |
+
sd_model, sampler = get_sd_models(global_opt)
|
80 |
+
# adapters and models to processing condition inputs
|
81 |
+
adapters = {}
|
82 |
+
cond_models = {}
|
83 |
+
torch.cuda.empty_cache()
|
84 |
+
|
85 |
+
|
86 |
+
def run(*args):
|
87 |
+
with torch.inference_mode(), \
|
88 |
+
sd_model.ema_scope(), \
|
89 |
+
autocast('cuda'):
|
90 |
+
|
91 |
+
inps = []
|
92 |
+
for i in range(0, len(args) - 8, len(supported_cond)):
|
93 |
+
inps.append(args[i:i + len(supported_cond)])
|
94 |
+
|
95 |
+
opt = copy.deepcopy(global_opt)
|
96 |
+
opt.prompt, opt.neg_prompt, opt.scale, opt.n_samples, opt.seed, opt.steps, opt.resize_short_edge, opt.cond_tau \
|
97 |
+
= args[-8:]
|
98 |
+
|
99 |
+
conds = []
|
100 |
+
activated_conds = []
|
101 |
+
|
102 |
+
ims1 = []
|
103 |
+
ims2 = []
|
104 |
+
for idx, (b, im1, im2, cond_weight) in enumerate(zip(*inps)):
|
105 |
+
if idx > 1:
|
106 |
+
if im1 is not None or im2 is not None:
|
107 |
+
if im1 is not None:
|
108 |
+
h, w, _ = im1.shape
|
109 |
+
else:
|
110 |
+
h, w, _ = im2.shape
|
111 |
+
break
|
112 |
+
# resize all the images to the same size
|
113 |
+
for idx, (b, im1, im2, cond_weight) in enumerate(zip(*inps)):
|
114 |
+
if idx == 0:
|
115 |
+
ims1.append(im1)
|
116 |
+
ims2.append(im2)
|
117 |
+
continue
|
118 |
+
if im1 is not None:
|
119 |
+
im1 = cv2.resize(im1, (w, h), interpolation=cv2.INTER_CUBIC)
|
120 |
+
if im2 is not None:
|
121 |
+
im2 = cv2.resize(im2, (w, h), interpolation=cv2.INTER_CUBIC)
|
122 |
+
ims1.append(im1)
|
123 |
+
ims2.append(im2)
|
124 |
+
|
125 |
+
for idx, (b, _, _, cond_weight) in enumerate(zip(*inps)):
|
126 |
+
cond_name = supported_cond[idx]
|
127 |
+
if b == 'Nothing':
|
128 |
+
if cond_name in adapters:
|
129 |
+
adapters[cond_name]['model'] = adapters[cond_name]['model'].cpu()
|
130 |
+
else:
|
131 |
+
activated_conds.append(cond_name)
|
132 |
+
if cond_name in adapters:
|
133 |
+
adapters[cond_name]['model'] = adapters[cond_name]['model'].to(opt.device)
|
134 |
+
else:
|
135 |
+
adapters[cond_name] = get_adapters(opt, getattr(ExtraCondition, cond_name))
|
136 |
+
adapters[cond_name]['cond_weight'] = cond_weight
|
137 |
+
|
138 |
+
process_cond_module = getattr(api, f'get_cond_{cond_name}')
|
139 |
|
140 |
+
if b == 'Image':
|
141 |
+
if cond_name not in cond_models:
|
142 |
+
cond_models[cond_name] = get_cond_model(opt, getattr(ExtraCondition, cond_name))
|
143 |
+
conds.append(process_cond_module(opt, ims1[idx], 'image', cond_models[cond_name]))
|
144 |
+
else:
|
145 |
+
conds.append(process_cond_module(opt, ims2[idx], cond_name, None))
|
146 |
|
147 |
+
adapter_features, append_to_context = get_adapter_feature(
|
148 |
+
conds, [adapters[cond_name] for cond_name in activated_conds])
|
149 |
|
150 |
+
output_conds = []
|
151 |
+
for cond in conds:
|
152 |
+
output_conds.append(tensor2img(cond, rgb2bgr=False))
|
153 |
|
154 |
+
ims = []
|
155 |
+
seed_everything(opt.seed)
|
156 |
+
for _ in range(opt.n_samples):
|
157 |
+
result = diffusion_inference(opt, sd_model, sampler, adapter_features, append_to_context)
|
158 |
+
ims.append(tensor2img(result, rgb2bgr=False))
|
159 |
|
160 |
+
# Clear GPU memory cache so less likely to OOM
|
161 |
+
torch.cuda.empty_cache()
|
162 |
+
return ims, output_conds
|
163 |
+
|
164 |
+
|
165 |
+
def change_visible(im1, im2, val):
|
166 |
+
outputs = {}
|
167 |
+
if val == "Image":
|
168 |
+
outputs[im1] = gr.update(visible=True)
|
169 |
+
outputs[im2] = gr.update(visible=False)
|
170 |
+
elif val == "Nothing":
|
171 |
+
outputs[im1] = gr.update(visible=False)
|
172 |
+
outputs[im2] = gr.update(visible=False)
|
173 |
+
else:
|
174 |
+
outputs[im1] = gr.update(visible=False)
|
175 |
+
outputs[im2] = gr.update(visible=True)
|
176 |
+
return outputs
|
177 |
+
|
178 |
+
|
179 |
+
DESCRIPTION = '# [Composable T2I-Adapter](https://github.com/TencentARC/T2I-Adapter)'
|
180 |
+
|
181 |
+
DESCRIPTION += f'<p>Gradio demo for **T2I-Adapter**: [[GitHub]](https://github.com/TencentARC/T2I-Adapter), [[Paper]](https://arxiv.org/abs/2302.08453). If T2I-Adapter is helpful, please help to β the [Github Repo](https://github.com/TencentARC/T2I-Adapter) and recommend it to your friends π </p>'
|
182 |
+
|
183 |
+
DESCRIPTION += f'<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/Adapter/T2I-Adapter?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>'
|
184 |
|
185 |
with gr.Blocks(css='style.css') as demo:
|
186 |
gr.Markdown(DESCRIPTION)
|
187 |
+
|
188 |
+
btns = []
|
189 |
+
ims1 = []
|
190 |
+
ims2 = []
|
191 |
+
cond_weights = []
|
192 |
+
|
193 |
+
with gr.Row():
|
194 |
+
with gr.Column(scale=1.9):
|
195 |
+
with gr.Box():
|
196 |
+
gr.Markdown("<h5><center>Style & Color</center></h5>")
|
197 |
+
with gr.Row():
|
198 |
+
for cond_name in supported_cond[:2]:
|
199 |
+
with gr.Box():
|
200 |
+
with gr.Column():
|
201 |
+
if cond_name == 'style':
|
202 |
+
btn1 = gr.Radio(
|
203 |
+
choices=["Image", "Nothing"],
|
204 |
+
label=f"Input type for {cond_name}",
|
205 |
+
interactive=True,
|
206 |
+
value="Nothing",
|
207 |
+
)
|
208 |
+
else:
|
209 |
+
btn1 = gr.Radio(
|
210 |
+
choices=["Image", cond_name, "Nothing"],
|
211 |
+
label=f"Input type for {cond_name}",
|
212 |
+
interactive=True,
|
213 |
+
value="Nothing",
|
214 |
+
)
|
215 |
+
im1 = gr.Image(
|
216 |
+
source='upload', label="Image", interactive=True, visible=False, type="numpy")
|
217 |
+
im2 = gr.Image(
|
218 |
+
source='upload', label=cond_name, interactive=True, visible=False, type="numpy")
|
219 |
+
cond_weight = gr.Slider(
|
220 |
+
label="Condition weight",
|
221 |
+
minimum=0,
|
222 |
+
maximum=5,
|
223 |
+
step=0.05,
|
224 |
+
value=1,
|
225 |
+
interactive=True)
|
226 |
+
|
227 |
+
fn = partial(change_visible, im1, im2)
|
228 |
+
btn1.change(fn=fn, inputs=[btn1], outputs=[im1, im2], queue=False)
|
229 |
+
|
230 |
+
btns.append(btn1)
|
231 |
+
ims1.append(im1)
|
232 |
+
ims2.append(im2)
|
233 |
+
cond_weights.append(cond_weight)
|
234 |
+
with gr.Column(scale=4):
|
235 |
+
with gr.Box():
|
236 |
+
gr.Markdown("<h5><center>Structure</center></h5>")
|
237 |
+
with gr.Row():
|
238 |
+
for cond_name in supported_cond[2:6]:
|
239 |
+
with gr.Box():
|
240 |
+
with gr.Column():
|
241 |
+
if cond_name == 'openpose':
|
242 |
+
btn1 = gr.Radio(
|
243 |
+
choices=["Image", 'pose', "Nothing"],
|
244 |
+
label=f"Input type for {cond_name}",
|
245 |
+
interactive=True,
|
246 |
+
value="Nothing",
|
247 |
+
)
|
248 |
+
else:
|
249 |
+
btn1 = gr.Radio(
|
250 |
+
choices=["Image", cond_name, "Nothing"],
|
251 |
+
label=f"Input type for {cond_name}",
|
252 |
+
interactive=True,
|
253 |
+
value="Nothing",
|
254 |
+
)
|
255 |
+
im1 = gr.Image(
|
256 |
+
source='upload', label="Image", interactive=True, visible=False, type="numpy")
|
257 |
+
im2 = gr.Image(
|
258 |
+
source='upload', label=cond_name, interactive=True, visible=False, type="numpy")
|
259 |
+
cond_weight = gr.Slider(
|
260 |
+
label="Condition weight",
|
261 |
+
minimum=0,
|
262 |
+
maximum=5,
|
263 |
+
step=0.05,
|
264 |
+
value=1,
|
265 |
+
interactive=True)
|
266 |
+
|
267 |
+
fn = partial(change_visible, im1, im2)
|
268 |
+
btn1.change(fn=fn, inputs=[btn1], outputs=[im1, im2], queue=False)
|
269 |
+
|
270 |
+
btns.append(btn1)
|
271 |
+
ims1.append(im1)
|
272 |
+
ims2.append(im2)
|
273 |
+
cond_weights.append(cond_weight)
|
274 |
+
|
275 |
+
with gr.Column():
|
276 |
+
prompt = gr.Textbox(label="Prompt")
|
277 |
+
|
278 |
+
with gr.Accordion('Advanced options', open=False):
|
279 |
+
neg_prompt = gr.Textbox(label="Negative Prompt", value=DEFAULT_NEGATIVE_PROMPT)
|
280 |
+
scale = gr.Slider(
|
281 |
+
label="Guidance Scale (Classifier free guidance)", value=7.5, minimum=1, maximum=20, step=0.1)
|
282 |
+
n_samples = gr.Slider(label="Num samples", value=1, minimum=1, maximum=8, step=1)
|
283 |
+
seed = gr.Slider(label="Seed", value=42, minimum=0, maximum=10000, step=1)
|
284 |
+
steps = gr.Slider(label="Steps", value=50, minimum=10, maximum=100, step=1)
|
285 |
+
resize_short_edge = gr.Slider(label="Image resolution", value=512, minimum=320, maximum=1024, step=1)
|
286 |
+
cond_tau = gr.Slider(
|
287 |
+
label="timestamp parameter that determines until which step the adapter is applied",
|
288 |
+
value=1.0,
|
289 |
+
minimum=0.1,
|
290 |
+
maximum=1.0,
|
291 |
+
step=0.05)
|
292 |
+
|
293 |
+
with gr.Row():
|
294 |
+
submit = gr.Button("Generate")
|
295 |
+
output = gr.Gallery().style(grid=2, height='auto')
|
296 |
+
cond = gr.Gallery().style(grid=2, height='auto')
|
297 |
+
|
298 |
+
inps = list(chain(btns, ims1, ims2, cond_weights))
|
299 |
+
|
300 |
+
inps.extend([prompt, neg_prompt, scale, n_samples, seed, steps, resize_short_edge, cond_tau])
|
301 |
+
submit.click(fn=run, inputs=inps, outputs=[output, cond])
|
302 |
+
demo.launch(server_name='0.0.0.0', share=False, server_port=47313)
|
{models β configs/mm}/faster_rcnn_r50_fpn_coco.py
RENAMED
@@ -1,182 +1,182 @@
|
|
1 |
-
checkpoint_config = dict(interval=1)
|
2 |
-
# yapf:disable
|
3 |
-
log_config = dict(
|
4 |
-
interval=50,
|
5 |
-
hooks=[
|
6 |
-
dict(type='TextLoggerHook'),
|
7 |
-
# dict(type='TensorboardLoggerHook')
|
8 |
-
])
|
9 |
-
# yapf:enable
|
10 |
-
dist_params = dict(backend='nccl')
|
11 |
-
log_level = 'INFO'
|
12 |
-
load_from = None
|
13 |
-
resume_from = None
|
14 |
-
workflow = [('train', 1)]
|
15 |
-
# optimizer
|
16 |
-
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
|
17 |
-
optimizer_config = dict(grad_clip=None)
|
18 |
-
# learning policy
|
19 |
-
lr_config = dict(
|
20 |
-
policy='step',
|
21 |
-
warmup='linear',
|
22 |
-
warmup_iters=500,
|
23 |
-
warmup_ratio=0.001,
|
24 |
-
step=[8, 11])
|
25 |
-
total_epochs = 12
|
26 |
-
|
27 |
-
model = dict(
|
28 |
-
type='FasterRCNN',
|
29 |
-
pretrained='torchvision://resnet50',
|
30 |
-
backbone=dict(
|
31 |
-
type='ResNet',
|
32 |
-
depth=50,
|
33 |
-
num_stages=4,
|
34 |
-
out_indices=(0, 1, 2, 3),
|
35 |
-
frozen_stages=1,
|
36 |
-
norm_cfg=dict(type='BN', requires_grad=True),
|
37 |
-
norm_eval=True,
|
38 |
-
style='pytorch'),
|
39 |
-
neck=dict(
|
40 |
-
type='FPN',
|
41 |
-
in_channels=[256, 512, 1024, 2048],
|
42 |
-
out_channels=256,
|
43 |
-
num_outs=5),
|
44 |
-
rpn_head=dict(
|
45 |
-
type='RPNHead',
|
46 |
-
in_channels=256,
|
47 |
-
feat_channels=256,
|
48 |
-
anchor_generator=dict(
|
49 |
-
type='AnchorGenerator',
|
50 |
-
scales=[8],
|
51 |
-
ratios=[0.5, 1.0, 2.0],
|
52 |
-
strides=[4, 8, 16, 32, 64]),
|
53 |
-
bbox_coder=dict(
|
54 |
-
type='DeltaXYWHBBoxCoder',
|
55 |
-
target_means=[.0, .0, .0, .0],
|
56 |
-
target_stds=[1.0, 1.0, 1.0, 1.0]),
|
57 |
-
loss_cls=dict(
|
58 |
-
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
|
59 |
-
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
|
60 |
-
roi_head=dict(
|
61 |
-
type='StandardRoIHead',
|
62 |
-
bbox_roi_extractor=dict(
|
63 |
-
type='SingleRoIExtractor',
|
64 |
-
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
|
65 |
-
out_channels=256,
|
66 |
-
featmap_strides=[4, 8, 16, 32]),
|
67 |
-
bbox_head=dict(
|
68 |
-
type='Shared2FCBBoxHead',
|
69 |
-
in_channels=256,
|
70 |
-
fc_out_channels=1024,
|
71 |
-
roi_feat_size=7,
|
72 |
-
num_classes=80,
|
73 |
-
bbox_coder=dict(
|
74 |
-
type='DeltaXYWHBBoxCoder',
|
75 |
-
target_means=[0., 0., 0., 0.],
|
76 |
-
target_stds=[0.1, 0.1, 0.2, 0.2]),
|
77 |
-
reg_class_agnostic=False,
|
78 |
-
loss_cls=dict(
|
79 |
-
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
|
80 |
-
loss_bbox=dict(type='L1Loss', loss_weight=1.0))),
|
81 |
-
# model training and testing settings
|
82 |
-
train_cfg=dict(
|
83 |
-
rpn=dict(
|
84 |
-
assigner=dict(
|
85 |
-
type='MaxIoUAssigner',
|
86 |
-
pos_iou_thr=0.7,
|
87 |
-
neg_iou_thr=0.3,
|
88 |
-
min_pos_iou=0.3,
|
89 |
-
match_low_quality=True,
|
90 |
-
ignore_iof_thr=-1),
|
91 |
-
sampler=dict(
|
92 |
-
type='RandomSampler',
|
93 |
-
num=256,
|
94 |
-
pos_fraction=0.5,
|
95 |
-
neg_pos_ub=-1,
|
96 |
-
add_gt_as_proposals=False),
|
97 |
-
allowed_border=-1,
|
98 |
-
pos_weight=-1,
|
99 |
-
debug=False),
|
100 |
-
rpn_proposal=dict(
|
101 |
-
nms_pre=2000,
|
102 |
-
max_per_img=1000,
|
103 |
-
nms=dict(type='nms', iou_threshold=0.7),
|
104 |
-
min_bbox_size=0),
|
105 |
-
rcnn=dict(
|
106 |
-
assigner=dict(
|
107 |
-
type='MaxIoUAssigner',
|
108 |
-
pos_iou_thr=0.5,
|
109 |
-
neg_iou_thr=0.5,
|
110 |
-
min_pos_iou=0.5,
|
111 |
-
match_low_quality=False,
|
112 |
-
ignore_iof_thr=-1),
|
113 |
-
sampler=dict(
|
114 |
-
type='RandomSampler',
|
115 |
-
num=512,
|
116 |
-
pos_fraction=0.25,
|
117 |
-
neg_pos_ub=-1,
|
118 |
-
add_gt_as_proposals=True),
|
119 |
-
pos_weight=-1,
|
120 |
-
debug=False)),
|
121 |
-
test_cfg=dict(
|
122 |
-
rpn=dict(
|
123 |
-
nms_pre=1000,
|
124 |
-
max_per_img=1000,
|
125 |
-
nms=dict(type='nms', iou_threshold=0.7),
|
126 |
-
min_bbox_size=0),
|
127 |
-
rcnn=dict(
|
128 |
-
score_thr=0.05,
|
129 |
-
nms=dict(type='nms', iou_threshold=0.5),
|
130 |
-
max_per_img=100)
|
131 |
-
# soft-nms is also supported for rcnn testing
|
132 |
-
# e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05)
|
133 |
-
))
|
134 |
-
|
135 |
-
dataset_type = 'CocoDataset'
|
136 |
-
data_root = 'data/coco'
|
137 |
-
img_norm_cfg = dict(
|
138 |
-
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
139 |
-
train_pipeline = [
|
140 |
-
dict(type='LoadImageFromFile'),
|
141 |
-
dict(type='LoadAnnotations', with_bbox=True),
|
142 |
-
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
|
143 |
-
dict(type='RandomFlip', flip_ratio=0.5),
|
144 |
-
dict(type='Normalize', **img_norm_cfg),
|
145 |
-
dict(type='Pad', size_divisor=32),
|
146 |
-
dict(type='DefaultFormatBundle'),
|
147 |
-
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
|
148 |
-
]
|
149 |
-
test_pipeline = [
|
150 |
-
dict(type='LoadImageFromFile'),
|
151 |
-
dict(
|
152 |
-
type='MultiScaleFlipAug',
|
153 |
-
img_scale=(1333, 800),
|
154 |
-
flip=False,
|
155 |
-
transforms=[
|
156 |
-
dict(type='Resize', keep_ratio=True),
|
157 |
-
dict(type='RandomFlip'),
|
158 |
-
dict(type='Normalize', **img_norm_cfg),
|
159 |
-
dict(type='Pad', size_divisor=32),
|
160 |
-
dict(type='DefaultFormatBundle'),
|
161 |
-
dict(type='Collect', keys=['img']),
|
162 |
-
])
|
163 |
-
]
|
164 |
-
data = dict(
|
165 |
-
samples_per_gpu=2,
|
166 |
-
workers_per_gpu=2,
|
167 |
-
train=dict(
|
168 |
-
type=dataset_type,
|
169 |
-
ann_file=f'{data_root}/annotations/instances_train2017.json',
|
170 |
-
img_prefix=f'{data_root}/train2017/',
|
171 |
-
pipeline=train_pipeline),
|
172 |
-
val=dict(
|
173 |
-
type=dataset_type,
|
174 |
-
ann_file=f'{data_root}/annotations/instances_val2017.json',
|
175 |
-
img_prefix=f'{data_root}/val2017/',
|
176 |
-
pipeline=test_pipeline),
|
177 |
-
test=dict(
|
178 |
-
type=dataset_type,
|
179 |
-
ann_file=f'{data_root}/annotations/instances_val2017.json',
|
180 |
-
img_prefix=f'{data_root}/val2017/',
|
181 |
-
pipeline=test_pipeline))
|
182 |
-
evaluation = dict(interval=1, metric='bbox')
|
|
|
1 |
+
checkpoint_config = dict(interval=1)
|
2 |
+
# yapf:disable
|
3 |
+
log_config = dict(
|
4 |
+
interval=50,
|
5 |
+
hooks=[
|
6 |
+
dict(type='TextLoggerHook'),
|
7 |
+
# dict(type='TensorboardLoggerHook')
|
8 |
+
])
|
9 |
+
# yapf:enable
|
10 |
+
dist_params = dict(backend='nccl')
|
11 |
+
log_level = 'INFO'
|
12 |
+
load_from = None
|
13 |
+
resume_from = None
|
14 |
+
workflow = [('train', 1)]
|
15 |
+
# optimizer
|
16 |
+
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
|
17 |
+
optimizer_config = dict(grad_clip=None)
|
18 |
+
# learning policy
|
19 |
+
lr_config = dict(
|
20 |
+
policy='step',
|
21 |
+
warmup='linear',
|
22 |
+
warmup_iters=500,
|
23 |
+
warmup_ratio=0.001,
|
24 |
+
step=[8, 11])
|
25 |
+
total_epochs = 12
|
26 |
+
|
27 |
+
model = dict(
|
28 |
+
type='FasterRCNN',
|
29 |
+
pretrained='torchvision://resnet50',
|
30 |
+
backbone=dict(
|
31 |
+
type='ResNet',
|
32 |
+
depth=50,
|
33 |
+
num_stages=4,
|
34 |
+
out_indices=(0, 1, 2, 3),
|
35 |
+
frozen_stages=1,
|
36 |
+
norm_cfg=dict(type='BN', requires_grad=True),
|
37 |
+
norm_eval=True,
|
38 |
+
style='pytorch'),
|
39 |
+
neck=dict(
|
40 |
+
type='FPN',
|
41 |
+
in_channels=[256, 512, 1024, 2048],
|
42 |
+
out_channels=256,
|
43 |
+
num_outs=5),
|
44 |
+
rpn_head=dict(
|
45 |
+
type='RPNHead',
|
46 |
+
in_channels=256,
|
47 |
+
feat_channels=256,
|
48 |
+
anchor_generator=dict(
|
49 |
+
type='AnchorGenerator',
|
50 |
+
scales=[8],
|
51 |
+
ratios=[0.5, 1.0, 2.0],
|
52 |
+
strides=[4, 8, 16, 32, 64]),
|
53 |
+
bbox_coder=dict(
|
54 |
+
type='DeltaXYWHBBoxCoder',
|
55 |
+
target_means=[.0, .0, .0, .0],
|
56 |
+
target_stds=[1.0, 1.0, 1.0, 1.0]),
|
57 |
+
loss_cls=dict(
|
58 |
+
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
|
59 |
+
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
|
60 |
+
roi_head=dict(
|
61 |
+
type='StandardRoIHead',
|
62 |
+
bbox_roi_extractor=dict(
|
63 |
+
type='SingleRoIExtractor',
|
64 |
+
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
|
65 |
+
out_channels=256,
|
66 |
+
featmap_strides=[4, 8, 16, 32]),
|
67 |
+
bbox_head=dict(
|
68 |
+
type='Shared2FCBBoxHead',
|
69 |
+
in_channels=256,
|
70 |
+
fc_out_channels=1024,
|
71 |
+
roi_feat_size=7,
|
72 |
+
num_classes=80,
|
73 |
+
bbox_coder=dict(
|
74 |
+
type='DeltaXYWHBBoxCoder',
|
75 |
+
target_means=[0., 0., 0., 0.],
|
76 |
+
target_stds=[0.1, 0.1, 0.2, 0.2]),
|
77 |
+
reg_class_agnostic=False,
|
78 |
+
loss_cls=dict(
|
79 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
|
80 |
+
loss_bbox=dict(type='L1Loss', loss_weight=1.0))),
|
81 |
+
# model training and testing settings
|
82 |
+
train_cfg=dict(
|
83 |
+
rpn=dict(
|
84 |
+
assigner=dict(
|
85 |
+
type='MaxIoUAssigner',
|
86 |
+
pos_iou_thr=0.7,
|
87 |
+
neg_iou_thr=0.3,
|
88 |
+
min_pos_iou=0.3,
|
89 |
+
match_low_quality=True,
|
90 |
+
ignore_iof_thr=-1),
|
91 |
+
sampler=dict(
|
92 |
+
type='RandomSampler',
|
93 |
+
num=256,
|
94 |
+
pos_fraction=0.5,
|
95 |
+
neg_pos_ub=-1,
|
96 |
+
add_gt_as_proposals=False),
|
97 |
+
allowed_border=-1,
|
98 |
+
pos_weight=-1,
|
99 |
+
debug=False),
|
100 |
+
rpn_proposal=dict(
|
101 |
+
nms_pre=2000,
|
102 |
+
max_per_img=1000,
|
103 |
+
nms=dict(type='nms', iou_threshold=0.7),
|
104 |
+
min_bbox_size=0),
|
105 |
+
rcnn=dict(
|
106 |
+
assigner=dict(
|
107 |
+
type='MaxIoUAssigner',
|
108 |
+
pos_iou_thr=0.5,
|
109 |
+
neg_iou_thr=0.5,
|
110 |
+
min_pos_iou=0.5,
|
111 |
+
match_low_quality=False,
|
112 |
+
ignore_iof_thr=-1),
|
113 |
+
sampler=dict(
|
114 |
+
type='RandomSampler',
|
115 |
+
num=512,
|
116 |
+
pos_fraction=0.25,
|
117 |
+
neg_pos_ub=-1,
|
118 |
+
add_gt_as_proposals=True),
|
119 |
+
pos_weight=-1,
|
120 |
+
debug=False)),
|
121 |
+
test_cfg=dict(
|
122 |
+
rpn=dict(
|
123 |
+
nms_pre=1000,
|
124 |
+
max_per_img=1000,
|
125 |
+
nms=dict(type='nms', iou_threshold=0.7),
|
126 |
+
min_bbox_size=0),
|
127 |
+
rcnn=dict(
|
128 |
+
score_thr=0.05,
|
129 |
+
nms=dict(type='nms', iou_threshold=0.5),
|
130 |
+
max_per_img=100)
|
131 |
+
# soft-nms is also supported for rcnn testing
|
132 |
+
# e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05)
|
133 |
+
))
|
134 |
+
|
135 |
+
dataset_type = 'CocoDataset'
|
136 |
+
data_root = 'data/coco'
|
137 |
+
img_norm_cfg = dict(
|
138 |
+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
139 |
+
train_pipeline = [
|
140 |
+
dict(type='LoadImageFromFile'),
|
141 |
+
dict(type='LoadAnnotations', with_bbox=True),
|
142 |
+
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
|
143 |
+
dict(type='RandomFlip', flip_ratio=0.5),
|
144 |
+
dict(type='Normalize', **img_norm_cfg),
|
145 |
+
dict(type='Pad', size_divisor=32),
|
146 |
+
dict(type='DefaultFormatBundle'),
|
147 |
+
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
|
148 |
+
]
|
149 |
+
test_pipeline = [
|
150 |
+
dict(type='LoadImageFromFile'),
|
151 |
+
dict(
|
152 |
+
type='MultiScaleFlipAug',
|
153 |
+
img_scale=(1333, 800),
|
154 |
+
flip=False,
|
155 |
+
transforms=[
|
156 |
+
dict(type='Resize', keep_ratio=True),
|
157 |
+
dict(type='RandomFlip'),
|
158 |
+
dict(type='Normalize', **img_norm_cfg),
|
159 |
+
dict(type='Pad', size_divisor=32),
|
160 |
+
dict(type='DefaultFormatBundle'),
|
161 |
+
dict(type='Collect', keys=['img']),
|
162 |
+
])
|
163 |
+
]
|
164 |
+
data = dict(
|
165 |
+
samples_per_gpu=2,
|
166 |
+
workers_per_gpu=2,
|
167 |
+
train=dict(
|
168 |
+
type=dataset_type,
|
169 |
+
ann_file=f'{data_root}/annotations/instances_train2017.json',
|
170 |
+
img_prefix=f'{data_root}/train2017/',
|
171 |
+
pipeline=train_pipeline),
|
172 |
+
val=dict(
|
173 |
+
type=dataset_type,
|
174 |
+
ann_file=f'{data_root}/annotations/instances_val2017.json',
|
175 |
+
img_prefix=f'{data_root}/val2017/',
|
176 |
+
pipeline=test_pipeline),
|
177 |
+
test=dict(
|
178 |
+
type=dataset_type,
|
179 |
+
ann_file=f'{data_root}/annotations/instances_val2017.json',
|
180 |
+
img_prefix=f'{data_root}/val2017/',
|
181 |
+
pipeline=test_pipeline))
|
182 |
+
evaluation = dict(interval=1, metric='bbox')
|
{models β configs/mm}/hrnet_w48_coco_256x192.py
RENAMED
@@ -1,169 +1,169 @@
|
|
1 |
-
# _base_ = [
|
2 |
-
# '../../../../_base_/default_runtime.py',
|
3 |
-
# '../../../../_base_/datasets/coco.py'
|
4 |
-
# ]
|
5 |
-
evaluation = dict(interval=10, metric='mAP', save_best='AP')
|
6 |
-
|
7 |
-
optimizer = dict(
|
8 |
-
type='Adam',
|
9 |
-
lr=5e-4,
|
10 |
-
)
|
11 |
-
optimizer_config = dict(grad_clip=None)
|
12 |
-
# learning policy
|
13 |
-
lr_config = dict(
|
14 |
-
policy='step',
|
15 |
-
warmup='linear',
|
16 |
-
warmup_iters=500,
|
17 |
-
warmup_ratio=0.001,
|
18 |
-
step=[170, 200])
|
19 |
-
total_epochs = 210
|
20 |
-
channel_cfg = dict(
|
21 |
-
num_output_channels=17,
|
22 |
-
dataset_joints=17,
|
23 |
-
dataset_channel=[
|
24 |
-
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
|
25 |
-
],
|
26 |
-
inference_channel=[
|
27 |
-
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
|
28 |
-
])
|
29 |
-
|
30 |
-
# model settings
|
31 |
-
model = dict(
|
32 |
-
type='TopDown',
|
33 |
-
pretrained='https://download.openmmlab.com/mmpose/'
|
34 |
-
'pretrain_models/hrnet_w48-8ef0771d.pth',
|
35 |
-
backbone=dict(
|
36 |
-
type='HRNet',
|
37 |
-
in_channels=3,
|
38 |
-
extra=dict(
|
39 |
-
stage1=dict(
|
40 |
-
num_modules=1,
|
41 |
-
num_branches=1,
|
42 |
-
block='BOTTLENECK',
|
43 |
-
num_blocks=(4, ),
|
44 |
-
num_channels=(64, )),
|
45 |
-
stage2=dict(
|
46 |
-
num_modules=1,
|
47 |
-
num_branches=2,
|
48 |
-
block='BASIC',
|
49 |
-
num_blocks=(4, 4),
|
50 |
-
num_channels=(48, 96)),
|
51 |
-
stage3=dict(
|
52 |
-
num_modules=4,
|
53 |
-
num_branches=3,
|
54 |
-
block='BASIC',
|
55 |
-
num_blocks=(4, 4, 4),
|
56 |
-
num_channels=(48, 96, 192)),
|
57 |
-
stage4=dict(
|
58 |
-
num_modules=3,
|
59 |
-
num_branches=4,
|
60 |
-
block='BASIC',
|
61 |
-
num_blocks=(4, 4, 4, 4),
|
62 |
-
num_channels=(48, 96, 192, 384))),
|
63 |
-
),
|
64 |
-
keypoint_head=dict(
|
65 |
-
type='TopdownHeatmapSimpleHead',
|
66 |
-
in_channels=48,
|
67 |
-
out_channels=channel_cfg['num_output_channels'],
|
68 |
-
num_deconv_layers=0,
|
69 |
-
extra=dict(final_conv_kernel=1, ),
|
70 |
-
loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)),
|
71 |
-
train_cfg=dict(),
|
72 |
-
test_cfg=dict(
|
73 |
-
flip_test=True,
|
74 |
-
post_process='default',
|
75 |
-
shift_heatmap=True,
|
76 |
-
modulate_kernel=11))
|
77 |
-
|
78 |
-
data_cfg = dict(
|
79 |
-
image_size=[192, 256],
|
80 |
-
heatmap_size=[48, 64],
|
81 |
-
num_output_channels=channel_cfg['num_output_channels'],
|
82 |
-
num_joints=channel_cfg['dataset_joints'],
|
83 |
-
dataset_channel=channel_cfg['dataset_channel'],
|
84 |
-
inference_channel=channel_cfg['inference_channel'],
|
85 |
-
soft_nms=False,
|
86 |
-
nms_thr=1.0,
|
87 |
-
oks_thr=0.9,
|
88 |
-
vis_thr=0.2,
|
89 |
-
use_gt_bbox=False,
|
90 |
-
det_bbox_thr=0.0,
|
91 |
-
bbox_file='data/coco/person_detection_results/'
|
92 |
-
'COCO_val2017_detections_AP_H_56_person.json',
|
93 |
-
)
|
94 |
-
|
95 |
-
train_pipeline = [
|
96 |
-
dict(type='LoadImageFromFile'),
|
97 |
-
dict(type='TopDownGetBboxCenterScale', padding=1.25),
|
98 |
-
dict(type='TopDownRandomShiftBboxCenter', shift_factor=0.16, prob=0.3),
|
99 |
-
dict(type='TopDownRandomFlip', flip_prob=0.5),
|
100 |
-
dict(
|
101 |
-
type='TopDownHalfBodyTransform',
|
102 |
-
num_joints_half_body=8,
|
103 |
-
prob_half_body=0.3),
|
104 |
-
dict(
|
105 |
-
type='TopDownGetRandomScaleRotation', rot_factor=40, scale_factor=0.5),
|
106 |
-
dict(type='TopDownAffine'),
|
107 |
-
dict(type='ToTensor'),
|
108 |
-
dict(
|
109 |
-
type='NormalizeTensor',
|
110 |
-
mean=[0.485, 0.456, 0.406],
|
111 |
-
std=[0.229, 0.224, 0.225]),
|
112 |
-
dict(type='TopDownGenerateTarget', sigma=2),
|
113 |
-
dict(
|
114 |
-
type='Collect',
|
115 |
-
keys=['img', 'target', 'target_weight'],
|
116 |
-
meta_keys=[
|
117 |
-
'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
|
118 |
-
'rotation', 'bbox_score', 'flip_pairs'
|
119 |
-
]),
|
120 |
-
]
|
121 |
-
|
122 |
-
val_pipeline = [
|
123 |
-
dict(type='LoadImageFromFile'),
|
124 |
-
dict(type='TopDownGetBboxCenterScale', padding=1.25),
|
125 |
-
dict(type='TopDownAffine'),
|
126 |
-
dict(type='ToTensor'),
|
127 |
-
dict(
|
128 |
-
type='NormalizeTensor',
|
129 |
-
mean=[0.485, 0.456, 0.406],
|
130 |
-
std=[0.229, 0.224, 0.225]),
|
131 |
-
dict(
|
132 |
-
type='Collect',
|
133 |
-
keys=['img'],
|
134 |
-
meta_keys=[
|
135 |
-
'image_file', 'center', 'scale', 'rotation', 'bbox_score',
|
136 |
-
'flip_pairs'
|
137 |
-
]),
|
138 |
-
]
|
139 |
-
|
140 |
-
test_pipeline = val_pipeline
|
141 |
-
|
142 |
-
data_root = 'data/coco'
|
143 |
-
data = dict(
|
144 |
-
samples_per_gpu=32,
|
145 |
-
workers_per_gpu=2,
|
146 |
-
val_dataloader=dict(samples_per_gpu=32),
|
147 |
-
test_dataloader=dict(samples_per_gpu=32),
|
148 |
-
train=dict(
|
149 |
-
type='TopDownCocoDataset',
|
150 |
-
ann_file=f'{data_root}/annotations/person_keypoints_train2017.json',
|
151 |
-
img_prefix=f'{data_root}/train2017/',
|
152 |
-
data_cfg=data_cfg,
|
153 |
-
pipeline=train_pipeline,
|
154 |
-
dataset_info={{_base_.dataset_info}}),
|
155 |
-
val=dict(
|
156 |
-
type='TopDownCocoDataset',
|
157 |
-
ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
|
158 |
-
img_prefix=f'{data_root}/val2017/',
|
159 |
-
data_cfg=data_cfg,
|
160 |
-
pipeline=val_pipeline,
|
161 |
-
dataset_info={{_base_.dataset_info}}),
|
162 |
-
test=dict(
|
163 |
-
type='TopDownCocoDataset',
|
164 |
-
ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
|
165 |
-
img_prefix=f'{data_root}/val2017/',
|
166 |
-
data_cfg=data_cfg,
|
167 |
-
pipeline=test_pipeline,
|
168 |
-
dataset_info={{_base_.dataset_info}}),
|
169 |
-
)
|
|
|
1 |
+
# _base_ = [
|
2 |
+
# '../../../../_base_/default_runtime.py',
|
3 |
+
# '../../../../_base_/datasets/coco.py'
|
4 |
+
# ]
|
5 |
+
evaluation = dict(interval=10, metric='mAP', save_best='AP')
|
6 |
+
|
7 |
+
optimizer = dict(
|
8 |
+
type='Adam',
|
9 |
+
lr=5e-4,
|
10 |
+
)
|
11 |
+
optimizer_config = dict(grad_clip=None)
|
12 |
+
# learning policy
|
13 |
+
lr_config = dict(
|
14 |
+
policy='step',
|
15 |
+
warmup='linear',
|
16 |
+
warmup_iters=500,
|
17 |
+
warmup_ratio=0.001,
|
18 |
+
step=[170, 200])
|
19 |
+
total_epochs = 210
|
20 |
+
channel_cfg = dict(
|
21 |
+
num_output_channels=17,
|
22 |
+
dataset_joints=17,
|
23 |
+
dataset_channel=[
|
24 |
+
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
|
25 |
+
],
|
26 |
+
inference_channel=[
|
27 |
+
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
|
28 |
+
])
|
29 |
+
|
30 |
+
# model settings
|
31 |
+
model = dict(
|
32 |
+
type='TopDown',
|
33 |
+
pretrained='https://download.openmmlab.com/mmpose/'
|
34 |
+
'pretrain_models/hrnet_w48-8ef0771d.pth',
|
35 |
+
backbone=dict(
|
36 |
+
type='HRNet',
|
37 |
+
in_channels=3,
|
38 |
+
extra=dict(
|
39 |
+
stage1=dict(
|
40 |
+
num_modules=1,
|
41 |
+
num_branches=1,
|
42 |
+
block='BOTTLENECK',
|
43 |
+
num_blocks=(4, ),
|
44 |
+
num_channels=(64, )),
|
45 |
+
stage2=dict(
|
46 |
+
num_modules=1,
|
47 |
+
num_branches=2,
|
48 |
+
block='BASIC',
|
49 |
+
num_blocks=(4, 4),
|
50 |
+
num_channels=(48, 96)),
|
51 |
+
stage3=dict(
|
52 |
+
num_modules=4,
|
53 |
+
num_branches=3,
|
54 |
+
block='BASIC',
|
55 |
+
num_blocks=(4, 4, 4),
|
56 |
+
num_channels=(48, 96, 192)),
|
57 |
+
stage4=dict(
|
58 |
+
num_modules=3,
|
59 |
+
num_branches=4,
|
60 |
+
block='BASIC',
|
61 |
+
num_blocks=(4, 4, 4, 4),
|
62 |
+
num_channels=(48, 96, 192, 384))),
|
63 |
+
),
|
64 |
+
keypoint_head=dict(
|
65 |
+
type='TopdownHeatmapSimpleHead',
|
66 |
+
in_channels=48,
|
67 |
+
out_channels=channel_cfg['num_output_channels'],
|
68 |
+
num_deconv_layers=0,
|
69 |
+
extra=dict(final_conv_kernel=1, ),
|
70 |
+
loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)),
|
71 |
+
train_cfg=dict(),
|
72 |
+
test_cfg=dict(
|
73 |
+
flip_test=True,
|
74 |
+
post_process='default',
|
75 |
+
shift_heatmap=True,
|
76 |
+
modulate_kernel=11))
|
77 |
+
|
78 |
+
data_cfg = dict(
|
79 |
+
image_size=[192, 256],
|
80 |
+
heatmap_size=[48, 64],
|
81 |
+
num_output_channels=channel_cfg['num_output_channels'],
|
82 |
+
num_joints=channel_cfg['dataset_joints'],
|
83 |
+
dataset_channel=channel_cfg['dataset_channel'],
|
84 |
+
inference_channel=channel_cfg['inference_channel'],
|
85 |
+
soft_nms=False,
|
86 |
+
nms_thr=1.0,
|
87 |
+
oks_thr=0.9,
|
88 |
+
vis_thr=0.2,
|
89 |
+
use_gt_bbox=False,
|
90 |
+
det_bbox_thr=0.0,
|
91 |
+
bbox_file='data/coco/person_detection_results/'
|
92 |
+
'COCO_val2017_detections_AP_H_56_person.json',
|
93 |
+
)
|
94 |
+
|
95 |
+
train_pipeline = [
|
96 |
+
dict(type='LoadImageFromFile'),
|
97 |
+
dict(type='TopDownGetBboxCenterScale', padding=1.25),
|
98 |
+
dict(type='TopDownRandomShiftBboxCenter', shift_factor=0.16, prob=0.3),
|
99 |
+
dict(type='TopDownRandomFlip', flip_prob=0.5),
|
100 |
+
dict(
|
101 |
+
type='TopDownHalfBodyTransform',
|
102 |
+
num_joints_half_body=8,
|
103 |
+
prob_half_body=0.3),
|
104 |
+
dict(
|
105 |
+
type='TopDownGetRandomScaleRotation', rot_factor=40, scale_factor=0.5),
|
106 |
+
dict(type='TopDownAffine'),
|
107 |
+
dict(type='ToTensor'),
|
108 |
+
dict(
|
109 |
+
type='NormalizeTensor',
|
110 |
+
mean=[0.485, 0.456, 0.406],
|
111 |
+
std=[0.229, 0.224, 0.225]),
|
112 |
+
dict(type='TopDownGenerateTarget', sigma=2),
|
113 |
+
dict(
|
114 |
+
type='Collect',
|
115 |
+
keys=['img', 'target', 'target_weight'],
|
116 |
+
meta_keys=[
|
117 |
+
'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
|
118 |
+
'rotation', 'bbox_score', 'flip_pairs'
|
119 |
+
]),
|
120 |
+
]
|
121 |
+
|
122 |
+
val_pipeline = [
|
123 |
+
dict(type='LoadImageFromFile'),
|
124 |
+
dict(type='TopDownGetBboxCenterScale', padding=1.25),
|
125 |
+
dict(type='TopDownAffine'),
|
126 |
+
dict(type='ToTensor'),
|
127 |
+
dict(
|
128 |
+
type='NormalizeTensor',
|
129 |
+
mean=[0.485, 0.456, 0.406],
|
130 |
+
std=[0.229, 0.224, 0.225]),
|
131 |
+
dict(
|
132 |
+
type='Collect',
|
133 |
+
keys=['img'],
|
134 |
+
meta_keys=[
|
135 |
+
'image_file', 'center', 'scale', 'rotation', 'bbox_score',
|
136 |
+
'flip_pairs'
|
137 |
+
]),
|
138 |
+
]
|
139 |
+
|
140 |
+
test_pipeline = val_pipeline
|
141 |
+
|
142 |
+
data_root = 'data/coco'
|
143 |
+
data = dict(
|
144 |
+
samples_per_gpu=32,
|
145 |
+
workers_per_gpu=2,
|
146 |
+
val_dataloader=dict(samples_per_gpu=32),
|
147 |
+
test_dataloader=dict(samples_per_gpu=32),
|
148 |
+
train=dict(
|
149 |
+
type='TopDownCocoDataset',
|
150 |
+
ann_file=f'{data_root}/annotations/person_keypoints_train2017.json',
|
151 |
+
img_prefix=f'{data_root}/train2017/',
|
152 |
+
data_cfg=data_cfg,
|
153 |
+
pipeline=train_pipeline,
|
154 |
+
dataset_info={{_base_.dataset_info}}),
|
155 |
+
val=dict(
|
156 |
+
type='TopDownCocoDataset',
|
157 |
+
ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
|
158 |
+
img_prefix=f'{data_root}/val2017/',
|
159 |
+
data_cfg=data_cfg,
|
160 |
+
pipeline=val_pipeline,
|
161 |
+
dataset_info={{_base_.dataset_info}}),
|
162 |
+
test=dict(
|
163 |
+
type='TopDownCocoDataset',
|
164 |
+
ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
|
165 |
+
img_prefix=f'{data_root}/val2017/',
|
166 |
+
data_cfg=data_cfg,
|
167 |
+
pipeline=test_pipeline,
|
168 |
+
dataset_info={{_base_.dataset_info}}),
|
169 |
+
)
|
configs/stable-diffusion/sd-v1-inference.yaml
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 1.0e-04
|
3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
4 |
+
params:
|
5 |
+
linear_start: 0.00085
|
6 |
+
linear_end: 0.0120
|
7 |
+
num_timesteps_cond: 1
|
8 |
+
log_every_t: 200
|
9 |
+
timesteps: 1000
|
10 |
+
first_stage_key: "jpg"
|
11 |
+
cond_stage_key: "txt"
|
12 |
+
image_size: 64
|
13 |
+
channels: 4
|
14 |
+
cond_stage_trainable: false # Note: different from the one we trained before
|
15 |
+
conditioning_key: crossattn
|
16 |
+
monitor: val/loss_simple_ema
|
17 |
+
scale_factor: 0.18215
|
18 |
+
use_ema: False
|
19 |
+
|
20 |
+
unet_config:
|
21 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
22 |
+
params:
|
23 |
+
use_fp16: True
|
24 |
+
image_size: 32 # unused
|
25 |
+
in_channels: 4
|
26 |
+
out_channels: 4
|
27 |
+
model_channels: 320
|
28 |
+
attention_resolutions: [ 4, 2, 1 ]
|
29 |
+
num_res_blocks: 2
|
30 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
31 |
+
num_heads: 8
|
32 |
+
use_spatial_transformer: True
|
33 |
+
transformer_depth: 1
|
34 |
+
context_dim: 768
|
35 |
+
use_checkpoint: True
|
36 |
+
legacy: False
|
37 |
+
|
38 |
+
first_stage_config:
|
39 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
40 |
+
params:
|
41 |
+
embed_dim: 4
|
42 |
+
monitor: val/rec_loss
|
43 |
+
ddconfig:
|
44 |
+
double_z: true
|
45 |
+
z_channels: 4
|
46 |
+
resolution: 512
|
47 |
+
in_channels: 3
|
48 |
+
out_ch: 3
|
49 |
+
ch: 128
|
50 |
+
ch_mult:
|
51 |
+
- 1
|
52 |
+
- 2
|
53 |
+
- 4
|
54 |
+
- 4
|
55 |
+
num_res_blocks: 2
|
56 |
+
attn_resolutions: []
|
57 |
+
dropout: 0.0
|
58 |
+
lossconfig:
|
59 |
+
target: torch.nn.Identity
|
60 |
+
|
61 |
+
cond_stage_config:
|
62 |
+
target: ldm.modules.encoders.modules.WebUIFrozenCLIPEmebedder
|
63 |
+
params:
|
64 |
+
version: openai/clip-vit-large-patch14
|
65 |
+
layer: last
|
configs/stable-diffusion/sd-v1-train.yaml
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 1.0e-04
|
3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
4 |
+
params:
|
5 |
+
linear_start: 0.00085
|
6 |
+
linear_end: 0.0120
|
7 |
+
num_timesteps_cond: 1
|
8 |
+
log_every_t: 200
|
9 |
+
timesteps: 1000
|
10 |
+
first_stage_key: "jpg"
|
11 |
+
cond_stage_key: "txt"
|
12 |
+
image_size: 64
|
13 |
+
channels: 4
|
14 |
+
cond_stage_trainable: false # Note: different from the one we trained before
|
15 |
+
conditioning_key: crossattn
|
16 |
+
monitor: val/loss_simple_ema
|
17 |
+
scale_factor: 0.18215
|
18 |
+
use_ema: False
|
19 |
+
|
20 |
+
scheduler_config: # 10000 warmup steps
|
21 |
+
target: ldm.lr_scheduler.LambdaLinearScheduler
|
22 |
+
params:
|
23 |
+
warm_up_steps: [ 10000 ]
|
24 |
+
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
25 |
+
f_start: [ 1.e-6 ]
|
26 |
+
f_max: [ 1. ]
|
27 |
+
f_min: [ 1. ]
|
28 |
+
|
29 |
+
unet_config:
|
30 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
31 |
+
params:
|
32 |
+
image_size: 32 # unused
|
33 |
+
in_channels: 4
|
34 |
+
out_channels: 4
|
35 |
+
model_channels: 320
|
36 |
+
attention_resolutions: [ 4, 2, 1 ]
|
37 |
+
num_res_blocks: 2
|
38 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
39 |
+
num_heads: 8
|
40 |
+
use_spatial_transformer: True
|
41 |
+
transformer_depth: 1
|
42 |
+
context_dim: 768
|
43 |
+
use_checkpoint: True
|
44 |
+
legacy: False
|
45 |
+
|
46 |
+
first_stage_config:
|
47 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
48 |
+
params:
|
49 |
+
embed_dim: 4
|
50 |
+
monitor: val/rec_loss
|
51 |
+
ddconfig:
|
52 |
+
double_z: true
|
53 |
+
z_channels: 4
|
54 |
+
resolution: 256
|
55 |
+
in_channels: 3
|
56 |
+
out_ch: 3
|
57 |
+
ch: 128
|
58 |
+
ch_mult:
|
59 |
+
- 1
|
60 |
+
- 2
|
61 |
+
- 4
|
62 |
+
- 4
|
63 |
+
num_res_blocks: 2
|
64 |
+
attn_resolutions: []
|
65 |
+
dropout: 0.0
|
66 |
+
lossconfig:
|
67 |
+
target: torch.nn.Identity
|
68 |
+
|
69 |
+
cond_stage_config: #__is_unconditional__
|
70 |
+
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
71 |
+
params:
|
72 |
+
version: openai/clip-vit-large-patch14
|
73 |
+
|
74 |
+
logger:
|
75 |
+
print_freq: 100
|
76 |
+
save_checkpoint_freq: !!float 1e4
|
77 |
+
use_tb_logger: true
|
78 |
+
wandb:
|
79 |
+
project: ~
|
80 |
+
resume_id: ~
|
81 |
+
dist_params:
|
82 |
+
backend: nccl
|
83 |
+
port: 29500
|
84 |
+
training:
|
85 |
+
lr: !!float 1e-5
|
86 |
+
save_freq: 1e4
|
configs/stable-diffusion/train_keypose.yaml
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: train_keypose
|
2 |
+
model:
|
3 |
+
base_learning_rate: 1.0e-04
|
4 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
5 |
+
params:
|
6 |
+
linear_start: 0.00085
|
7 |
+
linear_end: 0.0120
|
8 |
+
num_timesteps_cond: 1
|
9 |
+
log_every_t: 200
|
10 |
+
timesteps: 1000
|
11 |
+
first_stage_key: "jpg"
|
12 |
+
cond_stage_key: "txt"
|
13 |
+
image_size: 64
|
14 |
+
channels: 4
|
15 |
+
cond_stage_trainable: false # Note: different from the one we trained before
|
16 |
+
conditioning_key: crossattn
|
17 |
+
monitor: val/loss_simple_ema
|
18 |
+
scale_factor: 0.18215
|
19 |
+
use_ema: False
|
20 |
+
|
21 |
+
scheduler_config: # 10000 warmup steps
|
22 |
+
target: ldm.lr_scheduler.LambdaLinearScheduler
|
23 |
+
params:
|
24 |
+
warm_up_steps: [ 10000 ]
|
25 |
+
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
26 |
+
f_start: [ 1.e-6 ]
|
27 |
+
f_max: [ 1. ]
|
28 |
+
f_min: [ 1. ]
|
29 |
+
|
30 |
+
unet_config:
|
31 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
32 |
+
params:
|
33 |
+
image_size: 32 # unused
|
34 |
+
in_channels: 4
|
35 |
+
out_channels: 4
|
36 |
+
model_channels: 320
|
37 |
+
attention_resolutions: [ 4, 2, 1 ]
|
38 |
+
num_res_blocks: 2
|
39 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
40 |
+
num_heads: 8
|
41 |
+
use_spatial_transformer: True
|
42 |
+
transformer_depth: 1
|
43 |
+
context_dim: 768
|
44 |
+
use_checkpoint: True
|
45 |
+
legacy: False
|
46 |
+
|
47 |
+
first_stage_config:
|
48 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
49 |
+
params:
|
50 |
+
embed_dim: 4
|
51 |
+
monitor: val/rec_loss
|
52 |
+
ddconfig:
|
53 |
+
double_z: true
|
54 |
+
z_channels: 4
|
55 |
+
resolution: 256
|
56 |
+
in_channels: 3
|
57 |
+
out_ch: 3
|
58 |
+
ch: 128
|
59 |
+
ch_mult:
|
60 |
+
- 1
|
61 |
+
- 2
|
62 |
+
- 4
|
63 |
+
- 4
|
64 |
+
num_res_blocks: 2
|
65 |
+
attn_resolutions: []
|
66 |
+
dropout: 0.0
|
67 |
+
lossconfig:
|
68 |
+
target: torch.nn.Identity
|
69 |
+
|
70 |
+
cond_stage_config: #__is_unconditional__
|
71 |
+
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
72 |
+
params:
|
73 |
+
version: openai/clip-vit-large-patch14
|
74 |
+
|
75 |
+
logger:
|
76 |
+
print_freq: 100
|
77 |
+
save_checkpoint_freq: !!float 1e4
|
78 |
+
use_tb_logger: true
|
79 |
+
wandb:
|
80 |
+
project: ~
|
81 |
+
resume_id: ~
|
82 |
+
dist_params:
|
83 |
+
backend: nccl
|
84 |
+
port: 29500
|
85 |
+
training:
|
86 |
+
lr: !!float 1e-5
|
87 |
+
save_freq: 1e4
|
configs/stable-diffusion/train_mask.yaml
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: train_mask
|
2 |
+
model:
|
3 |
+
base_learning_rate: 1.0e-04
|
4 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
5 |
+
params:
|
6 |
+
linear_start: 0.00085
|
7 |
+
linear_end: 0.0120
|
8 |
+
num_timesteps_cond: 1
|
9 |
+
log_every_t: 200
|
10 |
+
timesteps: 1000
|
11 |
+
first_stage_key: "jpg"
|
12 |
+
cond_stage_key: "txt"
|
13 |
+
image_size: 64
|
14 |
+
channels: 4
|
15 |
+
cond_stage_trainable: false # Note: different from the one we trained before
|
16 |
+
conditioning_key: crossattn
|
17 |
+
monitor: val/loss_simple_ema
|
18 |
+
scale_factor: 0.18215
|
19 |
+
use_ema: False
|
20 |
+
|
21 |
+
scheduler_config: # 10000 warmup steps
|
22 |
+
target: ldm.lr_scheduler.LambdaLinearScheduler
|
23 |
+
params:
|
24 |
+
warm_up_steps: [ 10000 ]
|
25 |
+
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
26 |
+
f_start: [ 1.e-6 ]
|
27 |
+
f_max: [ 1. ]
|
28 |
+
f_min: [ 1. ]
|
29 |
+
|
30 |
+
unet_config:
|
31 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
32 |
+
params:
|
33 |
+
image_size: 32 # unused
|
34 |
+
in_channels: 4
|
35 |
+
out_channels: 4
|
36 |
+
model_channels: 320
|
37 |
+
attention_resolutions: [ 4, 2, 1 ]
|
38 |
+
num_res_blocks: 2
|
39 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
40 |
+
num_heads: 8
|
41 |
+
use_spatial_transformer: True
|
42 |
+
transformer_depth: 1
|
43 |
+
context_dim: 768
|
44 |
+
use_checkpoint: True
|
45 |
+
legacy: False
|
46 |
+
|
47 |
+
first_stage_config:
|
48 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
49 |
+
params:
|
50 |
+
embed_dim: 4
|
51 |
+
monitor: val/rec_loss
|
52 |
+
ddconfig:
|
53 |
+
double_z: true
|
54 |
+
z_channels: 4
|
55 |
+
resolution: 256
|
56 |
+
in_channels: 3
|
57 |
+
out_ch: 3
|
58 |
+
ch: 128
|
59 |
+
ch_mult:
|
60 |
+
- 1
|
61 |
+
- 2
|
62 |
+
- 4
|
63 |
+
- 4
|
64 |
+
num_res_blocks: 2
|
65 |
+
attn_resolutions: []
|
66 |
+
dropout: 0.0
|
67 |
+
lossconfig:
|
68 |
+
target: torch.nn.Identity
|
69 |
+
|
70 |
+
cond_stage_config: #__is_unconditional__
|
71 |
+
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
72 |
+
params:
|
73 |
+
version: openai/clip-vit-large-patch14
|
74 |
+
|
75 |
+
logger:
|
76 |
+
print_freq: 100
|
77 |
+
save_checkpoint_freq: !!float 1e4
|
78 |
+
use_tb_logger: true
|
79 |
+
wandb:
|
80 |
+
project: ~
|
81 |
+
resume_id: ~
|
82 |
+
dist_params:
|
83 |
+
backend: nccl
|
84 |
+
port: 29500
|
85 |
+
training:
|
86 |
+
lr: !!float 1e-5
|
87 |
+
save_freq: 1e4
|
configs/stable-diffusion/train_sketch.yaml
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: train_sketch
|
2 |
+
model:
|
3 |
+
base_learning_rate: 1.0e-04
|
4 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
5 |
+
params:
|
6 |
+
linear_start: 0.00085
|
7 |
+
linear_end: 0.0120
|
8 |
+
num_timesteps_cond: 1
|
9 |
+
log_every_t: 200
|
10 |
+
timesteps: 1000
|
11 |
+
first_stage_key: "jpg"
|
12 |
+
cond_stage_key: "txt"
|
13 |
+
image_size: 64
|
14 |
+
channels: 4
|
15 |
+
cond_stage_trainable: false # Note: different from the one we trained before
|
16 |
+
conditioning_key: crossattn
|
17 |
+
monitor: val/loss_simple_ema
|
18 |
+
scale_factor: 0.18215
|
19 |
+
use_ema: False
|
20 |
+
|
21 |
+
scheduler_config: # 10000 warmup steps
|
22 |
+
target: ldm.lr_scheduler.LambdaLinearScheduler
|
23 |
+
params:
|
24 |
+
warm_up_steps: [ 10000 ]
|
25 |
+
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
26 |
+
f_start: [ 1.e-6 ]
|
27 |
+
f_max: [ 1. ]
|
28 |
+
f_min: [ 1. ]
|
29 |
+
|
30 |
+
unet_config:
|
31 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
32 |
+
params:
|
33 |
+
image_size: 32 # unused
|
34 |
+
in_channels: 4
|
35 |
+
out_channels: 4
|
36 |
+
model_channels: 320
|
37 |
+
attention_resolutions: [ 4, 2, 1 ]
|
38 |
+
num_res_blocks: 2
|
39 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
40 |
+
num_heads: 8
|
41 |
+
use_spatial_transformer: True
|
42 |
+
transformer_depth: 1
|
43 |
+
context_dim: 768
|
44 |
+
use_checkpoint: True
|
45 |
+
legacy: False
|
46 |
+
|
47 |
+
first_stage_config:
|
48 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
49 |
+
params:
|
50 |
+
embed_dim: 4
|
51 |
+
monitor: val/rec_loss
|
52 |
+
ddconfig:
|
53 |
+
double_z: true
|
54 |
+
z_channels: 4
|
55 |
+
resolution: 256
|
56 |
+
in_channels: 3
|
57 |
+
out_ch: 3
|
58 |
+
ch: 128
|
59 |
+
ch_mult:
|
60 |
+
- 1
|
61 |
+
- 2
|
62 |
+
- 4
|
63 |
+
- 4
|
64 |
+
num_res_blocks: 2
|
65 |
+
attn_resolutions: []
|
66 |
+
dropout: 0.0
|
67 |
+
lossconfig:
|
68 |
+
target: torch.nn.Identity
|
69 |
+
|
70 |
+
cond_stage_config: #__is_unconditional__
|
71 |
+
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
72 |
+
params:
|
73 |
+
version: openai/clip-vit-large-patch14
|
74 |
+
|
75 |
+
logger:
|
76 |
+
print_freq: 100
|
77 |
+
save_checkpoint_freq: !!float 1e4
|
78 |
+
use_tb_logger: true
|
79 |
+
wandb:
|
80 |
+
project: ~
|
81 |
+
resume_id: ~
|
82 |
+
dist_params:
|
83 |
+
backend: nccl
|
84 |
+
port: 29500
|
85 |
+
training:
|
86 |
+
lr: !!float 1e-5
|
87 |
+
save_freq: 1e4
|
demo/demos.py
DELETED
@@ -1,309 +0,0 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
import numpy as np
|
3 |
-
import psutil
|
4 |
-
|
5 |
-
def create_map():
|
6 |
-
return np.zeros(shape=(512, 512), dtype=np.uint8)+255
|
7 |
-
|
8 |
-
def get_system_memory():
|
9 |
-
memory = psutil.virtual_memory()
|
10 |
-
memory_percent = memory.percent
|
11 |
-
memory_used = memory.used / (1024.0 ** 3)
|
12 |
-
memory_total = memory.total / (1024.0 ** 3)
|
13 |
-
return {"percent": f"{memory_percent}%", "used": f"{memory_used:.3f}GB", "total": f"{memory_total:.3f}GB"}
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
def create_demo_keypose(process):
|
18 |
-
with gr.Blocks() as demo:
|
19 |
-
with gr.Row():
|
20 |
-
gr.Markdown('## T2I-Adapter (Keypose)')
|
21 |
-
with gr.Row():
|
22 |
-
with gr.Column():
|
23 |
-
input_img = gr.Image(source='upload', type="numpy")
|
24 |
-
prompt = gr.Textbox(label="Prompt")
|
25 |
-
neg_prompt = gr.Textbox(label="Negative Prompt",
|
26 |
-
value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
|
27 |
-
pos_prompt = gr.Textbox(label="Positive Prompt",
|
28 |
-
value = 'crafted, elegant, meticulous, magnificent, maximum details, extremely hyper aesthetic, intricately detailed')
|
29 |
-
with gr.Row():
|
30 |
-
type_in = gr.inputs.Radio(['Keypose', 'Image'], type="value", default='Image', label='Input Types\n (You can input an image or a keypose map)')
|
31 |
-
fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed to produce a fixed output)')
|
32 |
-
run_button = gr.Button(label="Run")
|
33 |
-
con_strength = gr.Slider(label="Controling Strength (The guidance strength of the keypose to the result)", minimum=0, maximum=1, value=1, step=0.1)
|
34 |
-
scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
|
35 |
-
base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
|
36 |
-
with gr.Column():
|
37 |
-
result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
|
38 |
-
ips = [input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model]
|
39 |
-
run_button.click(fn=process, inputs=ips, outputs=[result])
|
40 |
-
return demo
|
41 |
-
|
42 |
-
def create_demo_openpose(process):
|
43 |
-
with gr.Blocks() as demo:
|
44 |
-
with gr.Row():
|
45 |
-
gr.Markdown('## T2I-Adapter (Openpose)')
|
46 |
-
with gr.Row():
|
47 |
-
with gr.Column():
|
48 |
-
input_img = gr.Image(source='upload', type="numpy")
|
49 |
-
prompt = gr.Textbox(label="Prompt")
|
50 |
-
neg_prompt = gr.Textbox(label="Negative Prompt",
|
51 |
-
value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
|
52 |
-
pos_prompt = gr.Textbox(label="Positive Prompt",
|
53 |
-
value = 'crafted, elegant, meticulous, magnificent, maximum details, extremely hyper aesthetic, intricately detailed')
|
54 |
-
with gr.Row():
|
55 |
-
type_in = gr.inputs.Radio(['Openpose', 'Image'], type="value", default='Image', label='Input Types\n (You can input an image or a openpose map)')
|
56 |
-
fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed to produce a fixed output)')
|
57 |
-
run_button = gr.Button(label="Run")
|
58 |
-
con_strength = gr.Slider(label="Controling Strength (The guidance strength of the openpose to the result)", minimum=0, maximum=1, value=1, step=0.1)
|
59 |
-
scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
|
60 |
-
base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
|
61 |
-
with gr.Column():
|
62 |
-
result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
|
63 |
-
ips = [input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model]
|
64 |
-
run_button.click(fn=process, inputs=ips, outputs=[result])
|
65 |
-
return demo
|
66 |
-
|
67 |
-
def create_demo_sketch(process):
|
68 |
-
with gr.Blocks() as demo:
|
69 |
-
with gr.Row():
|
70 |
-
gr.Markdown('## T2I-Adapter (Sketch)')
|
71 |
-
with gr.Row():
|
72 |
-
with gr.Column():
|
73 |
-
input_img = gr.Image(source='upload', type="numpy")
|
74 |
-
prompt = gr.Textbox(label="Prompt")
|
75 |
-
neg_prompt = gr.Textbox(label="Negative Prompt",
|
76 |
-
value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
|
77 |
-
pos_prompt = gr.Textbox(label="Positive Prompt",
|
78 |
-
value = 'crafted, elegant, meticulous, magnificent, maximum details, extremely hyper aesthetic, intricately detailed')
|
79 |
-
with gr.Row():
|
80 |
-
type_in = gr.inputs.Radio(['Sketch', 'Image'], type="value", default='Image', label='Input Types\n (You can input an image or a sketch)')
|
81 |
-
color_back = gr.inputs.Radio(['White', 'Black'], type="value", default='Black', label='Color of the sketch background\n (Only work for sketch input)')
|
82 |
-
run_button = gr.Button(label="Run")
|
83 |
-
con_strength = gr.Slider(label="Controling Strength (The guidance strength of the sketch to the result)", minimum=0, maximum=1, value=0.4, step=0.1)
|
84 |
-
scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
|
85 |
-
fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed)')
|
86 |
-
base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
|
87 |
-
with gr.Column():
|
88 |
-
result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
|
89 |
-
ips = [input_img, type_in, color_back, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model]
|
90 |
-
run_button.click(fn=process, inputs=ips, outputs=[result])
|
91 |
-
return demo
|
92 |
-
|
93 |
-
def create_demo_canny(process):
|
94 |
-
with gr.Blocks() as demo:
|
95 |
-
with gr.Row():
|
96 |
-
gr.Markdown('## T2I-Adapter (Canny)')
|
97 |
-
with gr.Row():
|
98 |
-
with gr.Column():
|
99 |
-
input_img = gr.Image(source='upload', type="numpy")
|
100 |
-
prompt = gr.Textbox(label="Prompt")
|
101 |
-
neg_prompt = gr.Textbox(label="Negative Prompt",
|
102 |
-
value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
|
103 |
-
pos_prompt = gr.Textbox(label="Positive Prompt",
|
104 |
-
value = 'crafted, elegant, meticulous, magnificent, maximum details, extremely hyper aesthetic, intricately detailed')
|
105 |
-
with gr.Row():
|
106 |
-
type_in = gr.inputs.Radio(['Canny', 'Image'], type="value", default='Image', label='Input Types\n (You can input an image or a canny map)')
|
107 |
-
color_back = gr.inputs.Radio(['White', 'Black'], type="value", default='Black', label='Color of the canny background\n (Only work for canny input)')
|
108 |
-
run_button = gr.Button(label="Run")
|
109 |
-
con_strength = gr.Slider(label="Controling Strength (The guidance strength of the canny to the result)", minimum=0, maximum=1, value=1, step=0.1)
|
110 |
-
scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
|
111 |
-
fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed)')
|
112 |
-
base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
|
113 |
-
with gr.Column():
|
114 |
-
result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
|
115 |
-
ips = [input_img, type_in, color_back, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model]
|
116 |
-
run_button.click(fn=process, inputs=ips, outputs=[result])
|
117 |
-
return demo
|
118 |
-
|
119 |
-
def create_demo_color_sketch(process):
|
120 |
-
with gr.Blocks() as demo:
|
121 |
-
with gr.Row():
|
122 |
-
gr.Markdown('## T2I-Adapter (Color + Sketch)')
|
123 |
-
with gr.Row():
|
124 |
-
with gr.Column():
|
125 |
-
with gr.Row():
|
126 |
-
input_img_sketch = gr.Image(source='upload', type="numpy", label='Sketch guidance')
|
127 |
-
input_img_color = gr.Image(source='upload', type="numpy", label='Color guidance')
|
128 |
-
prompt = gr.Textbox(label="Prompt")
|
129 |
-
neg_prompt = gr.Textbox(label="Negative Prompt",
|
130 |
-
value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
|
131 |
-
pos_prompt = gr.Textbox(label="Positive Prompt",
|
132 |
-
value = 'crafted, elegant, meticulous, magnificent, maximum details, extremely hyper aesthetic, intricately detailed')
|
133 |
-
type_in_color = gr.inputs.Radio(['ColorMap', 'Image'], type="value", default='Image', label='Input Types of Color\n (You can input an image or a color map)')
|
134 |
-
with gr.Row():
|
135 |
-
type_in = gr.inputs.Radio(['Sketch', 'Image'], type="value", default='Image', label='Input Types of Sketch\n (You can input an image or a sketch)')
|
136 |
-
color_back = gr.inputs.Radio(['White', 'Black'], type="value", default='Black', label='Color of the sketch background\n (Only work for sketch input)')
|
137 |
-
with gr.Row():
|
138 |
-
w_sketch = gr.Slider(label="Sketch guidance weight", minimum=0, maximum=2, value=1.0, step=0.1)
|
139 |
-
w_color = gr.Slider(label="Color guidance weight", minimum=0, maximum=2, value=1.2, step=0.1)
|
140 |
-
run_button = gr.Button(label="Run")
|
141 |
-
con_strength = gr.Slider(label="Controling Strength (The guidance strength of the sketch to the result)", minimum=0, maximum=1, value=0.4, step=0.1)
|
142 |
-
scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
|
143 |
-
fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed)')
|
144 |
-
base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
|
145 |
-
with gr.Column():
|
146 |
-
result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=3, height='auto')
|
147 |
-
ips = [input_img_sketch, input_img_color, type_in, type_in_color, w_sketch, w_color, color_back, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model]
|
148 |
-
run_button.click(fn=process, inputs=ips, outputs=[result])
|
149 |
-
return demo
|
150 |
-
|
151 |
-
def create_demo_style_sketch(process):
|
152 |
-
with gr.Blocks() as demo:
|
153 |
-
with gr.Row():
|
154 |
-
gr.Markdown('## T2I-Adapter (Style + Sketch)')
|
155 |
-
with gr.Row():
|
156 |
-
with gr.Column():
|
157 |
-
with gr.Row():
|
158 |
-
input_img_sketch = gr.Image(source='upload', type="numpy", label='Sketch guidance')
|
159 |
-
input_img_style = gr.Image(source='upload', type="numpy", label='Style guidance')
|
160 |
-
prompt = gr.Textbox(label="Prompt")
|
161 |
-
neg_prompt = gr.Textbox(label="Negative Prompt",
|
162 |
-
value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
|
163 |
-
pos_prompt = gr.Textbox(label="Positive Prompt",
|
164 |
-
value = 'crafted, elegant, meticulous, magnificent, maximum details, extremely hyper aesthetic, intricately detailed')
|
165 |
-
with gr.Row():
|
166 |
-
type_in = gr.inputs.Radio(['Sketch', 'Image'], type="value", default='Image', label='Input Types of Sketch\n (You can input an image or a sketch)')
|
167 |
-
color_back = gr.inputs.Radio(['White', 'Black'], type="value", default='Black', label='Color of the sketch background\n (Only work for sketch input)')
|
168 |
-
run_button = gr.Button(label="Run")
|
169 |
-
con_strength = gr.Slider(label="Controling Strength (The guidance strength of the sketch to the result)", minimum=0, maximum=1, value=1, step=0.1)
|
170 |
-
scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
|
171 |
-
fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed)')
|
172 |
-
base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
|
173 |
-
with gr.Column():
|
174 |
-
result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
|
175 |
-
ips = [input_img_sketch, input_img_style, type_in, color_back, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model]
|
176 |
-
run_button.click(fn=process, inputs=ips, outputs=[result])
|
177 |
-
return demo
|
178 |
-
|
179 |
-
def create_demo_color(process):
|
180 |
-
with gr.Blocks() as demo:
|
181 |
-
with gr.Row():
|
182 |
-
gr.Markdown('## T2I-Adapter (Color)')
|
183 |
-
with gr.Row():
|
184 |
-
with gr.Column():
|
185 |
-
input_img = gr.Image(source='upload', type="numpy", label='Color guidance')
|
186 |
-
prompt = gr.Textbox(label="Prompt")
|
187 |
-
neg_prompt = gr.Textbox(label="Negative Prompt",
|
188 |
-
value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
|
189 |
-
pos_prompt = gr.Textbox(label="Positive Prompt",
|
190 |
-
value = 'crafted, elegant, meticulous, magnificent, maximum details, extremely hyper aesthetic, intricately detailed')
|
191 |
-
type_in_color = gr.inputs.Radio(['ColorMap', 'Image'], type="value", default='Image', label='Input Types of Color\n (You can input an image or a color map)')
|
192 |
-
w_color = gr.Slider(label="Color guidance weight", minimum=0, maximum=2, value=1, step=0.1)
|
193 |
-
run_button = gr.Button(label="Run")
|
194 |
-
con_strength = gr.Slider(label="Controling Strength (The guidance strength of the sketch to the result)", minimum=0, maximum=1, value=1, step=0.1)
|
195 |
-
scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
|
196 |
-
fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed)')
|
197 |
-
base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
|
198 |
-
with gr.Column():
|
199 |
-
result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
|
200 |
-
ips = [input_img, prompt, neg_prompt, pos_prompt, w_color, type_in_color, fix_sample, scale, con_strength, base_model]
|
201 |
-
run_button.click(fn=process, inputs=ips, outputs=[result])
|
202 |
-
return demo
|
203 |
-
|
204 |
-
def create_demo_seg(process):
|
205 |
-
with gr.Blocks() as demo:
|
206 |
-
with gr.Row():
|
207 |
-
gr.Markdown('## T2I-Adapter (Segmentation)')
|
208 |
-
with gr.Row():
|
209 |
-
with gr.Column():
|
210 |
-
input_img = gr.Image(source='upload', type="numpy")
|
211 |
-
prompt = gr.Textbox(label="Prompt")
|
212 |
-
neg_prompt = gr.Textbox(label="Negative Prompt",
|
213 |
-
value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
|
214 |
-
pos_prompt = gr.Textbox(label="Positive Prompt",
|
215 |
-
value = 'crafted, elegant, meticulous, magnificent, maximum details, extremely hyper aesthetic, intricately detailed')
|
216 |
-
with gr.Row():
|
217 |
-
type_in = gr.inputs.Radio(['Segmentation', 'Image'], type="value", default='Image', label='You can input an image or a segmentation. If you choose to input a segmentation, it must correspond to the coco-stuff')
|
218 |
-
run_button = gr.Button(label="Run")
|
219 |
-
con_strength = gr.Slider(label="Controling Strength (The guidance strength of the segmentation to the result)", minimum=0, maximum=1, value=1, step=0.1)
|
220 |
-
scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
|
221 |
-
fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed)')
|
222 |
-
base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
|
223 |
-
with gr.Column():
|
224 |
-
result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
|
225 |
-
ips = [input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model]
|
226 |
-
run_button.click(fn=process, inputs=ips, outputs=[result])
|
227 |
-
return demo
|
228 |
-
|
229 |
-
def create_demo_depth(process):
|
230 |
-
with gr.Blocks() as demo:
|
231 |
-
with gr.Row():
|
232 |
-
gr.Markdown('## T2I-Adapter (Depth)')
|
233 |
-
with gr.Row():
|
234 |
-
with gr.Column():
|
235 |
-
input_img = gr.Image(source='upload', type="numpy")
|
236 |
-
prompt = gr.Textbox(label="Prompt")
|
237 |
-
neg_prompt = gr.Textbox(label="Negative Prompt",
|
238 |
-
value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
|
239 |
-
pos_prompt = gr.Textbox(label="Positive Prompt",
|
240 |
-
value = 'crafted, elegant, meticulous, magnificent, maximum details, extremely hyper aesthetic, intricately detailed')
|
241 |
-
with gr.Row():
|
242 |
-
type_in = gr.inputs.Radio(['Depth', 'Image'], type="value", default='Image', label='You can input an image or a depth map')
|
243 |
-
run_button = gr.Button(label="Run")
|
244 |
-
con_strength = gr.Slider(label="Controling Strength (The guidance strength of the depth map to the result)", minimum=0, maximum=1, value=1, step=0.1)
|
245 |
-
scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
|
246 |
-
fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed)')
|
247 |
-
base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
|
248 |
-
with gr.Column():
|
249 |
-
result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
|
250 |
-
ips = [input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model]
|
251 |
-
run_button.click(fn=process, inputs=ips, outputs=[result])
|
252 |
-
return demo
|
253 |
-
|
254 |
-
def create_demo_depth_keypose(process):
|
255 |
-
with gr.Blocks() as demo:
|
256 |
-
with gr.Row():
|
257 |
-
gr.Markdown('## T2I-Adapter (Depth & Keypose)')
|
258 |
-
with gr.Row():
|
259 |
-
with gr.Column():
|
260 |
-
with gr.Row():
|
261 |
-
input_img_depth = gr.Image(source='upload', type="numpy", label='Depth guidance')
|
262 |
-
input_img_keypose = gr.Image(source='upload', type="numpy", label='Keypose guidance')
|
263 |
-
|
264 |
-
prompt = gr.Textbox(label="Prompt")
|
265 |
-
neg_prompt = gr.Textbox(label="Negative Prompt",
|
266 |
-
value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
|
267 |
-
pos_prompt = gr.Textbox(label="Positive Prompt",
|
268 |
-
value = 'crafted, elegant, meticulous, magnificent, maximum details, extremely hyper aesthetic, intricately detailed')
|
269 |
-
with gr.Row():
|
270 |
-
type_in_depth = gr.inputs.Radio(['Depth', 'Image'], type="value", default='Image', label='You can input an image or a depth map')
|
271 |
-
type_in_keypose = gr.inputs.Radio(['Keypose', 'Image'], type="value", default='Image', label='You can input an image or a keypose map (mmpose style)')
|
272 |
-
with gr.Row():
|
273 |
-
w_depth = gr.Slider(label="Depth guidance weight", minimum=0, maximum=2, value=1.0, step=0.1)
|
274 |
-
w_keypose = gr.Slider(label="Keypose guidance weight", minimum=0, maximum=2, value=1.5, step=0.1)
|
275 |
-
run_button = gr.Button(label="Run")
|
276 |
-
con_strength = gr.Slider(label="Controling Strength (The guidance strength of the multi-guidance to the result)", minimum=0, maximum=1, value=1, step=0.1)
|
277 |
-
scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
|
278 |
-
fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed)')
|
279 |
-
base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
|
280 |
-
with gr.Column():
|
281 |
-
result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=3, height='auto')
|
282 |
-
ips = [input_img_depth, input_img_keypose, type_in_depth, type_in_keypose, w_depth, w_keypose, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model]
|
283 |
-
run_button.click(fn=process, inputs=ips, outputs=[result])
|
284 |
-
return demo
|
285 |
-
|
286 |
-
def create_demo_draw(process):
|
287 |
-
with gr.Blocks() as demo:
|
288 |
-
with gr.Row():
|
289 |
-
gr.Markdown('## T2I-Adapter (Hand-free drawing)')
|
290 |
-
with gr.Row():
|
291 |
-
with gr.Column():
|
292 |
-
create_button = gr.Button(label="Start", value='Hand-free drawing')
|
293 |
-
input_img = gr.Image(source='upload', type="numpy",tool='sketch')
|
294 |
-
create_button.click(fn=create_map, outputs=[input_img], queue=False)
|
295 |
-
prompt = gr.Textbox(label="Prompt")
|
296 |
-
neg_prompt = gr.Textbox(label="Negative Prompt",
|
297 |
-
value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
|
298 |
-
pos_prompt = gr.Textbox(label="Positive Prompt",
|
299 |
-
value = 'crafted, elegant, meticulous, magnificent, maximum details, extremely hyper aesthetic, intricately detailed')
|
300 |
-
run_button = gr.Button(label="Run")
|
301 |
-
con_strength = gr.Slider(label="Controling Strength (The guidance strength of the sketch to the result)", minimum=0, maximum=1, value=0.4, step=0.1)
|
302 |
-
scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
|
303 |
-
fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed)')
|
304 |
-
base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
|
305 |
-
with gr.Column():
|
306 |
-
result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
|
307 |
-
ips = [input_img, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model]
|
308 |
-
run_button.click(fn=process, inputs=ips, outputs=[result])
|
309 |
-
return demo
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo/model.py
DELETED
@@ -1,979 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from basicsr.utils import img2tensor, tensor2img
|
3 |
-
from pytorch_lightning import seed_everything
|
4 |
-
from ldm.models.diffusion.plms import PLMSSampler
|
5 |
-
from ldm.modules.encoders.adapter import Adapter, Adapter_light, StyleAdapter
|
6 |
-
from ldm.util import instantiate_from_config
|
7 |
-
from ldm.modules.structure_condition.model_edge import pidinet
|
8 |
-
from ldm.modules.structure_condition.model_seg import seger, Colorize
|
9 |
-
from ldm.modules.structure_condition.midas.api import MiDaSInference
|
10 |
-
import gradio as gr
|
11 |
-
from omegaconf import OmegaConf
|
12 |
-
import mmcv
|
13 |
-
from mmdet.apis import inference_detector, init_detector
|
14 |
-
from mmpose.apis import (inference_top_down_pose_model, init_pose_model, process_mmdet_results, vis_pose_result)
|
15 |
-
import os
|
16 |
-
import cv2
|
17 |
-
import numpy as np
|
18 |
-
import torch.nn.functional as F
|
19 |
-
from transformers import CLIPProcessor, CLIPVisionModel
|
20 |
-
from PIL import Image
|
21 |
-
|
22 |
-
|
23 |
-
def preprocessing(image, device):
|
24 |
-
# Resize
|
25 |
-
scale = 640 / max(image.shape[:2])
|
26 |
-
image = cv2.resize(image, dsize=None, fx=scale, fy=scale)
|
27 |
-
raw_image = image.astype(np.uint8)
|
28 |
-
|
29 |
-
# Subtract mean values
|
30 |
-
image = image.astype(np.float32)
|
31 |
-
image -= np.array(
|
32 |
-
[
|
33 |
-
float(104.008),
|
34 |
-
float(116.669),
|
35 |
-
float(122.675),
|
36 |
-
]
|
37 |
-
)
|
38 |
-
|
39 |
-
# Convert to torch.Tensor and add "batch" axis
|
40 |
-
image = torch.from_numpy(image.transpose(2, 0, 1)).float().unsqueeze(0)
|
41 |
-
image = image.to(device)
|
42 |
-
|
43 |
-
return image, raw_image
|
44 |
-
|
45 |
-
|
46 |
-
def imshow_keypoints(img,
|
47 |
-
pose_result,
|
48 |
-
skeleton=None,
|
49 |
-
kpt_score_thr=0.1,
|
50 |
-
pose_kpt_color=None,
|
51 |
-
pose_link_color=None,
|
52 |
-
radius=4,
|
53 |
-
thickness=1):
|
54 |
-
"""Draw keypoints and links on an image.
|
55 |
-
|
56 |
-
Args:
|
57 |
-
img (ndarry): The image to draw poses on.
|
58 |
-
pose_result (list[kpts]): The poses to draw. Each element kpts is
|
59 |
-
a set of K keypoints as an Kx3 numpy.ndarray, where each
|
60 |
-
keypoint is represented as x, y, score.
|
61 |
-
kpt_score_thr (float, optional): Minimum score of keypoints
|
62 |
-
to be shown. Default: 0.3.
|
63 |
-
pose_kpt_color (np.array[Nx3]`): Color of N keypoints. If None,
|
64 |
-
the keypoint will not be drawn.
|
65 |
-
pose_link_color (np.array[Mx3]): Color of M links. If None, the
|
66 |
-
links will not be drawn.
|
67 |
-
thickness (int): Thickness of lines.
|
68 |
-
"""
|
69 |
-
|
70 |
-
img_h, img_w, _ = img.shape
|
71 |
-
img = np.zeros(img.shape)
|
72 |
-
|
73 |
-
for idx, kpts in enumerate(pose_result):
|
74 |
-
if idx > 1:
|
75 |
-
continue
|
76 |
-
kpts = kpts['keypoints']
|
77 |
-
kpts = np.array(kpts, copy=False)
|
78 |
-
|
79 |
-
# draw each point on image
|
80 |
-
if pose_kpt_color is not None:
|
81 |
-
assert len(pose_kpt_color) == len(kpts)
|
82 |
-
|
83 |
-
for kid, kpt in enumerate(kpts):
|
84 |
-
x_coord, y_coord, kpt_score = int(kpt[0]), int(kpt[1]), kpt[2]
|
85 |
-
|
86 |
-
if kpt_score < kpt_score_thr or pose_kpt_color[kid] is None:
|
87 |
-
# skip the point that should not be drawn
|
88 |
-
continue
|
89 |
-
|
90 |
-
color = tuple(int(c) for c in pose_kpt_color[kid])
|
91 |
-
cv2.circle(img, (int(x_coord), int(y_coord)), radius, color, -1)
|
92 |
-
|
93 |
-
# draw links
|
94 |
-
if skeleton is not None and pose_link_color is not None:
|
95 |
-
assert len(pose_link_color) == len(skeleton)
|
96 |
-
|
97 |
-
for sk_id, sk in enumerate(skeleton):
|
98 |
-
pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1]))
|
99 |
-
pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1]))
|
100 |
-
|
101 |
-
if (pos1[0] <= 0 or pos1[0] >= img_w or pos1[1] <= 0 or pos1[1] >= img_h or pos2[0] <= 0
|
102 |
-
or pos2[0] >= img_w or pos2[1] <= 0 or pos2[1] >= img_h or kpts[sk[0], 2] < kpt_score_thr
|
103 |
-
or kpts[sk[1], 2] < kpt_score_thr or pose_link_color[sk_id] is None):
|
104 |
-
# skip the link that should not be drawn
|
105 |
-
continue
|
106 |
-
color = tuple(int(c) for c in pose_link_color[sk_id])
|
107 |
-
cv2.line(img, pos1, pos2, color, thickness=thickness)
|
108 |
-
|
109 |
-
return img
|
110 |
-
|
111 |
-
|
112 |
-
def load_model_from_config(config, ckpt, verbose=False):
|
113 |
-
print(f"Loading model from {ckpt}")
|
114 |
-
pl_sd = torch.load(ckpt, map_location="cpu")
|
115 |
-
if "global_step" in pl_sd:
|
116 |
-
print(f"Global Step: {pl_sd['global_step']}")
|
117 |
-
if "state_dict" in pl_sd:
|
118 |
-
sd = pl_sd["state_dict"]
|
119 |
-
else:
|
120 |
-
sd = pl_sd
|
121 |
-
model = instantiate_from_config(config.model)
|
122 |
-
_, _ = model.load_state_dict(sd, strict=False)
|
123 |
-
|
124 |
-
model.cuda()
|
125 |
-
model.eval()
|
126 |
-
return model
|
127 |
-
|
128 |
-
|
129 |
-
class Model_all:
|
130 |
-
def __init__(self, device='cpu'):
|
131 |
-
# common part
|
132 |
-
self.device = device
|
133 |
-
self.config = OmegaConf.load("configs/stable-diffusion/app.yaml")
|
134 |
-
self.config.model.params.cond_stage_config.params.device = device
|
135 |
-
self.base_model = load_model_from_config(self.config, "models/sd-v1-4.ckpt").to(device)
|
136 |
-
self.current_base = 'sd-v1-4.ckpt'
|
137 |
-
self.sampler = PLMSSampler(self.base_model)
|
138 |
-
|
139 |
-
# sketch part
|
140 |
-
self.model_canny = Adapter(channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
|
141 |
-
use_conv=False).to(device)
|
142 |
-
self.model_canny.load_state_dict(torch.load("models/t2iadapter_canny_sd14v1.pth", map_location=device))
|
143 |
-
self.model_sketch = Adapter(channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
|
144 |
-
use_conv=False).to(device)
|
145 |
-
self.model_sketch.load_state_dict(torch.load("models/t2iadapter_sketch_sd14v1.pth", map_location=device))
|
146 |
-
self.model_edge = pidinet().to(device)
|
147 |
-
self.model_edge.load_state_dict({k.replace('module.', ''): v for k, v in
|
148 |
-
torch.load('models/table5_pidinet.pth', map_location=device)[
|
149 |
-
'state_dict'].items()})
|
150 |
-
|
151 |
-
# segmentation part
|
152 |
-
self.model_seger = seger().to(device)
|
153 |
-
self.model_seger.eval()
|
154 |
-
self.coler = Colorize(n=182)
|
155 |
-
self.model_seg = Adapter(cin=int(3 * 64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
|
156 |
-
use_conv=False).to(device)
|
157 |
-
self.model_seg.load_state_dict(torch.load("models/t2iadapter_seg_sd14v1.pth", map_location=device))
|
158 |
-
|
159 |
-
# depth part
|
160 |
-
self.depth_model = MiDaSInference(model_type='dpt_hybrid').to(device)
|
161 |
-
self.model_depth = Adapter(cin=3 * 64, channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
|
162 |
-
use_conv=False).to(device)
|
163 |
-
self.model_depth.load_state_dict(torch.load("models/t2iadapter_depth_sd14v1.pth", map_location=device))
|
164 |
-
|
165 |
-
# keypose part
|
166 |
-
self.model_pose = Adapter(cin=int(3 * 64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
|
167 |
-
use_conv=False).to(device)
|
168 |
-
self.model_pose.load_state_dict(torch.load("models/t2iadapter_keypose_sd14v1.pth", map_location=device))
|
169 |
-
|
170 |
-
# openpose part
|
171 |
-
self.model_openpose = Adapter(cin=int(3 * 64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
|
172 |
-
use_conv=False).to(device)
|
173 |
-
self.model_openpose.load_state_dict(torch.load("models/t2iadapter_openpose_sd14v1.pth", map_location=device))
|
174 |
-
|
175 |
-
# color part
|
176 |
-
self.model_color = Adapter_light(cin=int(3 * 64), channels=[320, 640, 1280, 1280], nums_rb=4).to(device)
|
177 |
-
self.model_color.load_state_dict(torch.load("models/t2iadapter_color_sd14v1.pth", map_location=device))
|
178 |
-
|
179 |
-
# style part
|
180 |
-
self.model_style = StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8).to(device)
|
181 |
-
self.model_style.load_state_dict(torch.load("models/t2iadapter_style_sd14v1.pth", map_location=device))
|
182 |
-
self.clip_processor = CLIPProcessor.from_pretrained('openai/clip-vit-large-patch14')
|
183 |
-
self.clip_vision_model = CLIPVisionModel.from_pretrained('openai/clip-vit-large-patch14').to(device)
|
184 |
-
|
185 |
-
device = 'cpu'
|
186 |
-
## mmpose
|
187 |
-
det_config = 'models/faster_rcnn_r50_fpn_coco.py'
|
188 |
-
det_checkpoint = 'models/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
|
189 |
-
pose_config = 'models/hrnet_w48_coco_256x192.py'
|
190 |
-
pose_checkpoint = 'models/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth'
|
191 |
-
self.det_cat_id = 1
|
192 |
-
self.bbox_thr = 0.2
|
193 |
-
## detector
|
194 |
-
det_config_mmcv = mmcv.Config.fromfile(det_config)
|
195 |
-
self.det_model = init_detector(det_config_mmcv, det_checkpoint, device=device)
|
196 |
-
pose_config_mmcv = mmcv.Config.fromfile(pose_config)
|
197 |
-
self.pose_model = init_pose_model(pose_config_mmcv, pose_checkpoint, device=device)
|
198 |
-
## color
|
199 |
-
self.skeleton = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12], [5, 11], [6, 12], [5, 6], [5, 7], [6, 8],
|
200 |
-
[7, 9], [8, 10],
|
201 |
-
[1, 2], [0, 1], [0, 2], [1, 3], [2, 4], [3, 5], [4, 6]]
|
202 |
-
self.pose_kpt_color = [[51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255],
|
203 |
-
[0, 255, 0],
|
204 |
-
[255, 128, 0], [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0], [0, 255, 0],
|
205 |
-
[255, 128, 0],
|
206 |
-
[0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0]]
|
207 |
-
self.pose_link_color = [[0, 255, 0], [0, 255, 0], [255, 128, 0], [255, 128, 0],
|
208 |
-
[51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [0, 255, 0],
|
209 |
-
[255, 128, 0],
|
210 |
-
[0, 255, 0], [255, 128, 0], [51, 153, 255], [51, 153, 255], [51, 153, 255],
|
211 |
-
[51, 153, 255],
|
212 |
-
[51, 153, 255], [51, 153, 255], [51, 153, 255]]
|
213 |
-
|
214 |
-
def load_vae(self):
|
215 |
-
vae_sd = torch.load(os.path.join('models', 'anything-v4.0.vae.pt'), map_location="cuda")
|
216 |
-
sd = vae_sd["state_dict"]
|
217 |
-
self.base_model.first_stage_model.load_state_dict(sd, strict=False)
|
218 |
-
|
219 |
-
@torch.no_grad()
|
220 |
-
def process_sketch(self, input_img, type_in, color_back, prompt, neg_prompt, pos_prompt, fix_sample, scale,
|
221 |
-
con_strength, base_model):
|
222 |
-
if self.current_base != base_model:
|
223 |
-
ckpt = os.path.join("models", base_model)
|
224 |
-
pl_sd = torch.load(ckpt, map_location="cuda")
|
225 |
-
if "state_dict" in pl_sd:
|
226 |
-
sd = pl_sd["state_dict"]
|
227 |
-
else:
|
228 |
-
sd = pl_sd
|
229 |
-
self.base_model.load_state_dict(sd, strict=False)
|
230 |
-
self.current_base = base_model
|
231 |
-
if 'anything' in base_model.lower():
|
232 |
-
self.load_vae()
|
233 |
-
|
234 |
-
con_strength = int((1 - con_strength) * 50)
|
235 |
-
if fix_sample == 'True':
|
236 |
-
seed_everything(42)
|
237 |
-
im = cv2.resize(input_img, (512, 512))
|
238 |
-
|
239 |
-
if type_in == 'Sketch':
|
240 |
-
if color_back == 'White':
|
241 |
-
im = 255 - im
|
242 |
-
im_edge = im.copy()
|
243 |
-
im = img2tensor(im)[0].unsqueeze(0).unsqueeze(0) / 255.
|
244 |
-
im = im > 0.5
|
245 |
-
im = im.float()
|
246 |
-
elif type_in == 'Image':
|
247 |
-
im = img2tensor(im).unsqueeze(0) / 255.
|
248 |
-
im = self.model_edge(im.to(self.device))[-1]
|
249 |
-
im = im > 0.5
|
250 |
-
im = im.float()
|
251 |
-
im_edge = tensor2img(im)
|
252 |
-
|
253 |
-
# extract condition features
|
254 |
-
c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
|
255 |
-
nc = self.base_model.get_learned_conditioning([neg_prompt])
|
256 |
-
features_adapter = self.model_sketch(im.to(self.device))
|
257 |
-
shape = [4, 64, 64]
|
258 |
-
|
259 |
-
# sampling
|
260 |
-
samples_ddim, _ = self.sampler.sample(S=50,
|
261 |
-
conditioning=c,
|
262 |
-
batch_size=1,
|
263 |
-
shape=shape,
|
264 |
-
verbose=False,
|
265 |
-
unconditional_guidance_scale=scale,
|
266 |
-
unconditional_conditioning=nc,
|
267 |
-
eta=0.0,
|
268 |
-
x_T=None,
|
269 |
-
features_adapter1=features_adapter,
|
270 |
-
mode='sketch',
|
271 |
-
con_strength=con_strength)
|
272 |
-
|
273 |
-
x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
|
274 |
-
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
275 |
-
x_samples_ddim = x_samples_ddim.to('cpu')
|
276 |
-
x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
|
277 |
-
x_samples_ddim = 255. * x_samples_ddim
|
278 |
-
x_samples_ddim = x_samples_ddim.astype(np.uint8)
|
279 |
-
|
280 |
-
return [im_edge, x_samples_ddim]
|
281 |
-
|
282 |
-
@torch.no_grad()
|
283 |
-
def process_canny(self, input_img, type_in, color_back, prompt, neg_prompt, pos_prompt, fix_sample, scale,
|
284 |
-
con_strength, base_model):
|
285 |
-
if self.current_base != base_model:
|
286 |
-
ckpt = os.path.join("models", base_model)
|
287 |
-
pl_sd = torch.load(ckpt, map_location="cuda")
|
288 |
-
if "state_dict" in pl_sd:
|
289 |
-
sd = pl_sd["state_dict"]
|
290 |
-
else:
|
291 |
-
sd = pl_sd
|
292 |
-
self.base_model.load_state_dict(sd, strict=False)
|
293 |
-
self.current_base = base_model
|
294 |
-
if 'anything' in base_model.lower():
|
295 |
-
self.load_vae()
|
296 |
-
|
297 |
-
con_strength = int((1 - con_strength) * 50)
|
298 |
-
if fix_sample == 'True':
|
299 |
-
seed_everything(42)
|
300 |
-
im = cv2.resize(input_img, (512, 512))
|
301 |
-
|
302 |
-
if type_in == 'Canny':
|
303 |
-
if color_back == 'White':
|
304 |
-
im = 255 - im
|
305 |
-
im_edge = im.copy()
|
306 |
-
im = img2tensor(im)[0].unsqueeze(0).unsqueeze(0) / 255.
|
307 |
-
elif type_in == 'Image':
|
308 |
-
im = cv2.Canny(im,100,200)
|
309 |
-
im = img2tensor(im[..., None], bgr2rgb=True, float32=True).unsqueeze(0) / 255.
|
310 |
-
im_edge = tensor2img(im)
|
311 |
-
|
312 |
-
# extract condition features
|
313 |
-
c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
|
314 |
-
nc = self.base_model.get_learned_conditioning([neg_prompt])
|
315 |
-
features_adapter = self.model_canny(im.to(self.device))
|
316 |
-
shape = [4, 64, 64]
|
317 |
-
|
318 |
-
# sampling
|
319 |
-
samples_ddim, _ = self.sampler.sample(S=50,
|
320 |
-
conditioning=c,
|
321 |
-
batch_size=1,
|
322 |
-
shape=shape,
|
323 |
-
verbose=False,
|
324 |
-
unconditional_guidance_scale=scale,
|
325 |
-
unconditional_conditioning=nc,
|
326 |
-
eta=0.0,
|
327 |
-
x_T=None,
|
328 |
-
features_adapter1=features_adapter,
|
329 |
-
mode='sketch',
|
330 |
-
con_strength=con_strength)
|
331 |
-
|
332 |
-
x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
|
333 |
-
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
334 |
-
x_samples_ddim = x_samples_ddim.to('cpu')
|
335 |
-
x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
|
336 |
-
x_samples_ddim = 255. * x_samples_ddim
|
337 |
-
x_samples_ddim = x_samples_ddim.astype(np.uint8)
|
338 |
-
|
339 |
-
return [im_edge, x_samples_ddim]
|
340 |
-
|
341 |
-
@torch.no_grad()
|
342 |
-
def process_color_sketch(self, input_img_sketch, input_img_color, type_in, type_in_color, w_sketch, w_color, color_back, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
|
343 |
-
if self.current_base != base_model:
|
344 |
-
ckpt = os.path.join("models", base_model)
|
345 |
-
pl_sd = torch.load(ckpt, map_location="cuda")
|
346 |
-
if "state_dict" in pl_sd:
|
347 |
-
sd = pl_sd["state_dict"]
|
348 |
-
else:
|
349 |
-
sd = pl_sd
|
350 |
-
self.base_model.load_state_dict(sd, strict=False)
|
351 |
-
self.current_base = base_model
|
352 |
-
if 'anything' in base_model.lower():
|
353 |
-
self.load_vae()
|
354 |
-
|
355 |
-
con_strength = int((1 - con_strength) * 50)
|
356 |
-
if fix_sample == 'True':
|
357 |
-
seed_everything(42)
|
358 |
-
im = cv2.resize(input_img_sketch, (512, 512))
|
359 |
-
|
360 |
-
if type_in == 'Sketch':
|
361 |
-
if color_back == 'White':
|
362 |
-
im = 255 - im
|
363 |
-
im_edge = im.copy()
|
364 |
-
im = img2tensor(im)[0].unsqueeze(0).unsqueeze(0) / 255.
|
365 |
-
im = im > 0.5
|
366 |
-
im = im.float()
|
367 |
-
elif type_in == 'Image':
|
368 |
-
im = img2tensor(im).unsqueeze(0) / 255.
|
369 |
-
im = self.model_edge(im.to(self.device))[-1]#.cuda()
|
370 |
-
im = im > 0.5
|
371 |
-
im = im.float()
|
372 |
-
im_edge = tensor2img(im)
|
373 |
-
if type_in_color == 'Image':
|
374 |
-
input_img_color = cv2.resize(input_img_color,(512//64, 512//64), interpolation=cv2.INTER_CUBIC)
|
375 |
-
input_img_color = cv2.resize(input_img_color,(512,512), interpolation=cv2.INTER_NEAREST)
|
376 |
-
else:
|
377 |
-
input_img_color = cv2.resize(input_img_color, (512, 512))
|
378 |
-
im_color = input_img_color.copy()
|
379 |
-
im_color_tensor = img2tensor(input_img_color, bgr2rgb=False).unsqueeze(0) / 255.
|
380 |
-
|
381 |
-
# extract condition features
|
382 |
-
c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
|
383 |
-
nc = self.base_model.get_learned_conditioning([neg_prompt])
|
384 |
-
features_adapter_sketch = self.model_sketch(im.to(self.device))
|
385 |
-
features_adapter_color = self.model_color(im_color_tensor.to(self.device))
|
386 |
-
features_adapter = [fs*w_sketch+fc*w_color for fs, fc in zip(features_adapter_sketch,features_adapter_color)]
|
387 |
-
shape = [4, 64, 64]
|
388 |
-
|
389 |
-
# sampling
|
390 |
-
samples_ddim, _ = self.sampler.sample(S=50,
|
391 |
-
conditioning=c,
|
392 |
-
batch_size=1,
|
393 |
-
shape=shape,
|
394 |
-
verbose=False,
|
395 |
-
unconditional_guidance_scale=scale,
|
396 |
-
unconditional_conditioning=nc,
|
397 |
-
eta=0.0,
|
398 |
-
x_T=None,
|
399 |
-
features_adapter1=features_adapter,
|
400 |
-
mode='sketch',
|
401 |
-
con_strength=con_strength)
|
402 |
-
|
403 |
-
x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
|
404 |
-
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
405 |
-
x_samples_ddim = x_samples_ddim.to('cpu')
|
406 |
-
x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
|
407 |
-
x_samples_ddim = 255. * x_samples_ddim
|
408 |
-
x_samples_ddim = x_samples_ddim.astype(np.uint8)
|
409 |
-
|
410 |
-
return [im_edge, im_color, x_samples_ddim]
|
411 |
-
|
412 |
-
@torch.no_grad()
|
413 |
-
def process_style_sketch(self, input_img_sketch, input_img_style, type_in, color_back, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
|
414 |
-
if self.current_base != base_model:
|
415 |
-
ckpt = os.path.join("models", base_model)
|
416 |
-
pl_sd = torch.load(ckpt, map_location="cuda")
|
417 |
-
if "state_dict" in pl_sd:
|
418 |
-
sd = pl_sd["state_dict"]
|
419 |
-
else:
|
420 |
-
sd = pl_sd
|
421 |
-
self.base_model.load_state_dict(sd, strict=False)
|
422 |
-
self.current_base = base_model
|
423 |
-
if 'anything' in base_model.lower():
|
424 |
-
self.load_vae()
|
425 |
-
|
426 |
-
con_strength = int((1 - con_strength) * 50)
|
427 |
-
if fix_sample == 'True':
|
428 |
-
seed_everything(42)
|
429 |
-
im = cv2.resize(input_img_sketch, (512, 512))
|
430 |
-
|
431 |
-
if type_in == 'Sketch':
|
432 |
-
if color_back == 'White':
|
433 |
-
im = 255 - im
|
434 |
-
im_edge = im.copy()
|
435 |
-
im = img2tensor(im)[0].unsqueeze(0).unsqueeze(0) / 255.
|
436 |
-
im = im > 0.5
|
437 |
-
im = im.float()
|
438 |
-
elif type_in == 'Image':
|
439 |
-
im = img2tensor(im).unsqueeze(0) / 255.
|
440 |
-
im = self.model_edge(im.to(self.device))[-1]#.cuda()
|
441 |
-
im = im > 0.5
|
442 |
-
im = im.float()
|
443 |
-
im_edge = tensor2img(im)
|
444 |
-
|
445 |
-
style = Image.fromarray(input_img_style)
|
446 |
-
style_for_clip = self.clip_processor(images=style, return_tensors="pt")['pixel_values']
|
447 |
-
style_feat = self.clip_vision_model(style_for_clip.to(self.device))['last_hidden_state']
|
448 |
-
style_feat = self.model_style(style_feat)
|
449 |
-
|
450 |
-
# extract condition features
|
451 |
-
c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
|
452 |
-
nc = self.base_model.get_learned_conditioning([neg_prompt])
|
453 |
-
features_adapter = self.model_sketch(im.to(self.device))
|
454 |
-
shape = [4, 64, 64]
|
455 |
-
|
456 |
-
# sampling
|
457 |
-
samples_ddim, _ = self.sampler.sample(S=50,
|
458 |
-
conditioning=c,
|
459 |
-
batch_size=1,
|
460 |
-
shape=shape,
|
461 |
-
verbose=False,
|
462 |
-
unconditional_guidance_scale=scale,
|
463 |
-
unconditional_conditioning=nc,
|
464 |
-
eta=0.0,
|
465 |
-
x_T=None,
|
466 |
-
features_adapter1=features_adapter,
|
467 |
-
mode='style',
|
468 |
-
con_strength=con_strength,
|
469 |
-
style_feature=style_feat)
|
470 |
-
|
471 |
-
x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
|
472 |
-
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
473 |
-
x_samples_ddim = x_samples_ddim.to('cpu')
|
474 |
-
x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
|
475 |
-
x_samples_ddim = 255. * x_samples_ddim
|
476 |
-
x_samples_ddim = x_samples_ddim.astype(np.uint8)
|
477 |
-
|
478 |
-
return [im_edge, x_samples_ddim]
|
479 |
-
|
480 |
-
@torch.no_grad()
|
481 |
-
def process_color(self, input_img, prompt, neg_prompt, pos_prompt, w_color, type_in_color, fix_sample, scale, con_strength, base_model):
|
482 |
-
if self.current_base != base_model:
|
483 |
-
ckpt = os.path.join("models", base_model)
|
484 |
-
pl_sd = torch.load(ckpt, map_location="cuda")
|
485 |
-
if "state_dict" in pl_sd:
|
486 |
-
sd = pl_sd["state_dict"]
|
487 |
-
else:
|
488 |
-
sd = pl_sd
|
489 |
-
self.base_model.load_state_dict(sd, strict=False)
|
490 |
-
self.current_base = base_model
|
491 |
-
if 'anything' in base_model.lower():
|
492 |
-
self.load_vae()
|
493 |
-
|
494 |
-
con_strength = int((1 - con_strength) * 50)
|
495 |
-
if fix_sample == 'True':
|
496 |
-
seed_everything(42)
|
497 |
-
if type_in_color == 'Image':
|
498 |
-
input_img = cv2.resize(input_img,(512//64, 512//64), interpolation=cv2.INTER_CUBIC)
|
499 |
-
input_img = cv2.resize(input_img,(512,512), interpolation=cv2.INTER_NEAREST)
|
500 |
-
else:
|
501 |
-
input_img = cv2.resize(input_img, (512, 512))
|
502 |
-
|
503 |
-
im_color = input_img.copy()
|
504 |
-
im = img2tensor(input_img, bgr2rgb=False).unsqueeze(0) / 255.
|
505 |
-
|
506 |
-
# extract condition features
|
507 |
-
c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
|
508 |
-
nc = self.base_model.get_learned_conditioning([neg_prompt])
|
509 |
-
features_adapter = self.model_color(im.to(self.device))
|
510 |
-
features_adapter = [fi*w_color for fi in features_adapter]
|
511 |
-
shape = [4, 64, 64]
|
512 |
-
|
513 |
-
# sampling
|
514 |
-
samples_ddim, _ = self.sampler.sample(S=50,
|
515 |
-
conditioning=c,
|
516 |
-
batch_size=1,
|
517 |
-
shape=shape,
|
518 |
-
verbose=False,
|
519 |
-
unconditional_guidance_scale=scale,
|
520 |
-
unconditional_conditioning=nc,
|
521 |
-
eta=0.0,
|
522 |
-
x_T=None,
|
523 |
-
features_adapter1=features_adapter,
|
524 |
-
mode='sketch',
|
525 |
-
con_strength=con_strength)
|
526 |
-
|
527 |
-
x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
|
528 |
-
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
529 |
-
x_samples_ddim = x_samples_ddim.to('cpu')
|
530 |
-
x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
|
531 |
-
x_samples_ddim = 255. * x_samples_ddim
|
532 |
-
x_samples_ddim = x_samples_ddim.astype(np.uint8)
|
533 |
-
|
534 |
-
return [im_color, x_samples_ddim]
|
535 |
-
|
536 |
-
@torch.no_grad()
|
537 |
-
def process_depth(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale,
|
538 |
-
con_strength, base_model):
|
539 |
-
if self.current_base != base_model:
|
540 |
-
ckpt = os.path.join("models", base_model)
|
541 |
-
pl_sd = torch.load(ckpt, map_location="cuda")
|
542 |
-
if "state_dict" in pl_sd:
|
543 |
-
sd = pl_sd["state_dict"]
|
544 |
-
else:
|
545 |
-
sd = pl_sd
|
546 |
-
self.base_model.load_state_dict(sd, strict=False)
|
547 |
-
self.current_base = base_model
|
548 |
-
if 'anything' in base_model.lower():
|
549 |
-
self.load_vae()
|
550 |
-
|
551 |
-
con_strength = int((1 - con_strength) * 50)
|
552 |
-
if fix_sample == 'True':
|
553 |
-
seed_everything(42)
|
554 |
-
im = cv2.resize(input_img, (512, 512))
|
555 |
-
|
556 |
-
if type_in == 'Depth':
|
557 |
-
im_depth = im.copy()
|
558 |
-
depth = img2tensor(im).unsqueeze(0) / 255.
|
559 |
-
elif type_in == 'Image':
|
560 |
-
im = img2tensor(im).unsqueeze(0) / 127.5 - 1.0
|
561 |
-
depth = self.depth_model(im.to(self.device)).repeat(1, 3, 1, 1)
|
562 |
-
depth -= torch.min(depth)
|
563 |
-
depth /= torch.max(depth)
|
564 |
-
im_depth = tensor2img(depth)
|
565 |
-
|
566 |
-
# extract condition features
|
567 |
-
c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
|
568 |
-
nc = self.base_model.get_learned_conditioning([neg_prompt])
|
569 |
-
features_adapter = self.model_depth(depth.to(self.device))
|
570 |
-
shape = [4, 64, 64]
|
571 |
-
|
572 |
-
# sampling
|
573 |
-
samples_ddim, _ = self.sampler.sample(S=50,
|
574 |
-
conditioning=c,
|
575 |
-
batch_size=1,
|
576 |
-
shape=shape,
|
577 |
-
verbose=False,
|
578 |
-
unconditional_guidance_scale=scale,
|
579 |
-
unconditional_conditioning=nc,
|
580 |
-
eta=0.0,
|
581 |
-
x_T=None,
|
582 |
-
features_adapter1=features_adapter,
|
583 |
-
mode='sketch',
|
584 |
-
con_strength=con_strength)
|
585 |
-
|
586 |
-
x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
|
587 |
-
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
588 |
-
x_samples_ddim = x_samples_ddim.to('cpu')
|
589 |
-
x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
|
590 |
-
x_samples_ddim = 255. * x_samples_ddim
|
591 |
-
x_samples_ddim = x_samples_ddim.astype(np.uint8)
|
592 |
-
|
593 |
-
return [im_depth, x_samples_ddim]
|
594 |
-
|
595 |
-
@torch.no_grad()
|
596 |
-
def process_depth_keypose(self, input_img_depth, input_img_keypose, type_in_depth, type_in_keypose, w_depth,
|
597 |
-
w_keypose, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
|
598 |
-
if self.current_base != base_model:
|
599 |
-
ckpt = os.path.join("models", base_model)
|
600 |
-
pl_sd = torch.load(ckpt, map_location="cuda")
|
601 |
-
if "state_dict" in pl_sd:
|
602 |
-
sd = pl_sd["state_dict"]
|
603 |
-
else:
|
604 |
-
sd = pl_sd
|
605 |
-
self.base_model.load_state_dict(sd, strict=False)
|
606 |
-
self.current_base = base_model
|
607 |
-
if 'anything' in base_model.lower():
|
608 |
-
self.load_vae()
|
609 |
-
|
610 |
-
if fix_sample == 'True':
|
611 |
-
seed_everything(42)
|
612 |
-
im_depth = cv2.resize(input_img_depth, (512, 512))
|
613 |
-
im_keypose = cv2.resize(input_img_keypose, (512, 512))
|
614 |
-
|
615 |
-
# get depth
|
616 |
-
if type_in_depth == 'Depth':
|
617 |
-
im_depth_out = im_depth.copy()
|
618 |
-
depth = img2tensor(im_depth).unsqueeze(0) / 255.
|
619 |
-
elif type_in_depth == 'Image':
|
620 |
-
im_depth = img2tensor(im_depth).unsqueeze(0) / 127.5 - 1.0
|
621 |
-
depth = self.depth_model(im_depth.to(self.device)).repeat(1, 3, 1, 1)
|
622 |
-
depth -= torch.min(depth)
|
623 |
-
depth /= torch.max(depth)
|
624 |
-
im_depth_out = tensor2img(depth)
|
625 |
-
|
626 |
-
# get keypose
|
627 |
-
if type_in_keypose == 'Keypose':
|
628 |
-
im_keypose_out = im_keypose.copy()[:,:,::-1]
|
629 |
-
elif type_in_keypose == 'Image':
|
630 |
-
image = im_keypose.copy()
|
631 |
-
im_keypose = img2tensor(im_keypose).unsqueeze(0) / 255.
|
632 |
-
mmdet_results = inference_detector(self.det_model, image)
|
633 |
-
# keep the person class bounding boxes.
|
634 |
-
person_results = process_mmdet_results(mmdet_results, self.det_cat_id)
|
635 |
-
|
636 |
-
# optional
|
637 |
-
return_heatmap = False
|
638 |
-
dataset = self.pose_model.cfg.data['test']['type']
|
639 |
-
|
640 |
-
# e.g. use ('backbone', ) to return backbone feature
|
641 |
-
output_layer_names = None
|
642 |
-
pose_results, _ = inference_top_down_pose_model(
|
643 |
-
self.pose_model,
|
644 |
-
image,
|
645 |
-
person_results,
|
646 |
-
bbox_thr=self.bbox_thr,
|
647 |
-
format='xyxy',
|
648 |
-
dataset=dataset,
|
649 |
-
dataset_info=None,
|
650 |
-
return_heatmap=return_heatmap,
|
651 |
-
outputs=output_layer_names)
|
652 |
-
|
653 |
-
# show the results
|
654 |
-
im_keypose_out = imshow_keypoints(
|
655 |
-
image,
|
656 |
-
pose_results,
|
657 |
-
skeleton=self.skeleton,
|
658 |
-
pose_kpt_color=self.pose_kpt_color,
|
659 |
-
pose_link_color=self.pose_link_color,
|
660 |
-
radius=2,
|
661 |
-
thickness=2)
|
662 |
-
im_keypose_out = im_keypose_out.astype(np.uint8)
|
663 |
-
|
664 |
-
# extract condition features
|
665 |
-
c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
|
666 |
-
nc = self.base_model.get_learned_conditioning([neg_prompt])
|
667 |
-
features_adapter_depth = self.model_depth(depth.to(self.device))
|
668 |
-
pose = img2tensor(im_keypose_out, bgr2rgb=True, float32=True) / 255.
|
669 |
-
pose = pose.unsqueeze(0)
|
670 |
-
features_adapter_keypose = self.model_pose(pose.to(self.device))
|
671 |
-
features_adapter = [f_d * w_depth + f_k * w_keypose for f_d, f_k in
|
672 |
-
zip(features_adapter_depth, features_adapter_keypose)]
|
673 |
-
shape = [4, 64, 64]
|
674 |
-
|
675 |
-
# sampling
|
676 |
-
con_strength = int((1 - con_strength) * 50)
|
677 |
-
samples_ddim, _ = self.sampler.sample(S=50,
|
678 |
-
conditioning=c,
|
679 |
-
batch_size=1,
|
680 |
-
shape=shape,
|
681 |
-
verbose=False,
|
682 |
-
unconditional_guidance_scale=scale,
|
683 |
-
unconditional_conditioning=nc,
|
684 |
-
eta=0.0,
|
685 |
-
x_T=None,
|
686 |
-
features_adapter1=features_adapter,
|
687 |
-
mode='sketch',
|
688 |
-
con_strength=con_strength)
|
689 |
-
|
690 |
-
x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
|
691 |
-
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
692 |
-
x_samples_ddim = x_samples_ddim.to('cpu')
|
693 |
-
x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
|
694 |
-
x_samples_ddim = 255. * x_samples_ddim
|
695 |
-
x_samples_ddim = x_samples_ddim.astype(np.uint8)
|
696 |
-
|
697 |
-
return [im_depth_out, im_keypose_out[:, :, ::-1], x_samples_ddim]
|
698 |
-
|
699 |
-
@torch.no_grad()
|
700 |
-
def process_seg(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale,
|
701 |
-
con_strength, base_model):
|
702 |
-
if self.current_base != base_model:
|
703 |
-
ckpt = os.path.join("models", base_model)
|
704 |
-
pl_sd = torch.load(ckpt, map_location="cuda")
|
705 |
-
if "state_dict" in pl_sd:
|
706 |
-
sd = pl_sd["state_dict"]
|
707 |
-
else:
|
708 |
-
sd = pl_sd
|
709 |
-
self.base_model.load_state_dict(sd, strict=False)
|
710 |
-
self.current_base = base_model
|
711 |
-
if 'anything' in base_model.lower():
|
712 |
-
self.load_vae()
|
713 |
-
|
714 |
-
con_strength = int((1 - con_strength) * 50)
|
715 |
-
if fix_sample == 'True':
|
716 |
-
seed_everything(42)
|
717 |
-
im = cv2.resize(input_img, (512, 512))
|
718 |
-
|
719 |
-
if type_in == 'Segmentation':
|
720 |
-
im_seg = im.copy()
|
721 |
-
im = img2tensor(im).unsqueeze(0) / 255.
|
722 |
-
labelmap = im.float()
|
723 |
-
elif type_in == 'Image':
|
724 |
-
im, _ = preprocessing(im, self.device)
|
725 |
-
_, _, H, W = im.shape
|
726 |
-
|
727 |
-
# Image -> Probability map
|
728 |
-
logits = self.model_seger(im)
|
729 |
-
logits = F.interpolate(logits, size=(H, W), mode="bilinear", align_corners=False)
|
730 |
-
probs = F.softmax(logits, dim=1)[0]
|
731 |
-
probs = probs.cpu().data.numpy()
|
732 |
-
labelmap = np.argmax(probs, axis=0)
|
733 |
-
|
734 |
-
labelmap = self.coler(labelmap)
|
735 |
-
labelmap = np.transpose(labelmap, (1, 2, 0))
|
736 |
-
labelmap = cv2.resize(labelmap, (512, 512))
|
737 |
-
labelmap = img2tensor(labelmap, bgr2rgb=False, float32=True) / 255.
|
738 |
-
im_seg = tensor2img(labelmap)[:, :, ::-1]
|
739 |
-
labelmap = labelmap.unsqueeze(0)
|
740 |
-
|
741 |
-
# extract condition features
|
742 |
-
c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
|
743 |
-
nc = self.base_model.get_learned_conditioning([neg_prompt])
|
744 |
-
features_adapter = self.model_seg(labelmap.to(self.device))
|
745 |
-
shape = [4, 64, 64]
|
746 |
-
|
747 |
-
# sampling
|
748 |
-
samples_ddim, _ = self.sampler.sample(S=50,
|
749 |
-
conditioning=c,
|
750 |
-
batch_size=1,
|
751 |
-
shape=shape,
|
752 |
-
verbose=False,
|
753 |
-
unconditional_guidance_scale=scale,
|
754 |
-
unconditional_conditioning=nc,
|
755 |
-
eta=0.0,
|
756 |
-
x_T=None,
|
757 |
-
features_adapter1=features_adapter,
|
758 |
-
mode='sketch',
|
759 |
-
con_strength=con_strength)
|
760 |
-
|
761 |
-
x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
|
762 |
-
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
763 |
-
x_samples_ddim = x_samples_ddim.to('cpu')
|
764 |
-
x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
|
765 |
-
x_samples_ddim = 255. * x_samples_ddim
|
766 |
-
x_samples_ddim = x_samples_ddim.astype(np.uint8)
|
767 |
-
|
768 |
-
return [im_seg, x_samples_ddim]
|
769 |
-
|
770 |
-
@torch.no_grad()
|
771 |
-
def process_draw(self, input_img, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
|
772 |
-
if self.current_base != base_model:
|
773 |
-
ckpt = os.path.join("models", base_model)
|
774 |
-
pl_sd = torch.load(ckpt, map_location="cuda")
|
775 |
-
if "state_dict" in pl_sd:
|
776 |
-
sd = pl_sd["state_dict"]
|
777 |
-
else:
|
778 |
-
sd = pl_sd
|
779 |
-
self.base_model.load_state_dict(sd, strict=False)
|
780 |
-
self.current_base = base_model
|
781 |
-
if 'anything' in base_model.lower():
|
782 |
-
self.load_vae()
|
783 |
-
|
784 |
-
con_strength = int((1 - con_strength) * 50)
|
785 |
-
if fix_sample == 'True':
|
786 |
-
seed_everything(42)
|
787 |
-
input_img = input_img['mask']
|
788 |
-
c = input_img[:, :, 0:3].astype(np.float32)
|
789 |
-
a = input_img[:, :, 3:4].astype(np.float32) / 255.0
|
790 |
-
im = c * a + 255.0 * (1.0 - a)
|
791 |
-
im = im.clip(0, 255).astype(np.uint8)
|
792 |
-
im = cv2.resize(im, (512, 512))
|
793 |
-
|
794 |
-
im_edge = im.copy()
|
795 |
-
im = img2tensor(im)[0].unsqueeze(0).unsqueeze(0) / 255.
|
796 |
-
im = im > 0.5
|
797 |
-
im = im.float()
|
798 |
-
|
799 |
-
# extract condition features
|
800 |
-
c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
|
801 |
-
nc = self.base_model.get_learned_conditioning([neg_prompt])
|
802 |
-
features_adapter = self.model_sketch(im.to(self.device))
|
803 |
-
shape = [4, 64, 64]
|
804 |
-
|
805 |
-
# sampling
|
806 |
-
samples_ddim, _ = self.sampler.sample(S=50,
|
807 |
-
conditioning=c,
|
808 |
-
batch_size=1,
|
809 |
-
shape=shape,
|
810 |
-
verbose=False,
|
811 |
-
unconditional_guidance_scale=scale,
|
812 |
-
unconditional_conditioning=nc,
|
813 |
-
eta=0.0,
|
814 |
-
x_T=None,
|
815 |
-
features_adapter1=features_adapter,
|
816 |
-
mode='sketch',
|
817 |
-
con_strength=con_strength)
|
818 |
-
|
819 |
-
x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
|
820 |
-
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
821 |
-
x_samples_ddim = x_samples_ddim.to('cpu')
|
822 |
-
x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
|
823 |
-
x_samples_ddim = 255. * x_samples_ddim
|
824 |
-
x_samples_ddim = x_samples_ddim.astype(np.uint8)
|
825 |
-
|
826 |
-
return [im_edge, x_samples_ddim]
|
827 |
-
|
828 |
-
@torch.no_grad()
|
829 |
-
def process_keypose(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength,
|
830 |
-
base_model):
|
831 |
-
if self.current_base != base_model:
|
832 |
-
ckpt = os.path.join("models", base_model)
|
833 |
-
pl_sd = torch.load(ckpt, map_location="cuda")
|
834 |
-
if "state_dict" in pl_sd:
|
835 |
-
sd = pl_sd["state_dict"]
|
836 |
-
else:
|
837 |
-
sd = pl_sd
|
838 |
-
self.base_model.load_state_dict(sd, strict=False)
|
839 |
-
self.current_base = base_model
|
840 |
-
if 'anything' in base_model.lower():
|
841 |
-
self.load_vae()
|
842 |
-
|
843 |
-
con_strength = int((1 - con_strength) * 50)
|
844 |
-
if fix_sample == 'True':
|
845 |
-
seed_everything(42)
|
846 |
-
im = cv2.resize(input_img, (512, 512))
|
847 |
-
|
848 |
-
if type_in == 'Keypose':
|
849 |
-
im_pose = im.copy()[:,:,::-1]
|
850 |
-
elif type_in == 'Image':
|
851 |
-
image = im.copy()
|
852 |
-
im = img2tensor(im).unsqueeze(0) / 255.
|
853 |
-
mmdet_results = inference_detector(self.det_model, image)
|
854 |
-
# keep the person class bounding boxes.
|
855 |
-
person_results = process_mmdet_results(mmdet_results, self.det_cat_id)
|
856 |
-
|
857 |
-
# optional
|
858 |
-
return_heatmap = False
|
859 |
-
dataset = self.pose_model.cfg.data['test']['type']
|
860 |
-
|
861 |
-
# e.g. use ('backbone', ) to return backbone feature
|
862 |
-
output_layer_names = None
|
863 |
-
pose_results, _ = inference_top_down_pose_model(
|
864 |
-
self.pose_model,
|
865 |
-
image,
|
866 |
-
person_results,
|
867 |
-
bbox_thr=self.bbox_thr,
|
868 |
-
format='xyxy',
|
869 |
-
dataset=dataset,
|
870 |
-
dataset_info=None,
|
871 |
-
return_heatmap=return_heatmap,
|
872 |
-
outputs=output_layer_names)
|
873 |
-
|
874 |
-
# show the results
|
875 |
-
im_pose = imshow_keypoints(
|
876 |
-
image,
|
877 |
-
pose_results,
|
878 |
-
skeleton=self.skeleton,
|
879 |
-
pose_kpt_color=self.pose_kpt_color,
|
880 |
-
pose_link_color=self.pose_link_color,
|
881 |
-
radius=2,
|
882 |
-
thickness=2)
|
883 |
-
# im_pose = cv2.resize(im_pose, (512, 512))
|
884 |
-
|
885 |
-
# extract condition features
|
886 |
-
c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
|
887 |
-
nc = self.base_model.get_learned_conditioning([neg_prompt])
|
888 |
-
pose = img2tensor(im_pose, bgr2rgb=True, float32=True) / 255.
|
889 |
-
pose = pose.unsqueeze(0)
|
890 |
-
features_adapter = self.model_pose(pose.to(self.device))
|
891 |
-
|
892 |
-
shape = [4, 64, 64]
|
893 |
-
|
894 |
-
# sampling
|
895 |
-
samples_ddim, _ = self.sampler.sample(S=50,
|
896 |
-
conditioning=c,
|
897 |
-
batch_size=1,
|
898 |
-
shape=shape,
|
899 |
-
verbose=False,
|
900 |
-
unconditional_guidance_scale=scale,
|
901 |
-
unconditional_conditioning=nc,
|
902 |
-
eta=0.0,
|
903 |
-
x_T=None,
|
904 |
-
features_adapter1=features_adapter,
|
905 |
-
mode='sketch',
|
906 |
-
con_strength=con_strength)
|
907 |
-
|
908 |
-
x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
|
909 |
-
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
910 |
-
x_samples_ddim = x_samples_ddim.to('cpu')
|
911 |
-
x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
|
912 |
-
x_samples_ddim = 255. * x_samples_ddim
|
913 |
-
x_samples_ddim = x_samples_ddim.astype(np.uint8)
|
914 |
-
|
915 |
-
return [im_pose[:, :, ::-1].astype(np.uint8), x_samples_ddim]
|
916 |
-
|
917 |
-
@torch.no_grad()
|
918 |
-
def process_openpose(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength,
|
919 |
-
base_model):
|
920 |
-
if self.current_base != base_model:
|
921 |
-
ckpt = os.path.join("models", base_model)
|
922 |
-
pl_sd = torch.load(ckpt, map_location="cuda")
|
923 |
-
if "state_dict" in pl_sd:
|
924 |
-
sd = pl_sd["state_dict"]
|
925 |
-
else:
|
926 |
-
sd = pl_sd
|
927 |
-
self.base_model.load_state_dict(sd, strict=False)
|
928 |
-
self.current_base = base_model
|
929 |
-
if 'anything' in base_model.lower():
|
930 |
-
self.load_vae()
|
931 |
-
|
932 |
-
con_strength = int((1 - con_strength) * 50)
|
933 |
-
if fix_sample == 'True':
|
934 |
-
seed_everything(42)
|
935 |
-
im = cv2.resize(input_img, (512, 512))
|
936 |
-
|
937 |
-
if type_in == 'Openpose':
|
938 |
-
im_pose = im.copy()[:,:,::-1]
|
939 |
-
elif type_in == 'Image':
|
940 |
-
from ldm.modules.structure_condition.openpose.api import OpenposeInference
|
941 |
-
model = OpenposeInference()
|
942 |
-
keypose = model(im[:,:,::-1])
|
943 |
-
im_pose = keypose.copy()
|
944 |
-
|
945 |
-
# extract condition features
|
946 |
-
c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
|
947 |
-
nc = self.base_model.get_learned_conditioning([neg_prompt])
|
948 |
-
pose = img2tensor(im_pose, bgr2rgb=True, float32=True) / 255.
|
949 |
-
pose = pose.unsqueeze(0)
|
950 |
-
features_adapter = self.model_openpose(pose.to(self.device))
|
951 |
-
|
952 |
-
shape = [4, 64, 64]
|
953 |
-
|
954 |
-
# sampling
|
955 |
-
samples_ddim, _ = self.sampler.sample(S=50,
|
956 |
-
conditioning=c,
|
957 |
-
batch_size=1,
|
958 |
-
shape=shape,
|
959 |
-
verbose=False,
|
960 |
-
unconditional_guidance_scale=scale,
|
961 |
-
unconditional_conditioning=nc,
|
962 |
-
eta=0.0,
|
963 |
-
x_T=None,
|
964 |
-
features_adapter1=features_adapter,
|
965 |
-
mode='sketch',
|
966 |
-
con_strength=con_strength)
|
967 |
-
|
968 |
-
x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
|
969 |
-
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
970 |
-
x_samples_ddim = x_samples_ddim.to('cpu')
|
971 |
-
x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
|
972 |
-
x_samples_ddim = 255. * x_samples_ddim
|
973 |
-
x_samples_ddim = x_samples_ddim.astype(np.uint8)
|
974 |
-
|
975 |
-
return [im_pose[:, :, ::-1].astype(np.uint8), x_samples_ddim]
|
976 |
-
|
977 |
-
|
978 |
-
if __name__ == '__main__':
|
979 |
-
model = Model_all('cpu')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dist_util.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
|
2 |
+
import functools
|
3 |
+
import os
|
4 |
+
import subprocess
|
5 |
+
import torch
|
6 |
+
import torch.distributed as dist
|
7 |
+
import torch.multiprocessing as mp
|
8 |
+
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
9 |
+
|
10 |
+
|
11 |
+
def init_dist(launcher, backend='nccl', **kwargs):
|
12 |
+
if mp.get_start_method(allow_none=True) is None:
|
13 |
+
mp.set_start_method('spawn')
|
14 |
+
if launcher == 'pytorch':
|
15 |
+
_init_dist_pytorch(backend, **kwargs)
|
16 |
+
elif launcher == 'slurm':
|
17 |
+
_init_dist_slurm(backend, **kwargs)
|
18 |
+
else:
|
19 |
+
raise ValueError(f'Invalid launcher type: {launcher}')
|
20 |
+
|
21 |
+
|
22 |
+
def _init_dist_pytorch(backend, **kwargs):
|
23 |
+
rank = int(os.environ['RANK'])
|
24 |
+
num_gpus = torch.cuda.device_count()
|
25 |
+
torch.cuda.set_device(rank % num_gpus)
|
26 |
+
dist.init_process_group(backend=backend, **kwargs)
|
27 |
+
|
28 |
+
|
29 |
+
def _init_dist_slurm(backend, port=None):
|
30 |
+
"""Initialize slurm distributed training environment.
|
31 |
+
|
32 |
+
If argument ``port`` is not specified, then the master port will be system
|
33 |
+
environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
|
34 |
+
environment variable, then a default port ``29500`` will be used.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
backend (str): Backend of torch.distributed.
|
38 |
+
port (int, optional): Master port. Defaults to None.
|
39 |
+
"""
|
40 |
+
proc_id = int(os.environ['SLURM_PROCID'])
|
41 |
+
ntasks = int(os.environ['SLURM_NTASKS'])
|
42 |
+
node_list = os.environ['SLURM_NODELIST']
|
43 |
+
num_gpus = torch.cuda.device_count()
|
44 |
+
torch.cuda.set_device(proc_id % num_gpus)
|
45 |
+
addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1')
|
46 |
+
# specify master port
|
47 |
+
if port is not None:
|
48 |
+
os.environ['MASTER_PORT'] = str(port)
|
49 |
+
elif 'MASTER_PORT' in os.environ:
|
50 |
+
pass # use MASTER_PORT in the environment variable
|
51 |
+
else:
|
52 |
+
# 29500 is torch.distributed default port
|
53 |
+
os.environ['MASTER_PORT'] = '29500'
|
54 |
+
os.environ['MASTER_ADDR'] = addr
|
55 |
+
os.environ['WORLD_SIZE'] = str(ntasks)
|
56 |
+
os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
|
57 |
+
os.environ['RANK'] = str(proc_id)
|
58 |
+
dist.init_process_group(backend=backend)
|
59 |
+
|
60 |
+
|
61 |
+
def get_dist_info():
|
62 |
+
if dist.is_available():
|
63 |
+
initialized = dist.is_initialized()
|
64 |
+
else:
|
65 |
+
initialized = False
|
66 |
+
if initialized:
|
67 |
+
rank = dist.get_rank()
|
68 |
+
world_size = dist.get_world_size()
|
69 |
+
else:
|
70 |
+
rank = 0
|
71 |
+
world_size = 1
|
72 |
+
return rank, world_size
|
73 |
+
|
74 |
+
|
75 |
+
def master_only(func):
|
76 |
+
|
77 |
+
@functools.wraps(func)
|
78 |
+
def wrapper(*args, **kwargs):
|
79 |
+
rank, _ = get_dist_info()
|
80 |
+
if rank == 0:
|
81 |
+
return func(*args, **kwargs)
|
82 |
+
|
83 |
+
return wrapper
|
84 |
+
|
85 |
+
def get_bare_model(net):
|
86 |
+
"""Get bare model, especially under wrapping with
|
87 |
+
DistributedDataParallel or DataParallel.
|
88 |
+
"""
|
89 |
+
if isinstance(net, (DataParallel, DistributedDataParallel)):
|
90 |
+
net = net.module
|
91 |
+
return net
|
docs/AdapterZoo.md
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapter Zoo
|
2 |
+
|
3 |
+
You can download the adapters from <https://huggingface.co/TencentARC/T2I-Adapter/tree/main>
|
4 |
+
|
5 |
+
All the following adapters are trained with Stable Diffusion (SD) V1.4, and they can be directly used on custom models as long as they are fine-tuned from the same text-to-image models, such as Anything-4.0 or models on the <https://civitai.com/>.
|
6 |
+
|
7 |
+
| Adapter Name | Adapter Description | Demos|Model Parameters| Model Storage | |
|
8 |
+
| --- | --- |--- |--- |--- |---|
|
9 |
+
| t2iadapter_color_sd14v1.pth | Spatial color palette β image | [Demos](examples.md#color-adapter-spatial-palette) |18 M | 75 MB | |
|
10 |
+
| t2iadapter_style_sd14v1.pth | Image style β image | [Demos](examples.md#style-adapter)|| 154MB | Preliminary model. Style adapters with finer controls are on the way|
|
11 |
+
| t2iadapter_openpose_sd14v1.pth | Openpose β image| [Demos](examples.md#openpose-adapter) |77 M| 309 MB | |
|
12 |
+
| t2iadapter_canny_sd14v1.pth | Canny edges β image | [Demos](examples.md#canny-adapter-edge )|77 M | 309 MB ||
|
13 |
+
| t2iadapter_sketch_sd14v1.pth | sketch β image ||77 M| 308 MB | |
|
14 |
+
| t2iadapter_keypose_sd14v1.pth | keypose β image || 77 M| 309 MB | mmpose style |
|
15 |
+
| t2iadapter_seg_sd14v1.pth | segmentation β image ||77 M| 309 MB ||
|
16 |
+
| t2iadapter_depth_sd14v1.pth | depth maps β image ||77 M | 309 MB | Not the final model, still under training|
|
docs/FAQ.md
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# FAQ
|
2 |
+
|
3 |
+
- **Q: The openpose adapter (t2iadapter_openpose_sd14v1) outputs gray-scale images.**
|
4 |
+
|
5 |
+
**A:** You can add `colorful` in the prompt to avoid this problem.
|
docs/examples.md
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Demos
|
2 |
+
|
3 |
+
## Style Adapter
|
4 |
+
|
5 |
+
<p align="center">
|
6 |
+
<img src="https://user-images.githubusercontent.com/17445847/222734169-d47789e8-e83c-48c2-80ef-a896c2bafbb0.png" height=450>
|
7 |
+
</p>
|
8 |
+
|
9 |
+
## Color Adapter (Spatial Palette)
|
10 |
+
|
11 |
+
<p align="center">
|
12 |
+
<img src="https://user-images.githubusercontent.com/17445847/222915829-ccfb0366-13a8-484a-9561-627fabd87d29.png" height=450>
|
13 |
+
</p>
|
14 |
+
|
15 |
+
## Openpose Adapter
|
16 |
+
|
17 |
+
<p align="center">
|
18 |
+
<img src="https://user-images.githubusercontent.com/17445847/222733916-dc26a66e-d786-4407-8889-b81804862b1a.png" height=450>
|
19 |
+
</p>
|
20 |
+
|
21 |
+
## Canny Adapter (Edge)
|
22 |
+
|
23 |
+
<p align="center">
|
24 |
+
<img src="https://user-images.githubusercontent.com/17445847/222915813-c8f264bd-1be6-4496-97ff-aec4f6b53788.png" height=450>
|
25 |
+
</p>
|
26 |
+
|
27 |
+
## Multi-adapters
|
28 |
+
<p align="center">
|
29 |
+
<img src="https://user-images.githubusercontent.com/17445847/220939329-379f88b7-444f-4a3a-9de0-8f90605d1d34.png" height=450>
|
30 |
+
</p>
|
31 |
+
|
32 |
+
<div align="center">
|
33 |
+
|
34 |
+
*T2I adapters naturally support using multiple adapters together.*
|
35 |
+
|
36 |
+
</div><br />
|
37 |
+
The testing script usage for this example is similar to the command line given below, except that we replaced the pretrained SD model with Anything 4.5 and Kenshi
|
38 |
+
|
39 |
+
>python test_composable_adapters.py --prompt "1gril, computer desk, best quality, extremely detailed" --neg_prompt "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality" --depth_cond_path examples/depth/desk_depth.png --depth_cond_weight 1.0 --depth_ckpt models/t2iadapter_depth_sd14v1.pth --depth_type_in depth --pose_cond_path examples/keypose/person_keypose.png --pose_cond_weight 1.5 --ckpt models/anything-v4.0-pruned.ckpt --n_sample 4 --max_resolution 524288
|
40 |
+
|
41 |
+
[Image source](https://twitter.com/toyxyz3/status/1628375164781211648)
|
environment.yaml
DELETED
@@ -1,31 +0,0 @@
|
|
1 |
-
name: ldm
|
2 |
-
channels:
|
3 |
-
- pytorch
|
4 |
-
- defaults
|
5 |
-
dependencies:
|
6 |
-
- python=3.8.5
|
7 |
-
- pip=20.3
|
8 |
-
- cudatoolkit=11.3
|
9 |
-
- pytorch=1.11.0
|
10 |
-
- torchvision=0.12.0
|
11 |
-
- numpy=1.19.2
|
12 |
-
- pip:
|
13 |
-
- albumentations==0.4.3
|
14 |
-
- diffusers
|
15 |
-
- opencv-python==4.1.2.30
|
16 |
-
- pudb==2019.2
|
17 |
-
- invisible-watermark
|
18 |
-
- imageio==2.9.0
|
19 |
-
- imageio-ffmpeg==0.4.2
|
20 |
-
- pytorch-lightning==1.4.2
|
21 |
-
- omegaconf==2.1.1
|
22 |
-
- test-tube>=0.7.5
|
23 |
-
- streamlit>=0.73.1
|
24 |
-
- einops==0.3.0
|
25 |
-
- torch-fidelity==0.3.0
|
26 |
-
- transformers==4.19.2
|
27 |
-
- torchmetrics==0.6.0
|
28 |
-
- kornia==0.6
|
29 |
-
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
30 |
-
- -e git+https://github.com/openai/CLIP.git@main#egg=clip
|
31 |
-
- -e .
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ldm/modules/structure_condition/midas/__init__.py β experiments/README.md
RENAMED
File without changes
|
ldm/data/base.py
DELETED
@@ -1,23 +0,0 @@
|
|
1 |
-
from abc import abstractmethod
|
2 |
-
from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
|
3 |
-
|
4 |
-
|
5 |
-
class Txt2ImgIterableBaseDataset(IterableDataset):
|
6 |
-
'''
|
7 |
-
Define an interface to make the IterableDatasets for text2img data chainable
|
8 |
-
'''
|
9 |
-
def __init__(self, num_records=0, valid_ids=None, size=256):
|
10 |
-
super().__init__()
|
11 |
-
self.num_records = num_records
|
12 |
-
self.valid_ids = valid_ids
|
13 |
-
self.sample_ids = valid_ids
|
14 |
-
self.size = size
|
15 |
-
|
16 |
-
print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
|
17 |
-
|
18 |
-
def __len__(self):
|
19 |
-
return self.num_records
|
20 |
-
|
21 |
-
@abstractmethod
|
22 |
-
def __iter__(self):
|
23 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ldm/data/dataset_coco.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import cv2
|
3 |
+
import os
|
4 |
+
from basicsr.utils import img2tensor
|
5 |
+
|
6 |
+
|
7 |
+
class dataset_coco_mask_color():
|
8 |
+
def __init__(self, path_json, root_path_im, root_path_mask, image_size):
|
9 |
+
super(dataset_coco_mask_color, self).__init__()
|
10 |
+
with open(path_json, 'r', encoding='utf-8') as fp:
|
11 |
+
data = json.load(fp)
|
12 |
+
data = data['annotations']
|
13 |
+
self.files = []
|
14 |
+
self.root_path_im = root_path_im
|
15 |
+
self.root_path_mask = root_path_mask
|
16 |
+
for file in data:
|
17 |
+
name = "%012d.png" % file['image_id']
|
18 |
+
self.files.append({'name': name, 'sentence': file['caption']})
|
19 |
+
|
20 |
+
def __getitem__(self, idx):
|
21 |
+
file = self.files[idx]
|
22 |
+
name = file['name']
|
23 |
+
# print(os.path.join(self.root_path_im, name))
|
24 |
+
im = cv2.imread(os.path.join(self.root_path_im, name.replace('.png', '.jpg')))
|
25 |
+
im = cv2.resize(im, (512, 512))
|
26 |
+
im = img2tensor(im, bgr2rgb=True, float32=True) / 255.
|
27 |
+
|
28 |
+
mask = cv2.imread(os.path.join(self.root_path_mask, name)) # [:,:,0]
|
29 |
+
mask = cv2.resize(mask, (512, 512))
|
30 |
+
mask = img2tensor(mask, bgr2rgb=True, float32=True) / 255. # [0].unsqueeze(0)#/255.
|
31 |
+
|
32 |
+
sentence = file['sentence']
|
33 |
+
return {'im': im, 'mask': mask, 'sentence': sentence}
|
34 |
+
|
35 |
+
def __len__(self):
|
36 |
+
return len(self.files)
|
ldm/data/dataset_depth.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import cv2
|
3 |
+
import os
|
4 |
+
from basicsr.utils import img2tensor
|
5 |
+
|
6 |
+
|
7 |
+
class DepthDataset():
|
8 |
+
def __init__(self, meta_file):
|
9 |
+
super(DepthDataset, self).__init__()
|
10 |
+
|
11 |
+
self.files = []
|
12 |
+
with open(meta_file, 'r') as f:
|
13 |
+
lines = f.readlines()
|
14 |
+
for line in lines:
|
15 |
+
img_path = line.strip()
|
16 |
+
depth_img_path = img_path.rsplit('.', 1)[0] + '.depth.png'
|
17 |
+
txt_path = img_path.rsplit('.', 1)[0] + '.txt'
|
18 |
+
self.files.append({'img_path': img_path, 'depth_img_path': depth_img_path, 'txt_path': txt_path})
|
19 |
+
|
20 |
+
def __getitem__(self, idx):
|
21 |
+
file = self.files[idx]
|
22 |
+
|
23 |
+
im = cv2.imread(file['img_path'])
|
24 |
+
im = img2tensor(im, bgr2rgb=True, float32=True) / 255.
|
25 |
+
|
26 |
+
depth = cv2.imread(file['depth_img_path']) # [:,:,0]
|
27 |
+
depth = img2tensor(depth, bgr2rgb=True, float32=True) / 255. # [0].unsqueeze(0)#/255.
|
28 |
+
|
29 |
+
with open(file['txt_path'], 'r') as fs:
|
30 |
+
sentence = fs.readline().strip()
|
31 |
+
|
32 |
+
return {'im': im, 'depth': depth, 'sentence': sentence}
|
33 |
+
|
34 |
+
def __len__(self):
|
35 |
+
return len(self.files)
|
ldm/data/dataset_laion.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import pytorch_lightning as pl
|
6 |
+
import torch
|
7 |
+
import webdataset as wds
|
8 |
+
from torchvision.transforms import transforms
|
9 |
+
|
10 |
+
from ldm.util import instantiate_from_config
|
11 |
+
|
12 |
+
|
13 |
+
def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True):
|
14 |
+
"""Take a list of samples (as dictionary) and create a batch, preserving the keys.
|
15 |
+
If `tensors` is True, `ndarray` objects are combined into
|
16 |
+
tensor batches.
|
17 |
+
:param dict samples: list of samples
|
18 |
+
:param bool tensors: whether to turn lists of ndarrays into a single ndarray
|
19 |
+
:returns: single sample consisting of a batch
|
20 |
+
:rtype: dict
|
21 |
+
"""
|
22 |
+
keys = set.intersection(*[set(sample.keys()) for sample in samples])
|
23 |
+
batched = {key: [] for key in keys}
|
24 |
+
|
25 |
+
for s in samples:
|
26 |
+
[batched[key].append(s[key]) for key in batched]
|
27 |
+
|
28 |
+
result = {}
|
29 |
+
for key in batched:
|
30 |
+
if isinstance(batched[key][0], (int, float)):
|
31 |
+
if combine_scalars:
|
32 |
+
result[key] = np.array(list(batched[key]))
|
33 |
+
elif isinstance(batched[key][0], torch.Tensor):
|
34 |
+
if combine_tensors:
|
35 |
+
result[key] = torch.stack(list(batched[key]))
|
36 |
+
elif isinstance(batched[key][0], np.ndarray):
|
37 |
+
if combine_tensors:
|
38 |
+
result[key] = np.array(list(batched[key]))
|
39 |
+
else:
|
40 |
+
result[key] = list(batched[key])
|
41 |
+
return result
|
42 |
+
|
43 |
+
|
44 |
+
class WebDataModuleFromConfig(pl.LightningDataModule):
|
45 |
+
|
46 |
+
def __init__(self,
|
47 |
+
tar_base,
|
48 |
+
batch_size,
|
49 |
+
train=None,
|
50 |
+
validation=None,
|
51 |
+
test=None,
|
52 |
+
num_workers=4,
|
53 |
+
multinode=True,
|
54 |
+
min_size=None,
|
55 |
+
max_pwatermark=1.0,
|
56 |
+
**kwargs):
|
57 |
+
super().__init__()
|
58 |
+
print(f'Setting tar base to {tar_base}')
|
59 |
+
self.tar_base = tar_base
|
60 |
+
self.batch_size = batch_size
|
61 |
+
self.num_workers = num_workers
|
62 |
+
self.train = train
|
63 |
+
self.validation = validation
|
64 |
+
self.test = test
|
65 |
+
self.multinode = multinode
|
66 |
+
self.min_size = min_size # filter out very small images
|
67 |
+
self.max_pwatermark = max_pwatermark # filter out watermarked images
|
68 |
+
|
69 |
+
def make_loader(self, dataset_config):
|
70 |
+
image_transforms = [instantiate_from_config(tt) for tt in dataset_config.image_transforms]
|
71 |
+
image_transforms = transforms.Compose(image_transforms)
|
72 |
+
|
73 |
+
process = instantiate_from_config(dataset_config['process'])
|
74 |
+
|
75 |
+
shuffle = dataset_config.get('shuffle', 0)
|
76 |
+
shardshuffle = shuffle > 0
|
77 |
+
|
78 |
+
nodesplitter = wds.shardlists.split_by_node if self.multinode else wds.shardlists.single_node_only
|
79 |
+
|
80 |
+
tars = os.path.join(self.tar_base, dataset_config.shards)
|
81 |
+
|
82 |
+
dset = wds.WebDataset(
|
83 |
+
tars, nodesplitter=nodesplitter, shardshuffle=shardshuffle,
|
84 |
+
handler=wds.warn_and_continue).repeat().shuffle(shuffle)
|
85 |
+
print(f'Loading webdataset with {len(dset.pipeline[0].urls)} shards.')
|
86 |
+
|
87 |
+
dset = (
|
88 |
+
dset.select(self.filter_keys).decode('pil',
|
89 |
+
handler=wds.warn_and_continue).select(self.filter_size).map_dict(
|
90 |
+
jpg=image_transforms, handler=wds.warn_and_continue).map(process))
|
91 |
+
dset = (dset.batched(self.batch_size, partial=False, collation_fn=dict_collation_fn))
|
92 |
+
|
93 |
+
loader = wds.WebLoader(dset, batch_size=None, shuffle=False, num_workers=self.num_workers)
|
94 |
+
|
95 |
+
return loader
|
96 |
+
|
97 |
+
def filter_size(self, x):
|
98 |
+
if self.min_size is None:
|
99 |
+
return True
|
100 |
+
try:
|
101 |
+
return x['json']['original_width'] >= self.min_size and x['json']['original_height'] >= self.min_size and x[
|
102 |
+
'json']['pwatermark'] <= self.max_pwatermark
|
103 |
+
except Exception:
|
104 |
+
return False
|
105 |
+
|
106 |
+
def filter_keys(self, x):
|
107 |
+
try:
|
108 |
+
return ("jpg" in x) and ("txt" in x)
|
109 |
+
except Exception:
|
110 |
+
return False
|
111 |
+
|
112 |
+
def train_dataloader(self):
|
113 |
+
return self.make_loader(self.train)
|
114 |
+
|
115 |
+
def val_dataloader(self):
|
116 |
+
return None
|
117 |
+
|
118 |
+
def test_dataloader(self):
|
119 |
+
return None
|
120 |
+
|
121 |
+
|
122 |
+
if __name__ == '__main__':
|
123 |
+
from omegaconf import OmegaConf
|
124 |
+
config = OmegaConf.load("configs/stable-diffusion/train_canny_sd_v1.yaml")
|
125 |
+
datamod = WebDataModuleFromConfig(**config["data"]["params"])
|
126 |
+
dataloader = datamod.train_dataloader()
|
127 |
+
|
128 |
+
for batch in dataloader:
|
129 |
+
print(batch.keys())
|
130 |
+
print(batch['jpg'].shape)
|
ldm/data/dataset_wikiart.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os.path
|
3 |
+
|
4 |
+
from PIL import Image
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
|
7 |
+
from transformers import CLIPProcessor
|
8 |
+
from torchvision.transforms import transforms
|
9 |
+
|
10 |
+
import pytorch_lightning as pl
|
11 |
+
|
12 |
+
|
13 |
+
class WikiArtDataset():
|
14 |
+
def __init__(self, meta_file):
|
15 |
+
super(WikiArtDataset, self).__init__()
|
16 |
+
|
17 |
+
self.files = []
|
18 |
+
with open(meta_file, 'r') as f:
|
19 |
+
js = json.load(f)
|
20 |
+
for img_path in js:
|
21 |
+
img_name = os.path.splitext(os.path.basename(img_path))[0]
|
22 |
+
caption = img_name.split('_')[-1]
|
23 |
+
caption = caption.split('-')
|
24 |
+
j = len(caption) - 1
|
25 |
+
while j >= 0:
|
26 |
+
if not caption[j].isdigit():
|
27 |
+
break
|
28 |
+
j -= 1
|
29 |
+
if j < 0:
|
30 |
+
continue
|
31 |
+
sentence = ' '.join(caption[:j + 1])
|
32 |
+
self.files.append({'img_path': os.path.join('datasets/wikiart', img_path), 'sentence': sentence})
|
33 |
+
|
34 |
+
version = 'openai/clip-vit-large-patch14'
|
35 |
+
self.processor = CLIPProcessor.from_pretrained(version)
|
36 |
+
|
37 |
+
self.jpg_transform = transforms.Compose([
|
38 |
+
transforms.Resize(512),
|
39 |
+
transforms.RandomCrop(512),
|
40 |
+
transforms.ToTensor(),
|
41 |
+
])
|
42 |
+
|
43 |
+
def __getitem__(self, idx):
|
44 |
+
file = self.files[idx]
|
45 |
+
|
46 |
+
im = Image.open(file['img_path'])
|
47 |
+
|
48 |
+
im_tensor = self.jpg_transform(im)
|
49 |
+
|
50 |
+
clip_im = self.processor(images=im, return_tensors="pt")['pixel_values'][0]
|
51 |
+
|
52 |
+
return {'jpg': im_tensor, 'style': clip_im, 'txt': file['sentence']}
|
53 |
+
|
54 |
+
def __len__(self):
|
55 |
+
return len(self.files)
|
56 |
+
|
57 |
+
|
58 |
+
class WikiArtDataModule(pl.LightningDataModule):
|
59 |
+
def __init__(self, meta_file, batch_size, num_workers):
|
60 |
+
super(WikiArtDataModule, self).__init__()
|
61 |
+
self.train_dataset = WikiArtDataset(meta_file)
|
62 |
+
self.batch_size = batch_size
|
63 |
+
self.num_workers = num_workers
|
64 |
+
|
65 |
+
def train_dataloader(self):
|
66 |
+
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers,
|
67 |
+
pin_memory=True)
|
ldm/data/imagenet.py
DELETED
@@ -1,394 +0,0 @@
|
|
1 |
-
import os, yaml, pickle, shutil, tarfile, glob
|
2 |
-
import cv2
|
3 |
-
import albumentations
|
4 |
-
import PIL
|
5 |
-
import numpy as np
|
6 |
-
import torchvision.transforms.functional as TF
|
7 |
-
from omegaconf import OmegaConf
|
8 |
-
from functools import partial
|
9 |
-
from PIL import Image
|
10 |
-
from tqdm import tqdm
|
11 |
-
from torch.utils.data import Dataset, Subset
|
12 |
-
|
13 |
-
import taming.data.utils as tdu
|
14 |
-
from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve
|
15 |
-
from taming.data.imagenet import ImagePaths
|
16 |
-
|
17 |
-
from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
|
18 |
-
|
19 |
-
|
20 |
-
def synset2idx(path_to_yaml="data/index_synset.yaml"):
|
21 |
-
with open(path_to_yaml) as f:
|
22 |
-
di2s = yaml.load(f)
|
23 |
-
return dict((v,k) for k,v in di2s.items())
|
24 |
-
|
25 |
-
|
26 |
-
class ImageNetBase(Dataset):
|
27 |
-
def __init__(self, config=None):
|
28 |
-
self.config = config or OmegaConf.create()
|
29 |
-
if not type(self.config)==dict:
|
30 |
-
self.config = OmegaConf.to_container(self.config)
|
31 |
-
self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
|
32 |
-
self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
|
33 |
-
self._prepare()
|
34 |
-
self._prepare_synset_to_human()
|
35 |
-
self._prepare_idx_to_synset()
|
36 |
-
self._prepare_human_to_integer_label()
|
37 |
-
self._load()
|
38 |
-
|
39 |
-
def __len__(self):
|
40 |
-
return len(self.data)
|
41 |
-
|
42 |
-
def __getitem__(self, i):
|
43 |
-
return self.data[i]
|
44 |
-
|
45 |
-
def _prepare(self):
|
46 |
-
raise NotImplementedError()
|
47 |
-
|
48 |
-
def _filter_relpaths(self, relpaths):
|
49 |
-
ignore = set([
|
50 |
-
"n06596364_9591.JPEG",
|
51 |
-
])
|
52 |
-
relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
|
53 |
-
if "sub_indices" in self.config:
|
54 |
-
indices = str_to_indices(self.config["sub_indices"])
|
55 |
-
synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
|
56 |
-
self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
|
57 |
-
files = []
|
58 |
-
for rpath in relpaths:
|
59 |
-
syn = rpath.split("/")[0]
|
60 |
-
if syn in synsets:
|
61 |
-
files.append(rpath)
|
62 |
-
return files
|
63 |
-
else:
|
64 |
-
return relpaths
|
65 |
-
|
66 |
-
def _prepare_synset_to_human(self):
|
67 |
-
SIZE = 2655750
|
68 |
-
URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
|
69 |
-
self.human_dict = os.path.join(self.root, "synset_human.txt")
|
70 |
-
if (not os.path.exists(self.human_dict) or
|
71 |
-
not os.path.getsize(self.human_dict)==SIZE):
|
72 |
-
download(URL, self.human_dict)
|
73 |
-
|
74 |
-
def _prepare_idx_to_synset(self):
|
75 |
-
URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
|
76 |
-
self.idx2syn = os.path.join(self.root, "index_synset.yaml")
|
77 |
-
if (not os.path.exists(self.idx2syn)):
|
78 |
-
download(URL, self.idx2syn)
|
79 |
-
|
80 |
-
def _prepare_human_to_integer_label(self):
|
81 |
-
URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1"
|
82 |
-
self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt")
|
83 |
-
if (not os.path.exists(self.human2integer)):
|
84 |
-
download(URL, self.human2integer)
|
85 |
-
with open(self.human2integer, "r") as f:
|
86 |
-
lines = f.read().splitlines()
|
87 |
-
assert len(lines) == 1000
|
88 |
-
self.human2integer_dict = dict()
|
89 |
-
for line in lines:
|
90 |
-
value, key = line.split(":")
|
91 |
-
self.human2integer_dict[key] = int(value)
|
92 |
-
|
93 |
-
def _load(self):
|
94 |
-
with open(self.txt_filelist, "r") as f:
|
95 |
-
self.relpaths = f.read().splitlines()
|
96 |
-
l1 = len(self.relpaths)
|
97 |
-
self.relpaths = self._filter_relpaths(self.relpaths)
|
98 |
-
print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
|
99 |
-
|
100 |
-
self.synsets = [p.split("/")[0] for p in self.relpaths]
|
101 |
-
self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
|
102 |
-
|
103 |
-
unique_synsets = np.unique(self.synsets)
|
104 |
-
class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
|
105 |
-
if not self.keep_orig_class_label:
|
106 |
-
self.class_labels = [class_dict[s] for s in self.synsets]
|
107 |
-
else:
|
108 |
-
self.class_labels = [self.synset2idx[s] for s in self.synsets]
|
109 |
-
|
110 |
-
with open(self.human_dict, "r") as f:
|
111 |
-
human_dict = f.read().splitlines()
|
112 |
-
human_dict = dict(line.split(maxsplit=1) for line in human_dict)
|
113 |
-
|
114 |
-
self.human_labels = [human_dict[s] for s in self.synsets]
|
115 |
-
|
116 |
-
labels = {
|
117 |
-
"relpath": np.array(self.relpaths),
|
118 |
-
"synsets": np.array(self.synsets),
|
119 |
-
"class_label": np.array(self.class_labels),
|
120 |
-
"human_label": np.array(self.human_labels),
|
121 |
-
}
|
122 |
-
|
123 |
-
if self.process_images:
|
124 |
-
self.size = retrieve(self.config, "size", default=256)
|
125 |
-
self.data = ImagePaths(self.abspaths,
|
126 |
-
labels=labels,
|
127 |
-
size=self.size,
|
128 |
-
random_crop=self.random_crop,
|
129 |
-
)
|
130 |
-
else:
|
131 |
-
self.data = self.abspaths
|
132 |
-
|
133 |
-
|
134 |
-
class ImageNetTrain(ImageNetBase):
|
135 |
-
NAME = "ILSVRC2012_train"
|
136 |
-
URL = "http://www.image-net.org/challenges/LSVRC/2012/"
|
137 |
-
AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
|
138 |
-
FILES = [
|
139 |
-
"ILSVRC2012_img_train.tar",
|
140 |
-
]
|
141 |
-
SIZES = [
|
142 |
-
147897477120,
|
143 |
-
]
|
144 |
-
|
145 |
-
def __init__(self, process_images=True, data_root=None, **kwargs):
|
146 |
-
self.process_images = process_images
|
147 |
-
self.data_root = data_root
|
148 |
-
super().__init__(**kwargs)
|
149 |
-
|
150 |
-
def _prepare(self):
|
151 |
-
if self.data_root:
|
152 |
-
self.root = os.path.join(self.data_root, self.NAME)
|
153 |
-
else:
|
154 |
-
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
|
155 |
-
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
|
156 |
-
|
157 |
-
self.datadir = os.path.join(self.root, "data")
|
158 |
-
self.txt_filelist = os.path.join(self.root, "filelist.txt")
|
159 |
-
self.expected_length = 1281167
|
160 |
-
self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
|
161 |
-
default=True)
|
162 |
-
if not tdu.is_prepared(self.root):
|
163 |
-
# prep
|
164 |
-
print("Preparing dataset {} in {}".format(self.NAME, self.root))
|
165 |
-
|
166 |
-
datadir = self.datadir
|
167 |
-
if not os.path.exists(datadir):
|
168 |
-
path = os.path.join(self.root, self.FILES[0])
|
169 |
-
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
|
170 |
-
import academictorrents as at
|
171 |
-
atpath = at.get(self.AT_HASH, datastore=self.root)
|
172 |
-
assert atpath == path
|
173 |
-
|
174 |
-
print("Extracting {} to {}".format(path, datadir))
|
175 |
-
os.makedirs(datadir, exist_ok=True)
|
176 |
-
with tarfile.open(path, "r:") as tar:
|
177 |
-
tar.extractall(path=datadir)
|
178 |
-
|
179 |
-
print("Extracting sub-tars.")
|
180 |
-
subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
|
181 |
-
for subpath in tqdm(subpaths):
|
182 |
-
subdir = subpath[:-len(".tar")]
|
183 |
-
os.makedirs(subdir, exist_ok=True)
|
184 |
-
with tarfile.open(subpath, "r:") as tar:
|
185 |
-
tar.extractall(path=subdir)
|
186 |
-
|
187 |
-
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
|
188 |
-
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
|
189 |
-
filelist = sorted(filelist)
|
190 |
-
filelist = "\n".join(filelist)+"\n"
|
191 |
-
with open(self.txt_filelist, "w") as f:
|
192 |
-
f.write(filelist)
|
193 |
-
|
194 |
-
tdu.mark_prepared(self.root)
|
195 |
-
|
196 |
-
|
197 |
-
class ImageNetValidation(ImageNetBase):
|
198 |
-
NAME = "ILSVRC2012_validation"
|
199 |
-
URL = "http://www.image-net.org/challenges/LSVRC/2012/"
|
200 |
-
AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
|
201 |
-
VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
|
202 |
-
FILES = [
|
203 |
-
"ILSVRC2012_img_val.tar",
|
204 |
-
"validation_synset.txt",
|
205 |
-
]
|
206 |
-
SIZES = [
|
207 |
-
6744924160,
|
208 |
-
1950000,
|
209 |
-
]
|
210 |
-
|
211 |
-
def __init__(self, process_images=True, data_root=None, **kwargs):
|
212 |
-
self.data_root = data_root
|
213 |
-
self.process_images = process_images
|
214 |
-
super().__init__(**kwargs)
|
215 |
-
|
216 |
-
def _prepare(self):
|
217 |
-
if self.data_root:
|
218 |
-
self.root = os.path.join(self.data_root, self.NAME)
|
219 |
-
else:
|
220 |
-
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
|
221 |
-
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
|
222 |
-
self.datadir = os.path.join(self.root, "data")
|
223 |
-
self.txt_filelist = os.path.join(self.root, "filelist.txt")
|
224 |
-
self.expected_length = 50000
|
225 |
-
self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
|
226 |
-
default=False)
|
227 |
-
if not tdu.is_prepared(self.root):
|
228 |
-
# prep
|
229 |
-
print("Preparing dataset {} in {}".format(self.NAME, self.root))
|
230 |
-
|
231 |
-
datadir = self.datadir
|
232 |
-
if not os.path.exists(datadir):
|
233 |
-
path = os.path.join(self.root, self.FILES[0])
|
234 |
-
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
|
235 |
-
import academictorrents as at
|
236 |
-
atpath = at.get(self.AT_HASH, datastore=self.root)
|
237 |
-
assert atpath == path
|
238 |
-
|
239 |
-
print("Extracting {} to {}".format(path, datadir))
|
240 |
-
os.makedirs(datadir, exist_ok=True)
|
241 |
-
with tarfile.open(path, "r:") as tar:
|
242 |
-
tar.extractall(path=datadir)
|
243 |
-
|
244 |
-
vspath = os.path.join(self.root, self.FILES[1])
|
245 |
-
if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
|
246 |
-
download(self.VS_URL, vspath)
|
247 |
-
|
248 |
-
with open(vspath, "r") as f:
|
249 |
-
synset_dict = f.read().splitlines()
|
250 |
-
synset_dict = dict(line.split() for line in synset_dict)
|
251 |
-
|
252 |
-
print("Reorganizing into synset folders")
|
253 |
-
synsets = np.unique(list(synset_dict.values()))
|
254 |
-
for s in synsets:
|
255 |
-
os.makedirs(os.path.join(datadir, s), exist_ok=True)
|
256 |
-
for k, v in synset_dict.items():
|
257 |
-
src = os.path.join(datadir, k)
|
258 |
-
dst = os.path.join(datadir, v)
|
259 |
-
shutil.move(src, dst)
|
260 |
-
|
261 |
-
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
|
262 |
-
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
|
263 |
-
filelist = sorted(filelist)
|
264 |
-
filelist = "\n".join(filelist)+"\n"
|
265 |
-
with open(self.txt_filelist, "w") as f:
|
266 |
-
f.write(filelist)
|
267 |
-
|
268 |
-
tdu.mark_prepared(self.root)
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
class ImageNetSR(Dataset):
|
273 |
-
def __init__(self, size=None,
|
274 |
-
degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.,
|
275 |
-
random_crop=True):
|
276 |
-
"""
|
277 |
-
Imagenet Superresolution Dataloader
|
278 |
-
Performs following ops in order:
|
279 |
-
1. crops a crop of size s from image either as random or center crop
|
280 |
-
2. resizes crop to size with cv2.area_interpolation
|
281 |
-
3. degrades resized crop with degradation_fn
|
282 |
-
|
283 |
-
:param size: resizing to size after cropping
|
284 |
-
:param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light
|
285 |
-
:param downscale_f: Low Resolution Downsample factor
|
286 |
-
:param min_crop_f: determines crop size s,
|
287 |
-
where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f)
|
288 |
-
:param max_crop_f: ""
|
289 |
-
:param data_root:
|
290 |
-
:param random_crop:
|
291 |
-
"""
|
292 |
-
self.base = self.get_base()
|
293 |
-
assert size
|
294 |
-
assert (size / downscale_f).is_integer()
|
295 |
-
self.size = size
|
296 |
-
self.LR_size = int(size / downscale_f)
|
297 |
-
self.min_crop_f = min_crop_f
|
298 |
-
self.max_crop_f = max_crop_f
|
299 |
-
assert(max_crop_f <= 1.)
|
300 |
-
self.center_crop = not random_crop
|
301 |
-
|
302 |
-
self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
|
303 |
-
|
304 |
-
self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
|
305 |
-
|
306 |
-
if degradation == "bsrgan":
|
307 |
-
self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
|
308 |
-
|
309 |
-
elif degradation == "bsrgan_light":
|
310 |
-
self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
|
311 |
-
|
312 |
-
else:
|
313 |
-
interpolation_fn = {
|
314 |
-
"cv_nearest": cv2.INTER_NEAREST,
|
315 |
-
"cv_bilinear": cv2.INTER_LINEAR,
|
316 |
-
"cv_bicubic": cv2.INTER_CUBIC,
|
317 |
-
"cv_area": cv2.INTER_AREA,
|
318 |
-
"cv_lanczos": cv2.INTER_LANCZOS4,
|
319 |
-
"pil_nearest": PIL.Image.NEAREST,
|
320 |
-
"pil_bilinear": PIL.Image.BILINEAR,
|
321 |
-
"pil_bicubic": PIL.Image.BICUBIC,
|
322 |
-
"pil_box": PIL.Image.BOX,
|
323 |
-
"pil_hamming": PIL.Image.HAMMING,
|
324 |
-
"pil_lanczos": PIL.Image.LANCZOS,
|
325 |
-
}[degradation]
|
326 |
-
|
327 |
-
self.pil_interpolation = degradation.startswith("pil_")
|
328 |
-
|
329 |
-
if self.pil_interpolation:
|
330 |
-
self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
|
331 |
-
|
332 |
-
else:
|
333 |
-
self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
|
334 |
-
interpolation=interpolation_fn)
|
335 |
-
|
336 |
-
def __len__(self):
|
337 |
-
return len(self.base)
|
338 |
-
|
339 |
-
def __getitem__(self, i):
|
340 |
-
example = self.base[i]
|
341 |
-
image = Image.open(example["file_path_"])
|
342 |
-
|
343 |
-
if not image.mode == "RGB":
|
344 |
-
image = image.convert("RGB")
|
345 |
-
|
346 |
-
image = np.array(image).astype(np.uint8)
|
347 |
-
|
348 |
-
min_side_len = min(image.shape[:2])
|
349 |
-
crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
|
350 |
-
crop_side_len = int(crop_side_len)
|
351 |
-
|
352 |
-
if self.center_crop:
|
353 |
-
self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
|
354 |
-
|
355 |
-
else:
|
356 |
-
self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
|
357 |
-
|
358 |
-
image = self.cropper(image=image)["image"]
|
359 |
-
image = self.image_rescaler(image=image)["image"]
|
360 |
-
|
361 |
-
if self.pil_interpolation:
|
362 |
-
image_pil = PIL.Image.fromarray(image)
|
363 |
-
LR_image = self.degradation_process(image_pil)
|
364 |
-
LR_image = np.array(LR_image).astype(np.uint8)
|
365 |
-
|
366 |
-
else:
|
367 |
-
LR_image = self.degradation_process(image=image)["image"]
|
368 |
-
|
369 |
-
example["image"] = (image/127.5 - 1.0).astype(np.float32)
|
370 |
-
example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
|
371 |
-
|
372 |
-
return example
|
373 |
-
|
374 |
-
|
375 |
-
class ImageNetSRTrain(ImageNetSR):
|
376 |
-
def __init__(self, **kwargs):
|
377 |
-
super().__init__(**kwargs)
|
378 |
-
|
379 |
-
def get_base(self):
|
380 |
-
with open("data/imagenet_train_hr_indices.p", "rb") as f:
|
381 |
-
indices = pickle.load(f)
|
382 |
-
dset = ImageNetTrain(process_images=False,)
|
383 |
-
return Subset(dset, indices)
|
384 |
-
|
385 |
-
|
386 |
-
class ImageNetSRValidation(ImageNetSR):
|
387 |
-
def __init__(self, **kwargs):
|
388 |
-
super().__init__(**kwargs)
|
389 |
-
|
390 |
-
def get_base(self):
|
391 |
-
with open("data/imagenet_val_hr_indices.p", "rb") as f:
|
392 |
-
indices = pickle.load(f)
|
393 |
-
dset = ImageNetValidation(process_images=False,)
|
394 |
-
return Subset(dset, indices)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ldm/data/lsun.py
DELETED
@@ -1,92 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import numpy as np
|
3 |
-
import PIL
|
4 |
-
from PIL import Image
|
5 |
-
from torch.utils.data import Dataset
|
6 |
-
from torchvision import transforms
|
7 |
-
|
8 |
-
|
9 |
-
class LSUNBase(Dataset):
|
10 |
-
def __init__(self,
|
11 |
-
txt_file,
|
12 |
-
data_root,
|
13 |
-
size=None,
|
14 |
-
interpolation="bicubic",
|
15 |
-
flip_p=0.5
|
16 |
-
):
|
17 |
-
self.data_paths = txt_file
|
18 |
-
self.data_root = data_root
|
19 |
-
with open(self.data_paths, "r") as f:
|
20 |
-
self.image_paths = f.read().splitlines()
|
21 |
-
self._length = len(self.image_paths)
|
22 |
-
self.labels = {
|
23 |
-
"relative_file_path_": [l for l in self.image_paths],
|
24 |
-
"file_path_": [os.path.join(self.data_root, l)
|
25 |
-
for l in self.image_paths],
|
26 |
-
}
|
27 |
-
|
28 |
-
self.size = size
|
29 |
-
self.interpolation = {"linear": PIL.Image.LINEAR,
|
30 |
-
"bilinear": PIL.Image.BILINEAR,
|
31 |
-
"bicubic": PIL.Image.BICUBIC,
|
32 |
-
"lanczos": PIL.Image.LANCZOS,
|
33 |
-
}[interpolation]
|
34 |
-
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
35 |
-
|
36 |
-
def __len__(self):
|
37 |
-
return self._length
|
38 |
-
|
39 |
-
def __getitem__(self, i):
|
40 |
-
example = dict((k, self.labels[k][i]) for k in self.labels)
|
41 |
-
image = Image.open(example["file_path_"])
|
42 |
-
if not image.mode == "RGB":
|
43 |
-
image = image.convert("RGB")
|
44 |
-
|
45 |
-
# default to score-sde preprocessing
|
46 |
-
img = np.array(image).astype(np.uint8)
|
47 |
-
crop = min(img.shape[0], img.shape[1])
|
48 |
-
h, w, = img.shape[0], img.shape[1]
|
49 |
-
img = img[(h - crop) // 2:(h + crop) // 2,
|
50 |
-
(w - crop) // 2:(w + crop) // 2]
|
51 |
-
|
52 |
-
image = Image.fromarray(img)
|
53 |
-
if self.size is not None:
|
54 |
-
image = image.resize((self.size, self.size), resample=self.interpolation)
|
55 |
-
|
56 |
-
image = self.flip(image)
|
57 |
-
image = np.array(image).astype(np.uint8)
|
58 |
-
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
|
59 |
-
return example
|
60 |
-
|
61 |
-
|
62 |
-
class LSUNChurchesTrain(LSUNBase):
|
63 |
-
def __init__(self, **kwargs):
|
64 |
-
super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
|
65 |
-
|
66 |
-
|
67 |
-
class LSUNChurchesValidation(LSUNBase):
|
68 |
-
def __init__(self, flip_p=0., **kwargs):
|
69 |
-
super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
|
70 |
-
flip_p=flip_p, **kwargs)
|
71 |
-
|
72 |
-
|
73 |
-
class LSUNBedroomsTrain(LSUNBase):
|
74 |
-
def __init__(self, **kwargs):
|
75 |
-
super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
|
76 |
-
|
77 |
-
|
78 |
-
class LSUNBedroomsValidation(LSUNBase):
|
79 |
-
def __init__(self, flip_p=0.0, **kwargs):
|
80 |
-
super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
|
81 |
-
flip_p=flip_p, **kwargs)
|
82 |
-
|
83 |
-
|
84 |
-
class LSUNCatsTrain(LSUNBase):
|
85 |
-
def __init__(self, **kwargs):
|
86 |
-
super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
|
87 |
-
|
88 |
-
|
89 |
-
class LSUNCatsValidation(LSUNBase):
|
90 |
-
def __init__(self, flip_p=0., **kwargs):
|
91 |
-
super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
|
92 |
-
flip_p=flip_p, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ldm/data/utils.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
from torchvision.transforms import transforms
|
6 |
+
from torchvision.transforms.functional import to_tensor
|
7 |
+
from transformers import CLIPProcessor
|
8 |
+
|
9 |
+
from basicsr.utils import img2tensor
|
10 |
+
|
11 |
+
|
12 |
+
class AddCannyFreezeThreshold(object):
|
13 |
+
|
14 |
+
def __init__(self, low_threshold=100, high_threshold=200):
|
15 |
+
self.low_threshold = low_threshold
|
16 |
+
self.high_threshold = high_threshold
|
17 |
+
|
18 |
+
def __call__(self, sample):
|
19 |
+
# sample['jpg'] is PIL image
|
20 |
+
x = sample['jpg']
|
21 |
+
img = cv2.cvtColor(np.array(x), cv2.COLOR_RGB2BGR)
|
22 |
+
canny = cv2.Canny(img, self.low_threshold, self.high_threshold)[..., None]
|
23 |
+
sample['canny'] = img2tensor(canny, bgr2rgb=True, float32=True) / 255.
|
24 |
+
sample['jpg'] = to_tensor(x)
|
25 |
+
return sample
|
26 |
+
|
27 |
+
|
28 |
+
class AddStyle(object):
|
29 |
+
|
30 |
+
def __init__(self, version):
|
31 |
+
self.processor = CLIPProcessor.from_pretrained(version)
|
32 |
+
self.pil_to_tensor = transforms.ToTensor()
|
33 |
+
|
34 |
+
def __call__(self, sample):
|
35 |
+
# sample['jpg'] is PIL image
|
36 |
+
x = sample['jpg']
|
37 |
+
style = self.processor(images=x, return_tensors="pt")['pixel_values'][0]
|
38 |
+
sample['style'] = style
|
39 |
+
sample['jpg'] = to_tensor(x)
|
40 |
+
return sample
|
ldm/inference_base.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
from omegaconf import OmegaConf
|
4 |
+
|
5 |
+
from ldm.models.diffusion.ddim import DDIMSampler
|
6 |
+
from ldm.models.diffusion.plms import PLMSSampler
|
7 |
+
from ldm.modules.encoders.adapter import Adapter, StyleAdapter, Adapter_light
|
8 |
+
from ldm.modules.extra_condition.api import ExtraCondition
|
9 |
+
from ldm.util import fix_cond_shapes, load_model_from_config, read_state_dict
|
10 |
+
|
11 |
+
DEFAULT_NEGATIVE_PROMPT = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
|
12 |
+
'fewer digits, cropped, worst quality, low quality'
|
13 |
+
|
14 |
+
|
15 |
+
def get_base_argument_parser() -> argparse.ArgumentParser:
|
16 |
+
"""get the base argument parser for inference scripts"""
|
17 |
+
parser = argparse.ArgumentParser()
|
18 |
+
parser.add_argument(
|
19 |
+
'--outdir',
|
20 |
+
type=str,
|
21 |
+
help='dir to write results to',
|
22 |
+
default=None,
|
23 |
+
)
|
24 |
+
|
25 |
+
parser.add_argument(
|
26 |
+
'--prompt',
|
27 |
+
type=str,
|
28 |
+
nargs='?',
|
29 |
+
default=None,
|
30 |
+
help='positive prompt',
|
31 |
+
)
|
32 |
+
|
33 |
+
parser.add_argument(
|
34 |
+
'--neg_prompt',
|
35 |
+
type=str,
|
36 |
+
default=DEFAULT_NEGATIVE_PROMPT,
|
37 |
+
help='negative prompt',
|
38 |
+
)
|
39 |
+
|
40 |
+
parser.add_argument(
|
41 |
+
'--cond_path',
|
42 |
+
type=str,
|
43 |
+
default=None,
|
44 |
+
help='condition image path',
|
45 |
+
)
|
46 |
+
|
47 |
+
parser.add_argument(
|
48 |
+
'--cond_inp_type',
|
49 |
+
type=str,
|
50 |
+
default='image',
|
51 |
+
help='the type of the input condition image, take depth T2I as example, the input can be raw image, '
|
52 |
+
'which depth will be calculated, or the input can be a directly a depth map image',
|
53 |
+
)
|
54 |
+
|
55 |
+
parser.add_argument(
|
56 |
+
'--sampler',
|
57 |
+
type=str,
|
58 |
+
default='ddim',
|
59 |
+
choices=['ddim', 'plms'],
|
60 |
+
help='sampling algorithm, currently, only ddim and plms are supported, more are on the way',
|
61 |
+
)
|
62 |
+
|
63 |
+
parser.add_argument(
|
64 |
+
'--steps',
|
65 |
+
type=int,
|
66 |
+
default=50,
|
67 |
+
help='number of sampling steps',
|
68 |
+
)
|
69 |
+
|
70 |
+
parser.add_argument(
|
71 |
+
'--sd_ckpt',
|
72 |
+
type=str,
|
73 |
+
default='models/sd-v1-4.ckpt',
|
74 |
+
help='path to checkpoint of stable diffusion model, both .ckpt and .safetensor are supported',
|
75 |
+
)
|
76 |
+
|
77 |
+
parser.add_argument(
|
78 |
+
'--vae_ckpt',
|
79 |
+
type=str,
|
80 |
+
default=None,
|
81 |
+
help='vae checkpoint, anime SD models usually have seperate vae ckpt that need to be loaded',
|
82 |
+
)
|
83 |
+
|
84 |
+
parser.add_argument(
|
85 |
+
'--adapter_ckpt',
|
86 |
+
type=str,
|
87 |
+
default=None,
|
88 |
+
help='path to checkpoint of adapter',
|
89 |
+
)
|
90 |
+
|
91 |
+
parser.add_argument(
|
92 |
+
'--config',
|
93 |
+
type=str,
|
94 |
+
default='configs/stable-diffusion/sd-v1-inference.yaml',
|
95 |
+
help='path to config which constructs SD model',
|
96 |
+
)
|
97 |
+
|
98 |
+
parser.add_argument(
|
99 |
+
'--max_resolution',
|
100 |
+
type=float,
|
101 |
+
default=512 * 512,
|
102 |
+
help='max image height * width, only for computer with limited vram',
|
103 |
+
)
|
104 |
+
|
105 |
+
parser.add_argument(
|
106 |
+
'--resize_short_edge',
|
107 |
+
type=int,
|
108 |
+
default=None,
|
109 |
+
help='resize short edge of the input image, if this arg is set, max_resolution will not be used',
|
110 |
+
)
|
111 |
+
|
112 |
+
parser.add_argument(
|
113 |
+
'--C',
|
114 |
+
type=int,
|
115 |
+
default=4,
|
116 |
+
help='latent channels',
|
117 |
+
)
|
118 |
+
|
119 |
+
parser.add_argument(
|
120 |
+
'--f',
|
121 |
+
type=int,
|
122 |
+
default=8,
|
123 |
+
help='downsampling factor',
|
124 |
+
)
|
125 |
+
|
126 |
+
parser.add_argument(
|
127 |
+
'--scale',
|
128 |
+
type=float,
|
129 |
+
default=7.5,
|
130 |
+
help='unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))',
|
131 |
+
)
|
132 |
+
|
133 |
+
parser.add_argument(
|
134 |
+
'--cond_tau',
|
135 |
+
type=float,
|
136 |
+
default=1.0,
|
137 |
+
help='timestamp parameter that determines until which step the adapter is applied, '
|
138 |
+
'similar as Prompt-to-Prompt tau')
|
139 |
+
|
140 |
+
parser.add_argument(
|
141 |
+
'--cond_weight',
|
142 |
+
type=float,
|
143 |
+
default=1.0,
|
144 |
+
help='the adapter features are multiplied by the cond_weight. The larger the cond_weight, the more aligned '
|
145 |
+
'the generated image and condition will be, but the generated quality may be reduced',
|
146 |
+
)
|
147 |
+
|
148 |
+
parser.add_argument(
|
149 |
+
'--seed',
|
150 |
+
type=int,
|
151 |
+
default=42,
|
152 |
+
)
|
153 |
+
|
154 |
+
parser.add_argument(
|
155 |
+
'--n_samples',
|
156 |
+
type=int,
|
157 |
+
default=4,
|
158 |
+
help='# of samples to generate',
|
159 |
+
)
|
160 |
+
|
161 |
+
return parser
|
162 |
+
|
163 |
+
|
164 |
+
def get_sd_models(opt):
|
165 |
+
"""
|
166 |
+
build stable diffusion model, sampler
|
167 |
+
"""
|
168 |
+
# SD
|
169 |
+
config = OmegaConf.load(f"{opt.config}")
|
170 |
+
model = load_model_from_config(config, opt.sd_ckpt, opt.vae_ckpt)
|
171 |
+
sd_model = model.to(opt.device)
|
172 |
+
|
173 |
+
# sampler
|
174 |
+
if opt.sampler == 'plms':
|
175 |
+
sampler = PLMSSampler(model)
|
176 |
+
elif opt.sampler == 'ddim':
|
177 |
+
sampler = DDIMSampler(model)
|
178 |
+
else:
|
179 |
+
raise NotImplementedError
|
180 |
+
|
181 |
+
return sd_model, sampler
|
182 |
+
|
183 |
+
|
184 |
+
def get_t2i_adapter_models(opt):
|
185 |
+
config = OmegaConf.load(f"{opt.config}")
|
186 |
+
model = load_model_from_config(config, opt.sd_ckpt, opt.vae_ckpt)
|
187 |
+
adapter_ckpt_path = getattr(opt, f'{opt.which_cond}_adapter_ckpt', None)
|
188 |
+
if adapter_ckpt_path is None:
|
189 |
+
adapter_ckpt_path = getattr(opt, 'adapter_ckpt')
|
190 |
+
adapter_ckpt = read_state_dict(adapter_ckpt_path)
|
191 |
+
new_state_dict = {}
|
192 |
+
for k, v in adapter_ckpt.items():
|
193 |
+
if not k.startswith('adapter.'):
|
194 |
+
new_state_dict[f'adapter.{k}'] = v
|
195 |
+
else:
|
196 |
+
new_state_dict[k] = v
|
197 |
+
m, u = model.load_state_dict(new_state_dict, strict=False)
|
198 |
+
if len(u) > 0:
|
199 |
+
print(f"unexpected keys in loading adapter ckpt {adapter_ckpt_path}:")
|
200 |
+
print(u)
|
201 |
+
|
202 |
+
model = model.to(opt.device)
|
203 |
+
|
204 |
+
# sampler
|
205 |
+
if opt.sampler == 'plms':
|
206 |
+
sampler = PLMSSampler(model)
|
207 |
+
elif opt.sampler == 'ddim':
|
208 |
+
sampler = DDIMSampler(model)
|
209 |
+
else:
|
210 |
+
raise NotImplementedError
|
211 |
+
|
212 |
+
return model, sampler
|
213 |
+
|
214 |
+
|
215 |
+
def get_cond_ch(cond_type: ExtraCondition):
|
216 |
+
if cond_type == ExtraCondition.sketch or cond_type == ExtraCondition.canny:
|
217 |
+
return 1
|
218 |
+
return 3
|
219 |
+
|
220 |
+
|
221 |
+
def get_adapters(opt, cond_type: ExtraCondition):
|
222 |
+
adapter = {}
|
223 |
+
cond_weight = getattr(opt, f'{cond_type.name}_weight', None)
|
224 |
+
if cond_weight is None:
|
225 |
+
cond_weight = getattr(opt, 'cond_weight')
|
226 |
+
adapter['cond_weight'] = cond_weight
|
227 |
+
|
228 |
+
if cond_type == ExtraCondition.style:
|
229 |
+
adapter['model'] = StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8).to(opt.device)
|
230 |
+
elif cond_type == ExtraCondition.color:
|
231 |
+
adapter['model'] = Adapter_light(
|
232 |
+
cin=64 * get_cond_ch(cond_type),
|
233 |
+
channels=[320, 640, 1280, 1280],
|
234 |
+
nums_rb=4).to(opt.device)
|
235 |
+
else:
|
236 |
+
adapter['model'] = Adapter(
|
237 |
+
cin=64 * get_cond_ch(cond_type),
|
238 |
+
channels=[320, 640, 1280, 1280][:4],
|
239 |
+
nums_rb=2,
|
240 |
+
ksize=1,
|
241 |
+
sk=True,
|
242 |
+
use_conv=False).to(opt.device)
|
243 |
+
ckpt_path = getattr(opt, f'{cond_type.name}_adapter_ckpt', None)
|
244 |
+
if ckpt_path is None:
|
245 |
+
ckpt_path = getattr(opt, 'adapter_ckpt')
|
246 |
+
adapter['model'].load_state_dict(torch.load(ckpt_path))
|
247 |
+
|
248 |
+
return adapter
|
249 |
+
|
250 |
+
|
251 |
+
def diffusion_inference(opt, model, sampler, adapter_features, append_to_context=None):
|
252 |
+
# get text embedding
|
253 |
+
c = model.get_learned_conditioning([opt.prompt])
|
254 |
+
if opt.scale != 1.0:
|
255 |
+
uc = model.get_learned_conditioning([opt.neg_prompt])
|
256 |
+
else:
|
257 |
+
uc = None
|
258 |
+
c, uc = fix_cond_shapes(model, c, uc)
|
259 |
+
|
260 |
+
if not hasattr(opt, 'H'):
|
261 |
+
opt.H = 512
|
262 |
+
opt.W = 512
|
263 |
+
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
|
264 |
+
|
265 |
+
samples_latents, _ = sampler.sample(
|
266 |
+
S=opt.steps,
|
267 |
+
conditioning=c,
|
268 |
+
batch_size=1,
|
269 |
+
shape=shape,
|
270 |
+
verbose=False,
|
271 |
+
unconditional_guidance_scale=opt.scale,
|
272 |
+
unconditional_conditioning=uc,
|
273 |
+
x_T=None,
|
274 |
+
features_adapter=adapter_features,
|
275 |
+
append_to_context=append_to_context,
|
276 |
+
cond_tau=opt.cond_tau,
|
277 |
+
)
|
278 |
+
|
279 |
+
x_samples = model.decode_first_stage(samples_latents)
|
280 |
+
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
281 |
+
|
282 |
+
return x_samples
|
ldm/models/autoencoder.py
CHANGED
@@ -1,64 +1,65 @@
|
|
1 |
import torch
|
2 |
import pytorch_lightning as pl
|
3 |
import torch.nn.functional as F
|
|
|
4 |
from contextlib import contextmanager
|
5 |
|
6 |
-
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
7 |
-
|
8 |
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
9 |
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
10 |
|
11 |
from ldm.util import instantiate_from_config
|
|
|
12 |
|
13 |
|
14 |
-
class
|
15 |
def __init__(self,
|
16 |
ddconfig,
|
17 |
lossconfig,
|
18 |
-
n_embed,
|
19 |
embed_dim,
|
20 |
ckpt_path=None,
|
21 |
ignore_keys=[],
|
22 |
image_key="image",
|
23 |
colorize_nlabels=None,
|
24 |
monitor=None,
|
25 |
-
|
26 |
-
|
27 |
-
lr_g_factor=1.0,
|
28 |
-
remap=None,
|
29 |
-
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
30 |
-
use_ema=False
|
31 |
):
|
32 |
super().__init__()
|
33 |
-
self.
|
34 |
-
self.n_embed = n_embed
|
35 |
self.image_key = image_key
|
36 |
self.encoder = Encoder(**ddconfig)
|
37 |
self.decoder = Decoder(**ddconfig)
|
38 |
self.loss = instantiate_from_config(lossconfig)
|
39 |
-
|
40 |
-
|
41 |
-
sane_index_shape=sane_index_shape)
|
42 |
-
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
43 |
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
|
|
44 |
if colorize_nlabels is not None:
|
45 |
assert type(colorize_nlabels)==int
|
46 |
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
47 |
if monitor is not None:
|
48 |
self.monitor = monitor
|
49 |
-
self.batch_resize_range = batch_resize_range
|
50 |
-
if self.batch_resize_range is not None:
|
51 |
-
print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
|
52 |
|
53 |
-
self.use_ema =
|
54 |
if self.use_ema:
|
55 |
-
self.
|
|
|
|
|
56 |
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
57 |
|
58 |
if ckpt_path is not None:
|
59 |
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
60 |
-
|
61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
@contextmanager
|
64 |
def ema_scope(self, context=None):
|
@@ -75,252 +76,10 @@ class VQModel(pl.LightningModule):
|
|
75 |
if context is not None:
|
76 |
print(f"{context}: Restored training weights")
|
77 |
|
78 |
-
def init_from_ckpt(self, path, ignore_keys=list()):
|
79 |
-
sd = torch.load(path, map_location="cpu")["state_dict"]
|
80 |
-
keys = list(sd.keys())
|
81 |
-
for k in keys:
|
82 |
-
for ik in ignore_keys:
|
83 |
-
if k.startswith(ik):
|
84 |
-
print("Deleting key {} from state_dict.".format(k))
|
85 |
-
del sd[k]
|
86 |
-
missing, unexpected = self.load_state_dict(sd, strict=False)
|
87 |
-
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
88 |
-
if len(missing) > 0:
|
89 |
-
print(f"Missing Keys: {missing}")
|
90 |
-
print(f"Unexpected Keys: {unexpected}")
|
91 |
-
|
92 |
def on_train_batch_end(self, *args, **kwargs):
|
93 |
if self.use_ema:
|
94 |
self.model_ema(self)
|
95 |
|
96 |
-
def encode(self, x):
|
97 |
-
h = self.encoder(x)
|
98 |
-
h = self.quant_conv(h)
|
99 |
-
quant, emb_loss, info = self.quantize(h)
|
100 |
-
return quant, emb_loss, info
|
101 |
-
|
102 |
-
def encode_to_prequant(self, x):
|
103 |
-
h = self.encoder(x)
|
104 |
-
h = self.quant_conv(h)
|
105 |
-
return h
|
106 |
-
|
107 |
-
def decode(self, quant):
|
108 |
-
quant = self.post_quant_conv(quant)
|
109 |
-
dec = self.decoder(quant)
|
110 |
-
return dec
|
111 |
-
|
112 |
-
def decode_code(self, code_b):
|
113 |
-
quant_b = self.quantize.embed_code(code_b)
|
114 |
-
dec = self.decode(quant_b)
|
115 |
-
return dec
|
116 |
-
|
117 |
-
def forward(self, input, return_pred_indices=False):
|
118 |
-
quant, diff, (_,_,ind) = self.encode(input)
|
119 |
-
dec = self.decode(quant)
|
120 |
-
if return_pred_indices:
|
121 |
-
return dec, diff, ind
|
122 |
-
return dec, diff
|
123 |
-
|
124 |
-
def get_input(self, batch, k):
|
125 |
-
x = batch[k]
|
126 |
-
if len(x.shape) == 3:
|
127 |
-
x = x[..., None]
|
128 |
-
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
129 |
-
if self.batch_resize_range is not None:
|
130 |
-
lower_size = self.batch_resize_range[0]
|
131 |
-
upper_size = self.batch_resize_range[1]
|
132 |
-
if self.global_step <= 4:
|
133 |
-
# do the first few batches with max size to avoid later oom
|
134 |
-
new_resize = upper_size
|
135 |
-
else:
|
136 |
-
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
|
137 |
-
if new_resize != x.shape[2]:
|
138 |
-
x = F.interpolate(x, size=new_resize, mode="bicubic")
|
139 |
-
x = x.detach()
|
140 |
-
return x
|
141 |
-
|
142 |
-
def training_step(self, batch, batch_idx, optimizer_idx):
|
143 |
-
# https://github.com/pytorch/pytorch/issues/37142
|
144 |
-
# try not to fool the heuristics
|
145 |
-
x = self.get_input(batch, self.image_key)
|
146 |
-
xrec, qloss, ind = self(x, return_pred_indices=True)
|
147 |
-
|
148 |
-
if optimizer_idx == 0:
|
149 |
-
# autoencode
|
150 |
-
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
151 |
-
last_layer=self.get_last_layer(), split="train",
|
152 |
-
predicted_indices=ind)
|
153 |
-
|
154 |
-
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
155 |
-
return aeloss
|
156 |
-
|
157 |
-
if optimizer_idx == 1:
|
158 |
-
# discriminator
|
159 |
-
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
160 |
-
last_layer=self.get_last_layer(), split="train")
|
161 |
-
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
162 |
-
return discloss
|
163 |
-
|
164 |
-
def validation_step(self, batch, batch_idx):
|
165 |
-
log_dict = self._validation_step(batch, batch_idx)
|
166 |
-
with self.ema_scope():
|
167 |
-
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
|
168 |
-
return log_dict
|
169 |
-
|
170 |
-
def _validation_step(self, batch, batch_idx, suffix=""):
|
171 |
-
x = self.get_input(batch, self.image_key)
|
172 |
-
xrec, qloss, ind = self(x, return_pred_indices=True)
|
173 |
-
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
|
174 |
-
self.global_step,
|
175 |
-
last_layer=self.get_last_layer(),
|
176 |
-
split="val"+suffix,
|
177 |
-
predicted_indices=ind
|
178 |
-
)
|
179 |
-
|
180 |
-
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
|
181 |
-
self.global_step,
|
182 |
-
last_layer=self.get_last_layer(),
|
183 |
-
split="val"+suffix,
|
184 |
-
predicted_indices=ind
|
185 |
-
)
|
186 |
-
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
|
187 |
-
self.log(f"val{suffix}/rec_loss", rec_loss,
|
188 |
-
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
189 |
-
self.log(f"val{suffix}/aeloss", aeloss,
|
190 |
-
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
191 |
-
if version.parse(pl.__version__) >= version.parse('1.4.0'):
|
192 |
-
del log_dict_ae[f"val{suffix}/rec_loss"]
|
193 |
-
self.log_dict(log_dict_ae)
|
194 |
-
self.log_dict(log_dict_disc)
|
195 |
-
return self.log_dict
|
196 |
-
|
197 |
-
def configure_optimizers(self):
|
198 |
-
lr_d = self.learning_rate
|
199 |
-
lr_g = self.lr_g_factor*self.learning_rate
|
200 |
-
print("lr_d", lr_d)
|
201 |
-
print("lr_g", lr_g)
|
202 |
-
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
203 |
-
list(self.decoder.parameters())+
|
204 |
-
list(self.quantize.parameters())+
|
205 |
-
list(self.quant_conv.parameters())+
|
206 |
-
list(self.post_quant_conv.parameters()),
|
207 |
-
lr=lr_g, betas=(0.5, 0.9))
|
208 |
-
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
209 |
-
lr=lr_d, betas=(0.5, 0.9))
|
210 |
-
|
211 |
-
if self.scheduler_config is not None:
|
212 |
-
scheduler = instantiate_from_config(self.scheduler_config)
|
213 |
-
|
214 |
-
print("Setting up LambdaLR scheduler...")
|
215 |
-
scheduler = [
|
216 |
-
{
|
217 |
-
'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
|
218 |
-
'interval': 'step',
|
219 |
-
'frequency': 1
|
220 |
-
},
|
221 |
-
{
|
222 |
-
'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
|
223 |
-
'interval': 'step',
|
224 |
-
'frequency': 1
|
225 |
-
},
|
226 |
-
]
|
227 |
-
return [opt_ae, opt_disc], scheduler
|
228 |
-
return [opt_ae, opt_disc], []
|
229 |
-
|
230 |
-
def get_last_layer(self):
|
231 |
-
return self.decoder.conv_out.weight
|
232 |
-
|
233 |
-
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
|
234 |
-
log = dict()
|
235 |
-
x = self.get_input(batch, self.image_key)
|
236 |
-
x = x.to(self.device)
|
237 |
-
if only_inputs:
|
238 |
-
log["inputs"] = x
|
239 |
-
return log
|
240 |
-
xrec, _ = self(x)
|
241 |
-
if x.shape[1] > 3:
|
242 |
-
# colorize with random projection
|
243 |
-
assert xrec.shape[1] > 3
|
244 |
-
x = self.to_rgb(x)
|
245 |
-
xrec = self.to_rgb(xrec)
|
246 |
-
log["inputs"] = x
|
247 |
-
log["reconstructions"] = xrec
|
248 |
-
if plot_ema:
|
249 |
-
with self.ema_scope():
|
250 |
-
xrec_ema, _ = self(x)
|
251 |
-
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
|
252 |
-
log["reconstructions_ema"] = xrec_ema
|
253 |
-
return log
|
254 |
-
|
255 |
-
def to_rgb(self, x):
|
256 |
-
assert self.image_key == "segmentation"
|
257 |
-
if not hasattr(self, "colorize"):
|
258 |
-
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
259 |
-
x = F.conv2d(x, weight=self.colorize)
|
260 |
-
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
261 |
-
return x
|
262 |
-
|
263 |
-
|
264 |
-
class VQModelInterface(VQModel):
|
265 |
-
def __init__(self, embed_dim, *args, **kwargs):
|
266 |
-
super().__init__(embed_dim=embed_dim, *args, **kwargs)
|
267 |
-
self.embed_dim = embed_dim
|
268 |
-
|
269 |
-
def encode(self, x):
|
270 |
-
h = self.encoder(x)
|
271 |
-
h = self.quant_conv(h)
|
272 |
-
return h
|
273 |
-
|
274 |
-
def decode(self, h, force_not_quantize=False):
|
275 |
-
# also go through quantization layer
|
276 |
-
if not force_not_quantize:
|
277 |
-
quant, emb_loss, info = self.quantize(h)
|
278 |
-
else:
|
279 |
-
quant = h
|
280 |
-
quant = self.post_quant_conv(quant)
|
281 |
-
dec = self.decoder(quant)
|
282 |
-
return dec
|
283 |
-
|
284 |
-
|
285 |
-
class AutoencoderKL(pl.LightningModule):
|
286 |
-
def __init__(self,
|
287 |
-
ddconfig,
|
288 |
-
lossconfig,
|
289 |
-
embed_dim,
|
290 |
-
ckpt_path=None,
|
291 |
-
ignore_keys=[],
|
292 |
-
image_key="image",
|
293 |
-
colorize_nlabels=None,
|
294 |
-
monitor=None,
|
295 |
-
):
|
296 |
-
super().__init__()
|
297 |
-
self.image_key = image_key
|
298 |
-
self.encoder = Encoder(**ddconfig)
|
299 |
-
self.decoder = Decoder(**ddconfig)
|
300 |
-
self.loss = instantiate_from_config(lossconfig)
|
301 |
-
assert ddconfig["double_z"]
|
302 |
-
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
|
303 |
-
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
304 |
-
self.embed_dim = embed_dim
|
305 |
-
if colorize_nlabels is not None:
|
306 |
-
assert type(colorize_nlabels)==int
|
307 |
-
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
308 |
-
if monitor is not None:
|
309 |
-
self.monitor = monitor
|
310 |
-
if ckpt_path is not None:
|
311 |
-
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
312 |
-
|
313 |
-
def init_from_ckpt(self, path, ignore_keys=list()):
|
314 |
-
sd = torch.load(path, map_location="cpu")["state_dict"]
|
315 |
-
keys = list(sd.keys())
|
316 |
-
for k in keys:
|
317 |
-
for ik in ignore_keys:
|
318 |
-
if k.startswith(ik):
|
319 |
-
print("Deleting key {} from state_dict.".format(k))
|
320 |
-
del sd[k]
|
321 |
-
self.load_state_dict(sd, strict=False)
|
322 |
-
print(f"Restored from {path}")
|
323 |
-
|
324 |
def encode(self, x):
|
325 |
h = self.encoder(x)
|
326 |
moments = self.quant_conv(h)
|
@@ -370,25 +129,33 @@ class AutoencoderKL(pl.LightningModule):
|
|
370 |
return discloss
|
371 |
|
372 |
def validation_step(self, batch, batch_idx):
|
|
|
|
|
|
|
|
|
|
|
|
|
373 |
inputs = self.get_input(batch, self.image_key)
|
374 |
reconstructions, posterior = self(inputs)
|
375 |
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
|
376 |
-
last_layer=self.get_last_layer(), split="val")
|
377 |
|
378 |
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
|
379 |
-
last_layer=self.get_last_layer(), split="val")
|
380 |
|
381 |
-
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
|
382 |
self.log_dict(log_dict_ae)
|
383 |
self.log_dict(log_dict_disc)
|
384 |
return self.log_dict
|
385 |
|
386 |
def configure_optimizers(self):
|
387 |
lr = self.learning_rate
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
|
|
|
|
392 |
lr=lr, betas=(0.5, 0.9))
|
393 |
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
394 |
lr=lr, betas=(0.5, 0.9))
|
@@ -398,7 +165,7 @@ class AutoencoderKL(pl.LightningModule):
|
|
398 |
return self.decoder.conv_out.weight
|
399 |
|
400 |
@torch.no_grad()
|
401 |
-
def log_images(self, batch, only_inputs=False, **kwargs):
|
402 |
log = dict()
|
403 |
x = self.get_input(batch, self.image_key)
|
404 |
x = x.to(self.device)
|
@@ -423,9 +190,9 @@ class AutoencoderKL(pl.LightningModule):
|
|
423 |
return x
|
424 |
|
425 |
|
426 |
-
class IdentityFirstStage(
|
427 |
def __init__(self, *args, vq_interface=False, **kwargs):
|
428 |
-
self.vq_interface = vq_interface
|
429 |
super().__init__()
|
430 |
|
431 |
def encode(self, x, *args, **kwargs):
|
@@ -441,3 +208,4 @@ class IdentityFirstStage(torch.nn.Module):
|
|
441 |
|
442 |
def forward(self, x, *args, **kwargs):
|
443 |
return x
|
|
|
|
1 |
import torch
|
2 |
import pytorch_lightning as pl
|
3 |
import torch.nn.functional as F
|
4 |
+
import torch.nn as nn
|
5 |
from contextlib import contextmanager
|
6 |
|
|
|
|
|
7 |
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
8 |
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
9 |
|
10 |
from ldm.util import instantiate_from_config
|
11 |
+
from ldm.modules.ema import LitEma
|
12 |
|
13 |
|
14 |
+
class AutoencoderKL(pl.LightningModule):
|
15 |
def __init__(self,
|
16 |
ddconfig,
|
17 |
lossconfig,
|
|
|
18 |
embed_dim,
|
19 |
ckpt_path=None,
|
20 |
ignore_keys=[],
|
21 |
image_key="image",
|
22 |
colorize_nlabels=None,
|
23 |
monitor=None,
|
24 |
+
ema_decay=None,
|
25 |
+
learn_logvar=False
|
|
|
|
|
|
|
|
|
26 |
):
|
27 |
super().__init__()
|
28 |
+
self.learn_logvar = learn_logvar
|
|
|
29 |
self.image_key = image_key
|
30 |
self.encoder = Encoder(**ddconfig)
|
31 |
self.decoder = Decoder(**ddconfig)
|
32 |
self.loss = instantiate_from_config(lossconfig)
|
33 |
+
assert ddconfig["double_z"]
|
34 |
+
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
|
|
|
|
|
35 |
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
36 |
+
self.embed_dim = embed_dim
|
37 |
if colorize_nlabels is not None:
|
38 |
assert type(colorize_nlabels)==int
|
39 |
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
40 |
if monitor is not None:
|
41 |
self.monitor = monitor
|
|
|
|
|
|
|
42 |
|
43 |
+
self.use_ema = ema_decay is not None
|
44 |
if self.use_ema:
|
45 |
+
self.ema_decay = ema_decay
|
46 |
+
assert 0. < ema_decay < 1.
|
47 |
+
self.model_ema = LitEma(self, decay=ema_decay)
|
48 |
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
49 |
|
50 |
if ckpt_path is not None:
|
51 |
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
52 |
+
|
53 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
54 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
55 |
+
keys = list(sd.keys())
|
56 |
+
for k in keys:
|
57 |
+
for ik in ignore_keys:
|
58 |
+
if k.startswith(ik):
|
59 |
+
print("Deleting key {} from state_dict.".format(k))
|
60 |
+
del sd[k]
|
61 |
+
self.load_state_dict(sd, strict=False)
|
62 |
+
print(f"Restored from {path}")
|
63 |
|
64 |
@contextmanager
|
65 |
def ema_scope(self, context=None):
|
|
|
76 |
if context is not None:
|
77 |
print(f"{context}: Restored training weights")
|
78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
def on_train_batch_end(self, *args, **kwargs):
|
80 |
if self.use_ema:
|
81 |
self.model_ema(self)
|
82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
def encode(self, x):
|
84 |
h = self.encoder(x)
|
85 |
moments = self.quant_conv(h)
|
|
|
129 |
return discloss
|
130 |
|
131 |
def validation_step(self, batch, batch_idx):
|
132 |
+
log_dict = self._validation_step(batch, batch_idx)
|
133 |
+
with self.ema_scope():
|
134 |
+
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
|
135 |
+
return log_dict
|
136 |
+
|
137 |
+
def _validation_step(self, batch, batch_idx, postfix=""):
|
138 |
inputs = self.get_input(batch, self.image_key)
|
139 |
reconstructions, posterior = self(inputs)
|
140 |
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
|
141 |
+
last_layer=self.get_last_layer(), split="val"+postfix)
|
142 |
|
143 |
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
|
144 |
+
last_layer=self.get_last_layer(), split="val"+postfix)
|
145 |
|
146 |
+
self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
|
147 |
self.log_dict(log_dict_ae)
|
148 |
self.log_dict(log_dict_disc)
|
149 |
return self.log_dict
|
150 |
|
151 |
def configure_optimizers(self):
|
152 |
lr = self.learning_rate
|
153 |
+
ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
|
154 |
+
self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
|
155 |
+
if self.learn_logvar:
|
156 |
+
print(f"{self.__class__.__name__}: Learning logvar")
|
157 |
+
ae_params_list.append(self.loss.logvar)
|
158 |
+
opt_ae = torch.optim.Adam(ae_params_list,
|
159 |
lr=lr, betas=(0.5, 0.9))
|
160 |
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
161 |
lr=lr, betas=(0.5, 0.9))
|
|
|
165 |
return self.decoder.conv_out.weight
|
166 |
|
167 |
@torch.no_grad()
|
168 |
+
def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
|
169 |
log = dict()
|
170 |
x = self.get_input(batch, self.image_key)
|
171 |
x = x.to(self.device)
|
|
|
190 |
return x
|
191 |
|
192 |
|
193 |
+
class IdentityFirstStage(nn.Module):
|
194 |
def __init__(self, *args, vq_interface=False, **kwargs):
|
195 |
+
self.vq_interface = vq_interface
|
196 |
super().__init__()
|
197 |
|
198 |
def encode(self, x, *args, **kwargs):
|
|
|
208 |
|
209 |
def forward(self, x, *args, **kwargs):
|
210 |
return x
|
211 |
+
|
ldm/models/diffusion/classifier.py
DELETED
@@ -1,267 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import torch
|
3 |
-
import pytorch_lightning as pl
|
4 |
-
from omegaconf import OmegaConf
|
5 |
-
from torch.nn import functional as F
|
6 |
-
from torch.optim import AdamW
|
7 |
-
from torch.optim.lr_scheduler import LambdaLR
|
8 |
-
from copy import deepcopy
|
9 |
-
from einops import rearrange
|
10 |
-
from glob import glob
|
11 |
-
from natsort import natsorted
|
12 |
-
|
13 |
-
from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
|
14 |
-
from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
|
15 |
-
|
16 |
-
__models__ = {
|
17 |
-
'class_label': EncoderUNetModel,
|
18 |
-
'segmentation': UNetModel
|
19 |
-
}
|
20 |
-
|
21 |
-
|
22 |
-
def disabled_train(self, mode=True):
|
23 |
-
"""Overwrite model.train with this function to make sure train/eval mode
|
24 |
-
does not change anymore."""
|
25 |
-
return self
|
26 |
-
|
27 |
-
|
28 |
-
class NoisyLatentImageClassifier(pl.LightningModule):
|
29 |
-
|
30 |
-
def __init__(self,
|
31 |
-
diffusion_path,
|
32 |
-
num_classes,
|
33 |
-
ckpt_path=None,
|
34 |
-
pool='attention',
|
35 |
-
label_key=None,
|
36 |
-
diffusion_ckpt_path=None,
|
37 |
-
scheduler_config=None,
|
38 |
-
weight_decay=1.e-2,
|
39 |
-
log_steps=10,
|
40 |
-
monitor='val/loss',
|
41 |
-
*args,
|
42 |
-
**kwargs):
|
43 |
-
super().__init__(*args, **kwargs)
|
44 |
-
self.num_classes = num_classes
|
45 |
-
# get latest config of diffusion model
|
46 |
-
diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
|
47 |
-
self.diffusion_config = OmegaConf.load(diffusion_config).model
|
48 |
-
self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
|
49 |
-
self.load_diffusion()
|
50 |
-
|
51 |
-
self.monitor = monitor
|
52 |
-
self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
|
53 |
-
self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
|
54 |
-
self.log_steps = log_steps
|
55 |
-
|
56 |
-
self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
|
57 |
-
else self.diffusion_model.cond_stage_key
|
58 |
-
|
59 |
-
assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
|
60 |
-
|
61 |
-
if self.label_key not in __models__:
|
62 |
-
raise NotImplementedError()
|
63 |
-
|
64 |
-
self.load_classifier(ckpt_path, pool)
|
65 |
-
|
66 |
-
self.scheduler_config = scheduler_config
|
67 |
-
self.use_scheduler = self.scheduler_config is not None
|
68 |
-
self.weight_decay = weight_decay
|
69 |
-
|
70 |
-
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
|
71 |
-
sd = torch.load(path, map_location="cpu")
|
72 |
-
if "state_dict" in list(sd.keys()):
|
73 |
-
sd = sd["state_dict"]
|
74 |
-
keys = list(sd.keys())
|
75 |
-
for k in keys:
|
76 |
-
for ik in ignore_keys:
|
77 |
-
if k.startswith(ik):
|
78 |
-
print("Deleting key {} from state_dict.".format(k))
|
79 |
-
del sd[k]
|
80 |
-
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
|
81 |
-
sd, strict=False)
|
82 |
-
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
83 |
-
if len(missing) > 0:
|
84 |
-
print(f"Missing Keys: {missing}")
|
85 |
-
if len(unexpected) > 0:
|
86 |
-
print(f"Unexpected Keys: {unexpected}")
|
87 |
-
|
88 |
-
def load_diffusion(self):
|
89 |
-
model = instantiate_from_config(self.diffusion_config)
|
90 |
-
self.diffusion_model = model.eval()
|
91 |
-
self.diffusion_model.train = disabled_train
|
92 |
-
for param in self.diffusion_model.parameters():
|
93 |
-
param.requires_grad = False
|
94 |
-
|
95 |
-
def load_classifier(self, ckpt_path, pool):
|
96 |
-
model_config = deepcopy(self.diffusion_config.params.unet_config.params)
|
97 |
-
model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
|
98 |
-
model_config.out_channels = self.num_classes
|
99 |
-
if self.label_key == 'class_label':
|
100 |
-
model_config.pool = pool
|
101 |
-
|
102 |
-
self.model = __models__[self.label_key](**model_config)
|
103 |
-
if ckpt_path is not None:
|
104 |
-
print('#####################################################################')
|
105 |
-
print(f'load from ckpt "{ckpt_path}"')
|
106 |
-
print('#####################################################################')
|
107 |
-
self.init_from_ckpt(ckpt_path)
|
108 |
-
|
109 |
-
@torch.no_grad()
|
110 |
-
def get_x_noisy(self, x, t, noise=None):
|
111 |
-
noise = default(noise, lambda: torch.randn_like(x))
|
112 |
-
continuous_sqrt_alpha_cumprod = None
|
113 |
-
if self.diffusion_model.use_continuous_noise:
|
114 |
-
continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
|
115 |
-
# todo: make sure t+1 is correct here
|
116 |
-
|
117 |
-
return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
|
118 |
-
continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
|
119 |
-
|
120 |
-
def forward(self, x_noisy, t, *args, **kwargs):
|
121 |
-
return self.model(x_noisy, t)
|
122 |
-
|
123 |
-
@torch.no_grad()
|
124 |
-
def get_input(self, batch, k):
|
125 |
-
x = batch[k]
|
126 |
-
if len(x.shape) == 3:
|
127 |
-
x = x[..., None]
|
128 |
-
x = rearrange(x, 'b h w c -> b c h w')
|
129 |
-
x = x.to(memory_format=torch.contiguous_format).float()
|
130 |
-
return x
|
131 |
-
|
132 |
-
@torch.no_grad()
|
133 |
-
def get_conditioning(self, batch, k=None):
|
134 |
-
if k is None:
|
135 |
-
k = self.label_key
|
136 |
-
assert k is not None, 'Needs to provide label key'
|
137 |
-
|
138 |
-
targets = batch[k].to(self.device)
|
139 |
-
|
140 |
-
if self.label_key == 'segmentation':
|
141 |
-
targets = rearrange(targets, 'b h w c -> b c h w')
|
142 |
-
for down in range(self.numd):
|
143 |
-
h, w = targets.shape[-2:]
|
144 |
-
targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
|
145 |
-
|
146 |
-
# targets = rearrange(targets,'b c h w -> b h w c')
|
147 |
-
|
148 |
-
return targets
|
149 |
-
|
150 |
-
def compute_top_k(self, logits, labels, k, reduction="mean"):
|
151 |
-
_, top_ks = torch.topk(logits, k, dim=1)
|
152 |
-
if reduction == "mean":
|
153 |
-
return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
|
154 |
-
elif reduction == "none":
|
155 |
-
return (top_ks == labels[:, None]).float().sum(dim=-1)
|
156 |
-
|
157 |
-
def on_train_epoch_start(self):
|
158 |
-
# save some memory
|
159 |
-
self.diffusion_model.model.to('cpu')
|
160 |
-
|
161 |
-
@torch.no_grad()
|
162 |
-
def write_logs(self, loss, logits, targets):
|
163 |
-
log_prefix = 'train' if self.training else 'val'
|
164 |
-
log = {}
|
165 |
-
log[f"{log_prefix}/loss"] = loss.mean()
|
166 |
-
log[f"{log_prefix}/acc@1"] = self.compute_top_k(
|
167 |
-
logits, targets, k=1, reduction="mean"
|
168 |
-
)
|
169 |
-
log[f"{log_prefix}/acc@5"] = self.compute_top_k(
|
170 |
-
logits, targets, k=5, reduction="mean"
|
171 |
-
)
|
172 |
-
|
173 |
-
self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
|
174 |
-
self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
|
175 |
-
self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
|
176 |
-
lr = self.optimizers().param_groups[0]['lr']
|
177 |
-
self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
|
178 |
-
|
179 |
-
def shared_step(self, batch, t=None):
|
180 |
-
x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
|
181 |
-
targets = self.get_conditioning(batch)
|
182 |
-
if targets.dim() == 4:
|
183 |
-
targets = targets.argmax(dim=1)
|
184 |
-
if t is None:
|
185 |
-
t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
|
186 |
-
else:
|
187 |
-
t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
|
188 |
-
x_noisy = self.get_x_noisy(x, t)
|
189 |
-
logits = self(x_noisy, t)
|
190 |
-
|
191 |
-
loss = F.cross_entropy(logits, targets, reduction='none')
|
192 |
-
|
193 |
-
self.write_logs(loss.detach(), logits.detach(), targets.detach())
|
194 |
-
|
195 |
-
loss = loss.mean()
|
196 |
-
return loss, logits, x_noisy, targets
|
197 |
-
|
198 |
-
def training_step(self, batch, batch_idx):
|
199 |
-
loss, *_ = self.shared_step(batch)
|
200 |
-
return loss
|
201 |
-
|
202 |
-
def reset_noise_accs(self):
|
203 |
-
self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
|
204 |
-
range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
|
205 |
-
|
206 |
-
def on_validation_start(self):
|
207 |
-
self.reset_noise_accs()
|
208 |
-
|
209 |
-
@torch.no_grad()
|
210 |
-
def validation_step(self, batch, batch_idx):
|
211 |
-
loss, *_ = self.shared_step(batch)
|
212 |
-
|
213 |
-
for t in self.noisy_acc:
|
214 |
-
_, logits, _, targets = self.shared_step(batch, t)
|
215 |
-
self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
|
216 |
-
self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
|
217 |
-
|
218 |
-
return loss
|
219 |
-
|
220 |
-
def configure_optimizers(self):
|
221 |
-
optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
|
222 |
-
|
223 |
-
if self.use_scheduler:
|
224 |
-
scheduler = instantiate_from_config(self.scheduler_config)
|
225 |
-
|
226 |
-
print("Setting up LambdaLR scheduler...")
|
227 |
-
scheduler = [
|
228 |
-
{
|
229 |
-
'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
|
230 |
-
'interval': 'step',
|
231 |
-
'frequency': 1
|
232 |
-
}]
|
233 |
-
return [optimizer], scheduler
|
234 |
-
|
235 |
-
return optimizer
|
236 |
-
|
237 |
-
@torch.no_grad()
|
238 |
-
def log_images(self, batch, N=8, *args, **kwargs):
|
239 |
-
log = dict()
|
240 |
-
x = self.get_input(batch, self.diffusion_model.first_stage_key)
|
241 |
-
log['inputs'] = x
|
242 |
-
|
243 |
-
y = self.get_conditioning(batch)
|
244 |
-
|
245 |
-
if self.label_key == 'class_label':
|
246 |
-
y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
|
247 |
-
log['labels'] = y
|
248 |
-
|
249 |
-
if ismap(y):
|
250 |
-
log['labels'] = self.diffusion_model.to_rgb(y)
|
251 |
-
|
252 |
-
for step in range(self.log_steps):
|
253 |
-
current_time = step * self.log_time_interval
|
254 |
-
|
255 |
-
_, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
|
256 |
-
|
257 |
-
log[f'inputs@t{current_time}'] = x_noisy
|
258 |
-
|
259 |
-
pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
|
260 |
-
pred = rearrange(pred, 'b h w c -> b c h w')
|
261 |
-
|
262 |
-
log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
|
263 |
-
|
264 |
-
for key in log:
|
265 |
-
log[key] = log[key][:N]
|
266 |
-
|
267 |
-
return log
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ldm/models/diffusion/ddim.py
CHANGED
@@ -3,7 +3,6 @@
|
|
3 |
import torch
|
4 |
import numpy as np
|
5 |
from tqdm import tqdm
|
6 |
-
from functools import partial
|
7 |
|
8 |
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
|
9 |
extract_into_tensor
|
@@ -24,7 +23,7 @@ class DDIMSampler(object):
|
|
24 |
|
25 |
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
26 |
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
27 |
-
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
|
28 |
alphas_cumprod = self.model.alphas_cumprod
|
29 |
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
30 |
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
@@ -43,14 +42,14 @@ class DDIMSampler(object):
|
|
43 |
# ddim sampling parameters
|
44 |
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
45 |
ddim_timesteps=self.ddim_timesteps,
|
46 |
-
eta=ddim_eta,verbose=verbose)
|
47 |
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
48 |
self.register_buffer('ddim_alphas', ddim_alphas)
|
49 |
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
50 |
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
51 |
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
52 |
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
53 |
-
|
54 |
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
55 |
|
56 |
@torch.no_grad()
|
@@ -75,6 +74,9 @@ class DDIMSampler(object):
|
|
75 |
log_every_t=100,
|
76 |
unconditional_guidance_scale=1.,
|
77 |
unconditional_conditioning=None,
|
|
|
|
|
|
|
78 |
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
79 |
**kwargs
|
80 |
):
|
@@ -107,6 +109,9 @@ class DDIMSampler(object):
|
|
107 |
log_every_t=log_every_t,
|
108 |
unconditional_guidance_scale=unconditional_guidance_scale,
|
109 |
unconditional_conditioning=unconditional_conditioning,
|
|
|
|
|
|
|
110 |
)
|
111 |
return samples, intermediates
|
112 |
|
@@ -116,7 +121,8 @@ class DDIMSampler(object):
|
|
116 |
callback=None, timesteps=None, quantize_denoised=False,
|
117 |
mask=None, x0=None, img_callback=None, log_every_t=100,
|
118 |
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
119 |
-
unconditional_guidance_scale=1., unconditional_conditioning=None,
|
|
|
120 |
device = self.model.betas.device
|
121 |
b = shape[0]
|
122 |
if x_T is None:
|
@@ -131,7 +137,7 @@ class DDIMSampler(object):
|
|
131 |
timesteps = self.ddim_timesteps[:subset_end]
|
132 |
|
133 |
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
134 |
-
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
|
135 |
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
136 |
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
137 |
|
@@ -151,7 +157,13 @@ class DDIMSampler(object):
|
|
151 |
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
152 |
corrector_kwargs=corrector_kwargs,
|
153 |
unconditional_guidance_scale=unconditional_guidance_scale,
|
154 |
-
unconditional_conditioning=unconditional_conditioning
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
img, pred_x0 = outs
|
156 |
if callback: callback(i)
|
157 |
if img_callback: img_callback(pred_x0, i)
|
@@ -165,20 +177,55 @@ class DDIMSampler(object):
|
|
165 |
@torch.no_grad()
|
166 |
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
167 |
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
168 |
-
unconditional_guidance_scale=1., unconditional_conditioning=None
|
|
|
169 |
b, *_, device = *x.shape, x.device
|
170 |
|
171 |
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
172 |
-
|
|
|
|
|
|
|
|
|
173 |
else:
|
174 |
x_in = torch.cat([x] * 2)
|
175 |
t_in = torch.cat([t] * 2)
|
176 |
-
|
177 |
-
|
178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
|
180 |
if score_corrector is not None:
|
181 |
-
assert self.model.parameterization == "eps"
|
182 |
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
183 |
|
184 |
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
@@ -189,14 +236,18 @@ class DDIMSampler(object):
|
|
189 |
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
190 |
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
191 |
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
192 |
-
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
193 |
|
194 |
# current prediction for x_0
|
195 |
-
|
|
|
|
|
|
|
|
|
196 |
if quantize_denoised:
|
197 |
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
198 |
# direction pointing to x_t
|
199 |
-
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
200 |
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
201 |
if noise_dropout > 0.:
|
202 |
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
@@ -238,4 +289,4 @@ class DDIMSampler(object):
|
|
238 |
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
|
239 |
unconditional_guidance_scale=unconditional_guidance_scale,
|
240 |
unconditional_conditioning=unconditional_conditioning)
|
241 |
-
return x_dec
|
|
|
3 |
import torch
|
4 |
import numpy as np
|
5 |
from tqdm import tqdm
|
|
|
6 |
|
7 |
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
|
8 |
extract_into_tensor
|
|
|
23 |
|
24 |
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
25 |
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
26 |
+
num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose)
|
27 |
alphas_cumprod = self.model.alphas_cumprod
|
28 |
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
29 |
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
|
|
42 |
# ddim sampling parameters
|
43 |
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
44 |
ddim_timesteps=self.ddim_timesteps,
|
45 |
+
eta=ddim_eta, verbose=verbose)
|
46 |
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
47 |
self.register_buffer('ddim_alphas', ddim_alphas)
|
48 |
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
49 |
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
50 |
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
51 |
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
52 |
+
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
53 |
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
54 |
|
55 |
@torch.no_grad()
|
|
|
74 |
log_every_t=100,
|
75 |
unconditional_guidance_scale=1.,
|
76 |
unconditional_conditioning=None,
|
77 |
+
features_adapter=None,
|
78 |
+
append_to_context=None,
|
79 |
+
cond_tau=0.4,
|
80 |
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
81 |
**kwargs
|
82 |
):
|
|
|
109 |
log_every_t=log_every_t,
|
110 |
unconditional_guidance_scale=unconditional_guidance_scale,
|
111 |
unconditional_conditioning=unconditional_conditioning,
|
112 |
+
features_adapter=features_adapter,
|
113 |
+
append_to_context=append_to_context,
|
114 |
+
cond_tau=cond_tau,
|
115 |
)
|
116 |
return samples, intermediates
|
117 |
|
|
|
121 |
callback=None, timesteps=None, quantize_denoised=False,
|
122 |
mask=None, x0=None, img_callback=None, log_every_t=100,
|
123 |
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
124 |
+
unconditional_guidance_scale=1., unconditional_conditioning=None, features_adapter=None,
|
125 |
+
append_to_context=None, cond_tau=0.4):
|
126 |
device = self.model.betas.device
|
127 |
b = shape[0]
|
128 |
if x_T is None:
|
|
|
137 |
timesteps = self.ddim_timesteps[:subset_end]
|
138 |
|
139 |
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
140 |
+
time_range = reversed(range(0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
|
141 |
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
142 |
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
143 |
|
|
|
157 |
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
158 |
corrector_kwargs=corrector_kwargs,
|
159 |
unconditional_guidance_scale=unconditional_guidance_scale,
|
160 |
+
unconditional_conditioning=unconditional_conditioning,
|
161 |
+
features_adapter=None if index < int(
|
162 |
+
(1 - cond_tau) * total_steps) else features_adapter,
|
163 |
+
# TODO support style_cond_tau
|
164 |
+
append_to_context=None if index < int(
|
165 |
+
0.5 * total_steps) else append_to_context,
|
166 |
+
)
|
167 |
img, pred_x0 = outs
|
168 |
if callback: callback(i)
|
169 |
if img_callback: img_callback(pred_x0, i)
|
|
|
177 |
@torch.no_grad()
|
178 |
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
179 |
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
180 |
+
unconditional_guidance_scale=1., unconditional_conditioning=None, features_adapter=None,
|
181 |
+
append_to_context=None):
|
182 |
b, *_, device = *x.shape, x.device
|
183 |
|
184 |
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
185 |
+
if append_to_context is not None:
|
186 |
+
model_output = self.model.apply_model(x, t, torch.cat([c, append_to_context], dim=1),
|
187 |
+
features_adapter=features_adapter)
|
188 |
+
else:
|
189 |
+
model_output = self.model.apply_model(x, t, c, features_adapter=features_adapter)
|
190 |
else:
|
191 |
x_in = torch.cat([x] * 2)
|
192 |
t_in = torch.cat([t] * 2)
|
193 |
+
if isinstance(c, dict):
|
194 |
+
assert isinstance(unconditional_conditioning, dict)
|
195 |
+
c_in = dict()
|
196 |
+
for k in c:
|
197 |
+
if isinstance(c[k], list):
|
198 |
+
c_in[k] = [torch.cat([
|
199 |
+
unconditional_conditioning[k][i],
|
200 |
+
c[k][i]]) for i in range(len(c[k]))]
|
201 |
+
else:
|
202 |
+
c_in[k] = torch.cat([
|
203 |
+
unconditional_conditioning[k],
|
204 |
+
c[k]])
|
205 |
+
elif isinstance(c, list):
|
206 |
+
c_in = list()
|
207 |
+
assert isinstance(unconditional_conditioning, list)
|
208 |
+
for i in range(len(c)):
|
209 |
+
c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
|
210 |
+
else:
|
211 |
+
if append_to_context is not None:
|
212 |
+
pad_len = append_to_context.size(1)
|
213 |
+
new_unconditional_conditioning = torch.cat(
|
214 |
+
[unconditional_conditioning, unconditional_conditioning[:, -pad_len:, :]], dim=1)
|
215 |
+
new_c = torch.cat([c, append_to_context], dim=1)
|
216 |
+
c_in = torch.cat([new_unconditional_conditioning, new_c])
|
217 |
+
else:
|
218 |
+
c_in = torch.cat([unconditional_conditioning, c])
|
219 |
+
model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in, features_adapter=features_adapter).chunk(2)
|
220 |
+
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
|
221 |
+
|
222 |
+
if self.model.parameterization == "v":
|
223 |
+
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
|
224 |
+
else:
|
225 |
+
e_t = model_output
|
226 |
|
227 |
if score_corrector is not None:
|
228 |
+
assert self.model.parameterization == "eps", 'not implemented'
|
229 |
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
230 |
|
231 |
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
|
|
236 |
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
237 |
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
238 |
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
239 |
+
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device)
|
240 |
|
241 |
# current prediction for x_0
|
242 |
+
if self.model.parameterization != "v":
|
243 |
+
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
244 |
+
else:
|
245 |
+
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
|
246 |
+
|
247 |
if quantize_denoised:
|
248 |
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
249 |
# direction pointing to x_t
|
250 |
+
dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
|
251 |
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
252 |
if noise_dropout > 0.:
|
253 |
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
|
|
289 |
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
|
290 |
unconditional_guidance_scale=unconditional_guidance_scale,
|
291 |
unconditional_conditioning=unconditional_conditioning)
|
292 |
+
return x_dec
|
ldm/models/diffusion/ddpm.py
CHANGED
@@ -12,16 +12,18 @@ import numpy as np
|
|
12 |
import pytorch_lightning as pl
|
13 |
from torch.optim.lr_scheduler import LambdaLR
|
14 |
from einops import rearrange, repeat
|
15 |
-
from contextlib import contextmanager
|
16 |
from functools import partial
|
|
|
17 |
from tqdm import tqdm
|
18 |
from torchvision.utils import make_grid
|
19 |
from pytorch_lightning.utilities.distributed import rank_zero_only
|
|
|
20 |
|
21 |
from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
|
22 |
from ldm.modules.ema import LitEma
|
23 |
from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
|
24 |
-
from ldm.models.autoencoder import
|
25 |
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
|
26 |
from ldm.models.diffusion.ddim import DDIMSampler
|
27 |
|
@@ -71,9 +73,13 @@ class DDPM(pl.LightningModule):
|
|
71 |
use_positional_encodings=False,
|
72 |
learn_logvar=False,
|
73 |
logvar_init=0.,
|
|
|
|
|
|
|
|
|
74 |
):
|
75 |
super().__init__()
|
76 |
-
assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
|
77 |
self.parameterization = parameterization
|
78 |
print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
|
79 |
self.cond_stage_model = None
|
@@ -100,8 +106,18 @@ class DDPM(pl.LightningModule):
|
|
100 |
|
101 |
if monitor is not None:
|
102 |
self.monitor = monitor
|
|
|
|
|
103 |
if ckpt_path is not None:
|
104 |
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
|
107 |
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
|
@@ -113,6 +129,9 @@ class DDPM(pl.LightningModule):
|
|
113 |
if self.learn_logvar:
|
114 |
self.logvar = nn.Parameter(self.logvar, requires_grad=True)
|
115 |
|
|
|
|
|
|
|
116 |
|
117 |
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
118 |
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
@@ -146,7 +165,7 @@ class DDPM(pl.LightningModule):
|
|
146 |
|
147 |
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
148 |
posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
|
149 |
-
|
150 |
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
151 |
self.register_buffer('posterior_variance', to_torch(posterior_variance))
|
152 |
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
@@ -158,12 +177,14 @@ class DDPM(pl.LightningModule):
|
|
158 |
|
159 |
if self.parameterization == "eps":
|
160 |
lvlb_weights = self.betas ** 2 / (
|
161 |
-
|
162 |
elif self.parameterization == "x0":
|
163 |
lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
|
|
|
|
|
|
|
164 |
else:
|
165 |
raise NotImplementedError("mu not supported")
|
166 |
-
# TODO how to choose this term
|
167 |
lvlb_weights[0] = lvlb_weights[1]
|
168 |
self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
|
169 |
assert not torch.isnan(self.lvlb_weights).all()
|
@@ -183,6 +204,7 @@ class DDPM(pl.LightningModule):
|
|
183 |
if context is not None:
|
184 |
print(f"{context}: Restored training weights")
|
185 |
|
|
|
186 |
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
|
187 |
sd = torch.load(path, map_location="cpu")
|
188 |
if "state_dict" in list(sd.keys()):
|
@@ -193,13 +215,57 @@ class DDPM(pl.LightningModule):
|
|
193 |
if k.startswith(ik):
|
194 |
print("Deleting key {} from state_dict.".format(k))
|
195 |
del sd[k]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
|
197 |
sd, strict=False)
|
198 |
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
199 |
if len(missing) > 0:
|
200 |
-
print(f"Missing Keys
|
201 |
if len(unexpected) > 0:
|
202 |
-
print(f"
|
203 |
|
204 |
def q_mean_variance(self, x_start, t):
|
205 |
"""
|
@@ -219,6 +285,20 @@ class DDPM(pl.LightningModule):
|
|
219 |
extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
220 |
)
|
221 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
def q_posterior(self, x_start, x_t, t):
|
223 |
posterior_mean = (
|
224 |
extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
@@ -276,6 +356,12 @@ class DDPM(pl.LightningModule):
|
|
276 |
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
277 |
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
|
278 |
|
|
|
|
|
|
|
|
|
|
|
|
|
279 |
def get_loss(self, pred, target, mean=True):
|
280 |
if self.loss_type == 'l1':
|
281 |
loss = (target - pred).abs()
|
@@ -301,6 +387,8 @@ class DDPM(pl.LightningModule):
|
|
301 |
target = noise
|
302 |
elif self.parameterization == "x0":
|
303 |
target = x_start
|
|
|
|
|
304 |
else:
|
305 |
raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported")
|
306 |
|
@@ -328,10 +416,10 @@ class DDPM(pl.LightningModule):
|
|
328 |
|
329 |
def get_input(self, batch, k):
|
330 |
x = batch[k]
|
331 |
-
if len(x.shape) == 3:
|
332 |
-
|
333 |
-
x = rearrange(x, 'b h w c -> b c h w')
|
334 |
-
x = x.to(memory_format=torch.contiguous_format).float()
|
335 |
return x
|
336 |
|
337 |
def shared_step(self, batch):
|
@@ -421,41 +509,12 @@ class DDPM(pl.LightningModule):
|
|
421 |
return opt
|
422 |
|
423 |
|
424 |
-
class DiffusionWrapper(pl.LightningModule):
|
425 |
-
def __init__(self, diff_model_config, conditioning_key):
|
426 |
-
super().__init__()
|
427 |
-
self.diffusion_model = instantiate_from_config(diff_model_config)
|
428 |
-
self.conditioning_key = conditioning_key
|
429 |
-
assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm']
|
430 |
-
|
431 |
-
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, features_adapter=None):
|
432 |
-
if self.conditioning_key is None:
|
433 |
-
out = self.diffusion_model(x, t, features_adapter=features_adapter)
|
434 |
-
elif self.conditioning_key == 'concat':
|
435 |
-
xc = torch.cat([x] + c_concat, dim=1)
|
436 |
-
out = self.diffusion_model(xc, t, features_adapter=features_adapter)
|
437 |
-
elif self.conditioning_key == 'crossattn':
|
438 |
-
cc = torch.cat(c_crossattn, 1)
|
439 |
-
out = self.diffusion_model(x, t, context=cc, features_adapter=features_adapter)
|
440 |
-
elif self.conditioning_key == 'hybrid':
|
441 |
-
xc = torch.cat([x] + c_concat, dim=1)
|
442 |
-
cc = torch.cat(c_crossattn, 1)
|
443 |
-
out = self.diffusion_model(xc, t, context=cc, features_adapter=features_adapter)
|
444 |
-
elif self.conditioning_key == 'adm':
|
445 |
-
cc = c_crossattn[0]
|
446 |
-
out = self.diffusion_model(x, t, y=cc, features_adapter=features_adapter)
|
447 |
-
else:
|
448 |
-
raise NotImplementedError()
|
449 |
-
|
450 |
-
return out
|
451 |
-
|
452 |
-
|
453 |
class LatentDiffusion(DDPM):
|
454 |
"""main class"""
|
|
|
455 |
def __init__(self,
|
456 |
first_stage_config,
|
457 |
cond_stage_config,
|
458 |
-
unet_config,
|
459 |
num_timesteps_cond=None,
|
460 |
cond_stage_key="image",
|
461 |
cond_stage_trainable=False,
|
@@ -474,9 +533,10 @@ class LatentDiffusion(DDPM):
|
|
474 |
if cond_stage_config == '__is_unconditional__':
|
475 |
conditioning_key = None
|
476 |
ckpt_path = kwargs.pop("ckpt_path", None)
|
|
|
|
|
477 |
ignore_keys = kwargs.pop("ignore_keys", [])
|
478 |
-
super().__init__(conditioning_key=conditioning_key,
|
479 |
-
self.model = DiffusionWrapper(unet_config, conditioning_key)
|
480 |
self.concat_mode = concat_mode
|
481 |
self.cond_stage_trainable = cond_stage_trainable
|
482 |
self.cond_stage_key = cond_stage_key
|
@@ -492,35 +552,27 @@ class LatentDiffusion(DDPM):
|
|
492 |
self.instantiate_cond_stage(cond_stage_config)
|
493 |
self.cond_stage_forward = cond_stage_forward
|
494 |
self.clip_denoised = False
|
495 |
-
self.bbox_tokenizer = None
|
496 |
|
497 |
self.restarted_from_ckpt = False
|
498 |
if ckpt_path is not None:
|
499 |
self.init_from_ckpt(ckpt_path, ignore_keys)
|
500 |
self.restarted_from_ckpt = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
501 |
|
502 |
def make_cond_schedule(self, ):
|
503 |
self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
|
504 |
ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
|
505 |
self.cond_ids[:self.num_timesteps_cond] = ids
|
506 |
|
507 |
-
@rank_zero_only
|
508 |
-
@torch.no_grad()
|
509 |
-
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
|
510 |
-
# only for very first batch
|
511 |
-
if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
|
512 |
-
assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
|
513 |
-
# set rescale weight to 1./std of encodings
|
514 |
-
print("### USING STD-RESCALING ###")
|
515 |
-
x = super().get_input(batch, self.first_stage_key)
|
516 |
-
x = x.to(self.device)
|
517 |
-
encoder_posterior = self.encode_first_stage(x)
|
518 |
-
z = self.get_first_stage_encoding(encoder_posterior).detach()
|
519 |
-
del self.scale_factor
|
520 |
-
self.register_buffer('scale_factor', 1. / z.flatten().std())
|
521 |
-
print(f"setting self.scale_factor to {self.scale_factor}")
|
522 |
-
print("### USING STD-RESCALING ###")
|
523 |
-
|
524 |
def register_schedule(self,
|
525 |
given_betas=None, beta_schedule="linear", timesteps=1000,
|
526 |
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
@@ -562,7 +614,7 @@ class LatentDiffusion(DDPM):
|
|
562 |
denoise_row = []
|
563 |
for zd in tqdm(samples, desc=desc):
|
564 |
denoise_row.append(self.decode_first_stage(zd.to(self.device),
|
565 |
-
|
566 |
n_imgs_per_row = len(denoise_row)
|
567 |
denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
|
568 |
denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
|
@@ -695,9 +747,9 @@ class LatentDiffusion(DDPM):
|
|
695 |
if cond_key is None:
|
696 |
cond_key = self.cond_stage_key
|
697 |
if cond_key != self.first_stage_key:
|
698 |
-
if cond_key in ['caption', 'coordinates_bbox']:
|
699 |
xc = batch[cond_key]
|
700 |
-
elif cond_key
|
701 |
xc = batch
|
702 |
else:
|
703 |
xc = super().get_input(batch, cond_key).to(self.device)
|
@@ -742,181 +794,28 @@ class LatentDiffusion(DDPM):
|
|
742 |
z = rearrange(z, 'b h w c -> b c h w').contiguous()
|
743 |
|
744 |
z = 1. / self.scale_factor * z
|
745 |
-
|
746 |
-
if hasattr(self, "split_input_params"):
|
747 |
-
if self.split_input_params["patch_distributed_vq"]:
|
748 |
-
ks = self.split_input_params["ks"] # eg. (128, 128)
|
749 |
-
stride = self.split_input_params["stride"] # eg. (64, 64)
|
750 |
-
uf = self.split_input_params["vqf"]
|
751 |
-
bs, nc, h, w = z.shape
|
752 |
-
if ks[0] > h or ks[1] > w:
|
753 |
-
ks = (min(ks[0], h), min(ks[1], w))
|
754 |
-
print("reducing Kernel")
|
755 |
-
|
756 |
-
if stride[0] > h or stride[1] > w:
|
757 |
-
stride = (min(stride[0], h), min(stride[1], w))
|
758 |
-
print("reducing stride")
|
759 |
-
|
760 |
-
fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
|
761 |
-
|
762 |
-
z = unfold(z) # (bn, nc * prod(**ks), L)
|
763 |
-
# 1. Reshape to img shape
|
764 |
-
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
765 |
-
|
766 |
-
# 2. apply model loop over last dim
|
767 |
-
if isinstance(self.first_stage_model, VQModelInterface):
|
768 |
-
output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
|
769 |
-
force_not_quantize=predict_cids or force_not_quantize)
|
770 |
-
for i in range(z.shape[-1])]
|
771 |
-
else:
|
772 |
-
|
773 |
-
output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
|
774 |
-
for i in range(z.shape[-1])]
|
775 |
-
|
776 |
-
o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
|
777 |
-
o = o * weighting
|
778 |
-
# Reverse 1. reshape to img shape
|
779 |
-
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
|
780 |
-
# stitch crops together
|
781 |
-
decoded = fold(o)
|
782 |
-
decoded = decoded / normalization # norm is shape (1, 1, h, w)
|
783 |
-
return decoded
|
784 |
-
else:
|
785 |
-
if isinstance(self.first_stage_model, VQModelInterface):
|
786 |
-
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
|
787 |
-
else:
|
788 |
-
return self.first_stage_model.decode(z)
|
789 |
-
|
790 |
-
else:
|
791 |
-
if isinstance(self.first_stage_model, VQModelInterface):
|
792 |
-
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
|
793 |
-
else:
|
794 |
-
return self.first_stage_model.decode(z)
|
795 |
-
|
796 |
-
# same as above but without decorator
|
797 |
-
def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
|
798 |
-
if predict_cids:
|
799 |
-
if z.dim() == 4:
|
800 |
-
z = torch.argmax(z.exp(), dim=1).long()
|
801 |
-
z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
|
802 |
-
z = rearrange(z, 'b h w c -> b c h w').contiguous()
|
803 |
-
|
804 |
-
z = 1. / self.scale_factor * z
|
805 |
-
|
806 |
-
if hasattr(self, "split_input_params"):
|
807 |
-
if self.split_input_params["patch_distributed_vq"]:
|
808 |
-
ks = self.split_input_params["ks"] # eg. (128, 128)
|
809 |
-
stride = self.split_input_params["stride"] # eg. (64, 64)
|
810 |
-
uf = self.split_input_params["vqf"]
|
811 |
-
bs, nc, h, w = z.shape
|
812 |
-
if ks[0] > h or ks[1] > w:
|
813 |
-
ks = (min(ks[0], h), min(ks[1], w))
|
814 |
-
print("reducing Kernel")
|
815 |
-
|
816 |
-
if stride[0] > h or stride[1] > w:
|
817 |
-
stride = (min(stride[0], h), min(stride[1], w))
|
818 |
-
print("reducing stride")
|
819 |
-
|
820 |
-
fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
|
821 |
-
|
822 |
-
z = unfold(z) # (bn, nc * prod(**ks), L)
|
823 |
-
# 1. Reshape to img shape
|
824 |
-
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
825 |
-
|
826 |
-
# 2. apply model loop over last dim
|
827 |
-
if isinstance(self.first_stage_model, VQModelInterface):
|
828 |
-
output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
|
829 |
-
force_not_quantize=predict_cids or force_not_quantize)
|
830 |
-
for i in range(z.shape[-1])]
|
831 |
-
else:
|
832 |
-
|
833 |
-
output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
|
834 |
-
for i in range(z.shape[-1])]
|
835 |
-
|
836 |
-
o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
|
837 |
-
o = o * weighting
|
838 |
-
# Reverse 1. reshape to img shape
|
839 |
-
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
|
840 |
-
# stitch crops together
|
841 |
-
decoded = fold(o)
|
842 |
-
decoded = decoded / normalization # norm is shape (1, 1, h, w)
|
843 |
-
return decoded
|
844 |
-
else:
|
845 |
-
if isinstance(self.first_stage_model, VQModelInterface):
|
846 |
-
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
|
847 |
-
else:
|
848 |
-
return self.first_stage_model.decode(z)
|
849 |
-
|
850 |
-
else:
|
851 |
-
if isinstance(self.first_stage_model, VQModelInterface):
|
852 |
-
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
|
853 |
-
else:
|
854 |
-
return self.first_stage_model.decode(z)
|
855 |
|
856 |
@torch.no_grad()
|
857 |
def encode_first_stage(self, x):
|
858 |
-
|
859 |
-
if self.split_input_params["patch_distributed_vq"]:
|
860 |
-
ks = self.split_input_params["ks"] # eg. (128, 128)
|
861 |
-
stride = self.split_input_params["stride"] # eg. (64, 64)
|
862 |
-
df = self.split_input_params["vqf"]
|
863 |
-
self.split_input_params['original_image_size'] = x.shape[-2:]
|
864 |
-
bs, nc, h, w = x.shape
|
865 |
-
if ks[0] > h or ks[1] > w:
|
866 |
-
ks = (min(ks[0], h), min(ks[1], w))
|
867 |
-
print("reducing Kernel")
|
868 |
-
|
869 |
-
if stride[0] > h or stride[1] > w:
|
870 |
-
stride = (min(stride[0], h), min(stride[1], w))
|
871 |
-
print("reducing stride")
|
872 |
-
|
873 |
-
fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
|
874 |
-
z = unfold(x) # (bn, nc * prod(**ks), L)
|
875 |
-
# Reshape to img shape
|
876 |
-
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
877 |
-
|
878 |
-
output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
|
879 |
-
for i in range(z.shape[-1])]
|
880 |
-
|
881 |
-
o = torch.stack(output_list, axis=-1)
|
882 |
-
o = o * weighting
|
883 |
-
|
884 |
-
# Reverse reshape to img shape
|
885 |
-
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
|
886 |
-
# stitch crops together
|
887 |
-
decoded = fold(o)
|
888 |
-
decoded = decoded / normalization
|
889 |
-
return decoded
|
890 |
-
|
891 |
-
else:
|
892 |
-
return self.first_stage_model.encode(x)
|
893 |
-
else:
|
894 |
-
return self.first_stage_model.encode(x)
|
895 |
|
896 |
def shared_step(self, batch, **kwargs):
|
897 |
x, c = self.get_input(batch, self.first_stage_key)
|
898 |
-
loss = self(x, c)
|
899 |
return loss
|
900 |
|
901 |
-
def forward(self, x, c,
|
902 |
-
t
|
903 |
-
|
904 |
-
|
905 |
-
|
906 |
-
def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
|
907 |
-
def rescale_bbox(bbox):
|
908 |
-
x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
|
909 |
-
y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
|
910 |
-
w = min(bbox[2] / crop_coordinates[2], 1 - x0)
|
911 |
-
h = min(bbox[3] / crop_coordinates[3], 1 - y0)
|
912 |
-
return x0, y0, w, h
|
913 |
-
|
914 |
-
return [rescale_bbox(b) for b in bboxes]
|
915 |
|
916 |
-
|
917 |
|
|
|
918 |
if isinstance(cond, dict):
|
919 |
-
# hybrid case, cond is
|
920 |
pass
|
921 |
else:
|
922 |
if not isinstance(cond, list):
|
@@ -924,98 +823,7 @@ class LatentDiffusion(DDPM):
|
|
924 |
key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
|
925 |
cond = {key: cond}
|
926 |
|
927 |
-
|
928 |
-
assert len(cond) == 1 # todo can only deal with one conditioning atm
|
929 |
-
assert not return_ids
|
930 |
-
ks = self.split_input_params["ks"] # eg. (128, 128)
|
931 |
-
stride = self.split_input_params["stride"] # eg. (64, 64)
|
932 |
-
|
933 |
-
h, w = x_noisy.shape[-2:]
|
934 |
-
|
935 |
-
fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)
|
936 |
-
|
937 |
-
z = unfold(x_noisy) # (bn, nc * prod(**ks), L)
|
938 |
-
# Reshape to img shape
|
939 |
-
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
940 |
-
z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
|
941 |
-
|
942 |
-
if self.cond_stage_key in ["image", "LR_image", "segmentation",
|
943 |
-
'bbox_img'] and self.model.conditioning_key: # todo check for completeness
|
944 |
-
c_key = next(iter(cond.keys())) # get key
|
945 |
-
c = next(iter(cond.values())) # get value
|
946 |
-
assert (len(c) == 1) # todo extend to list with more than one elem
|
947 |
-
c = c[0] # get element
|
948 |
-
|
949 |
-
c = unfold(c)
|
950 |
-
c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
951 |
-
|
952 |
-
cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
|
953 |
-
|
954 |
-
elif self.cond_stage_key == 'coordinates_bbox':
|
955 |
-
assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
|
956 |
-
|
957 |
-
# assuming padding of unfold is always 0 and its dilation is always 1
|
958 |
-
n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
|
959 |
-
full_img_h, full_img_w = self.split_input_params['original_image_size']
|
960 |
-
# as we are operating on latents, we need the factor from the original image size to the
|
961 |
-
# spatial latent size to properly rescale the crops for regenerating the bbox annotations
|
962 |
-
num_downs = self.first_stage_model.encoder.num_resolutions - 1
|
963 |
-
rescale_latent = 2 ** (num_downs)
|
964 |
-
|
965 |
-
# get top left postions of patches as conforming for the bbbox tokenizer, therefore we
|
966 |
-
# need to rescale the tl patch coordinates to be in between (0,1)
|
967 |
-
tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
|
968 |
-
rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
|
969 |
-
for patch_nr in range(z.shape[-1])]
|
970 |
-
|
971 |
-
# patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)
|
972 |
-
patch_limits = [(x_tl, y_tl,
|
973 |
-
rescale_latent * ks[0] / full_img_w,
|
974 |
-
rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]
|
975 |
-
# patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]
|
976 |
-
|
977 |
-
# tokenize crop coordinates for the bounding boxes of the respective patches
|
978 |
-
patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)
|
979 |
-
for bbox in patch_limits] # list of length l with tensors of shape (1, 2)
|
980 |
-
print(patch_limits_tknzd[0].shape)
|
981 |
-
# cut tknzd crop position from conditioning
|
982 |
-
assert isinstance(cond, dict), 'cond must be dict to be fed into model'
|
983 |
-
cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)
|
984 |
-
print(cut_cond.shape)
|
985 |
-
|
986 |
-
adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])
|
987 |
-
adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')
|
988 |
-
print(adapted_cond.shape)
|
989 |
-
adapted_cond = self.get_learned_conditioning(adapted_cond)
|
990 |
-
print(adapted_cond.shape)
|
991 |
-
adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])
|
992 |
-
print(adapted_cond.shape)
|
993 |
-
|
994 |
-
cond_list = [{'c_crossattn': [e]} for e in adapted_cond]
|
995 |
-
|
996 |
-
else:
|
997 |
-
cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient
|
998 |
-
|
999 |
-
# apply model by loop over crops
|
1000 |
-
if features_adapter is not None:
|
1001 |
-
output_list = [self.model(z_list[i], t, **cond_list[i], features_adapter=features_adapter) for i in range(z.shape[-1])]
|
1002 |
-
else:
|
1003 |
-
output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]
|
1004 |
-
assert not isinstance(output_list[0],
|
1005 |
-
tuple) # todo cant deal with multiple model outputs check this never happens
|
1006 |
-
|
1007 |
-
o = torch.stack(output_list, axis=-1)
|
1008 |
-
o = o * weighting
|
1009 |
-
# Reverse reshape to img shape
|
1010 |
-
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
|
1011 |
-
# stitch crops together
|
1012 |
-
x_recon = fold(o) / normalization
|
1013 |
-
|
1014 |
-
else:
|
1015 |
-
if features_adapter is not None:
|
1016 |
-
x_recon = self.model(x_noisy, t, **cond, features_adapter=features_adapter)
|
1017 |
-
else:
|
1018 |
-
x_recon = self.model(x_noisy, t, **cond)
|
1019 |
|
1020 |
if isinstance(x_recon, tuple) and not return_ids:
|
1021 |
return x_recon[0]
|
@@ -1040,10 +848,10 @@ class LatentDiffusion(DDPM):
|
|
1040 |
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
|
1041 |
return mean_flat(kl_prior) / np.log(2.0)
|
1042 |
|
1043 |
-
def p_losses(self, x_start, cond, t,
|
1044 |
noise = default(noise, lambda: torch.randn_like(x_start))
|
1045 |
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
1046 |
-
model_output = self.apply_model(x_noisy, t, cond,
|
1047 |
|
1048 |
loss_dict = {}
|
1049 |
prefix = 'train' if self.training else 'val'
|
@@ -1052,6 +860,8 @@ class LatentDiffusion(DDPM):
|
|
1052 |
target = x_start
|
1053 |
elif self.parameterization == "eps":
|
1054 |
target = noise
|
|
|
|
|
1055 |
else:
|
1056 |
raise NotImplementedError()
|
1057 |
|
@@ -1247,7 +1057,7 @@ class LatentDiffusion(DDPM):
|
|
1247 |
@torch.no_grad()
|
1248 |
def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
|
1249 |
verbose=True, timesteps=None, quantize_denoised=False,
|
1250 |
-
mask=None, x0=None, shape=None
|
1251 |
if shape is None:
|
1252 |
shape = (batch_size, self.channels, self.image_size, self.image_size)
|
1253 |
if cond is not None:
|
@@ -1263,26 +1073,51 @@ class LatentDiffusion(DDPM):
|
|
1263 |
mask=mask, x0=x0)
|
1264 |
|
1265 |
@torch.no_grad()
|
1266 |
-
def sample_log(self,cond,batch_size,ddim, ddim_steps
|
1267 |
-
|
1268 |
if ddim:
|
1269 |
ddim_sampler = DDIMSampler(self)
|
1270 |
shape = (self.channels, self.image_size, self.image_size)
|
1271 |
-
samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
|
1272 |
-
|
1273 |
|
1274 |
else:
|
1275 |
samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
|
1276 |
-
return_intermediates=True
|
1277 |
|
1278 |
return samples, intermediates
|
1279 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1280 |
|
1281 |
@torch.no_grad()
|
1282 |
-
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=
|
1283 |
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
|
1284 |
-
plot_diffusion_rows=True,
|
1285 |
-
|
|
|
|
|
1286 |
use_ddim = ddim_steps is not None
|
1287 |
|
1288 |
log = dict()
|
@@ -1299,12 +1134,16 @@ class LatentDiffusion(DDPM):
|
|
1299 |
if hasattr(self.cond_stage_model, "decode"):
|
1300 |
xc = self.cond_stage_model.decode(c)
|
1301 |
log["conditioning"] = xc
|
1302 |
-
elif self.cond_stage_key in ["caption"]:
|
1303 |
-
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[
|
1304 |
log["conditioning"] = xc
|
1305 |
-
elif self.cond_stage_key
|
1306 |
-
|
1307 |
-
|
|
|
|
|
|
|
|
|
1308 |
elif isimage(xc):
|
1309 |
log["conditioning"] = xc
|
1310 |
if ismap(xc):
|
@@ -1330,9 +1169,9 @@ class LatentDiffusion(DDPM):
|
|
1330 |
|
1331 |
if sample:
|
1332 |
# get denoise row
|
1333 |
-
with
|
1334 |
-
samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
|
1335 |
-
ddim_steps=ddim_steps,eta=ddim_eta)
|
1336 |
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
|
1337 |
x_samples = self.decode_first_stage(samples)
|
1338 |
log["samples"] = x_samples
|
@@ -1343,39 +1182,52 @@ class LatentDiffusion(DDPM):
|
|
1343 |
if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
|
1344 |
self.first_stage_model, IdentityFirstStage):
|
1345 |
# also display when quantizing x0 while sampling
|
1346 |
-
with
|
1347 |
-
samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
|
1348 |
-
ddim_steps=ddim_steps,eta=ddim_eta,
|
1349 |
quantize_denoised=True)
|
1350 |
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
|
1351 |
# quantize_denoised=True)
|
1352 |
x_samples = self.decode_first_stage(samples.to(self.device))
|
1353 |
log["samples_x0_quantized"] = x_samples
|
1354 |
|
1355 |
-
|
1356 |
-
|
1357 |
-
|
1358 |
-
|
1359 |
-
|
1360 |
-
|
1361 |
-
|
1362 |
-
|
1363 |
-
|
1364 |
-
|
1365 |
-
|
1366 |
-
|
1367 |
-
|
1368 |
-
|
1369 |
-
|
1370 |
-
|
1371 |
-
|
1372 |
-
|
1373 |
-
|
1374 |
-
|
1375 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1376 |
|
1377 |
if plot_progressive_rows:
|
1378 |
-
with
|
1379 |
img, progressives = self.progressive_denoising(c,
|
1380 |
shape=(self.channels, self.image_size, self.image_size),
|
1381 |
batch_size=N)
|
@@ -1422,25 +1274,40 @@ class LatentDiffusion(DDPM):
|
|
1422 |
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
|
1423 |
return x
|
1424 |
|
1425 |
-
class Layout2ImgDiffusion(LatentDiffusion):
|
1426 |
-
# TODO: move all layout-specific hacks to this class
|
1427 |
-
def __init__(self, cond_stage_key, *args, **kwargs):
|
1428 |
-
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
|
1429 |
-
super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
|
1430 |
-
|
1431 |
-
def log_images(self, batch, N=8, *args, **kwargs):
|
1432 |
-
logs = super().log_images(batch=batch, N=N, *args, **kwargs)
|
1433 |
|
1434 |
-
|
1435 |
-
|
1436 |
-
|
|
|
|
|
|
|
1437 |
|
1438 |
-
|
1439 |
-
|
1440 |
-
|
1441 |
-
|
1442 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1443 |
|
1444 |
-
|
1445 |
-
logs['bbox_image'] = cond_img
|
1446 |
-
return logs
|
|
|
12 |
import pytorch_lightning as pl
|
13 |
from torch.optim.lr_scheduler import LambdaLR
|
14 |
from einops import rearrange, repeat
|
15 |
+
from contextlib import contextmanager, nullcontext
|
16 |
from functools import partial
|
17 |
+
import itertools
|
18 |
from tqdm import tqdm
|
19 |
from torchvision.utils import make_grid
|
20 |
from pytorch_lightning.utilities.distributed import rank_zero_only
|
21 |
+
from omegaconf import ListConfig
|
22 |
|
23 |
from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
|
24 |
from ldm.modules.ema import LitEma
|
25 |
from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
|
26 |
+
from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL
|
27 |
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
|
28 |
from ldm.models.diffusion.ddim import DDIMSampler
|
29 |
|
|
|
73 |
use_positional_encodings=False,
|
74 |
learn_logvar=False,
|
75 |
logvar_init=0.,
|
76 |
+
make_it_fit=False,
|
77 |
+
ucg_training=None,
|
78 |
+
reset_ema=False,
|
79 |
+
reset_num_ema_updates=False,
|
80 |
):
|
81 |
super().__init__()
|
82 |
+
assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"'
|
83 |
self.parameterization = parameterization
|
84 |
print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
|
85 |
self.cond_stage_model = None
|
|
|
106 |
|
107 |
if monitor is not None:
|
108 |
self.monitor = monitor
|
109 |
+
self.make_it_fit = make_it_fit
|
110 |
+
if reset_ema: assert exists(ckpt_path)
|
111 |
if ckpt_path is not None:
|
112 |
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
|
113 |
+
if reset_ema:
|
114 |
+
assert self.use_ema
|
115 |
+
print(f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
|
116 |
+
self.model_ema = LitEma(self.model)
|
117 |
+
if reset_num_ema_updates:
|
118 |
+
print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
|
119 |
+
assert self.use_ema
|
120 |
+
self.model_ema.reset_num_updates()
|
121 |
|
122 |
self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
|
123 |
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
|
|
|
129 |
if self.learn_logvar:
|
130 |
self.logvar = nn.Parameter(self.logvar, requires_grad=True)
|
131 |
|
132 |
+
self.ucg_training = ucg_training or dict()
|
133 |
+
if self.ucg_training:
|
134 |
+
self.ucg_prng = np.random.RandomState()
|
135 |
|
136 |
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
137 |
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
|
|
165 |
|
166 |
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
167 |
posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
|
168 |
+
1. - alphas_cumprod) + self.v_posterior * betas
|
169 |
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
170 |
self.register_buffer('posterior_variance', to_torch(posterior_variance))
|
171 |
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
|
|
177 |
|
178 |
if self.parameterization == "eps":
|
179 |
lvlb_weights = self.betas ** 2 / (
|
180 |
+
2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
|
181 |
elif self.parameterization == "x0":
|
182 |
lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
|
183 |
+
elif self.parameterization == "v":
|
184 |
+
lvlb_weights = torch.ones_like(self.betas ** 2 / (
|
185 |
+
2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)))
|
186 |
else:
|
187 |
raise NotImplementedError("mu not supported")
|
|
|
188 |
lvlb_weights[0] = lvlb_weights[1]
|
189 |
self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
|
190 |
assert not torch.isnan(self.lvlb_weights).all()
|
|
|
204 |
if context is not None:
|
205 |
print(f"{context}: Restored training weights")
|
206 |
|
207 |
+
@torch.no_grad()
|
208 |
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
|
209 |
sd = torch.load(path, map_location="cpu")
|
210 |
if "state_dict" in list(sd.keys()):
|
|
|
215 |
if k.startswith(ik):
|
216 |
print("Deleting key {} from state_dict.".format(k))
|
217 |
del sd[k]
|
218 |
+
if self.make_it_fit:
|
219 |
+
n_params = len([name for name, _ in
|
220 |
+
itertools.chain(self.named_parameters(),
|
221 |
+
self.named_buffers())])
|
222 |
+
for name, param in tqdm(
|
223 |
+
itertools.chain(self.named_parameters(),
|
224 |
+
self.named_buffers()),
|
225 |
+
desc="Fitting old weights to new weights",
|
226 |
+
total=n_params
|
227 |
+
):
|
228 |
+
if not name in sd:
|
229 |
+
continue
|
230 |
+
old_shape = sd[name].shape
|
231 |
+
new_shape = param.shape
|
232 |
+
assert len(old_shape) == len(new_shape)
|
233 |
+
if len(new_shape) > 2:
|
234 |
+
# we only modify first two axes
|
235 |
+
assert new_shape[2:] == old_shape[2:]
|
236 |
+
# assumes first axis corresponds to output dim
|
237 |
+
if not new_shape == old_shape:
|
238 |
+
new_param = param.clone()
|
239 |
+
old_param = sd[name]
|
240 |
+
if len(new_shape) == 1:
|
241 |
+
for i in range(new_param.shape[0]):
|
242 |
+
new_param[i] = old_param[i % old_shape[0]]
|
243 |
+
elif len(new_shape) >= 2:
|
244 |
+
for i in range(new_param.shape[0]):
|
245 |
+
for j in range(new_param.shape[1]):
|
246 |
+
new_param[i, j] = old_param[i % old_shape[0], j % old_shape[1]]
|
247 |
+
|
248 |
+
n_used_old = torch.ones(old_shape[1])
|
249 |
+
for j in range(new_param.shape[1]):
|
250 |
+
n_used_old[j % old_shape[1]] += 1
|
251 |
+
n_used_new = torch.zeros(new_shape[1])
|
252 |
+
for j in range(new_param.shape[1]):
|
253 |
+
n_used_new[j] = n_used_old[j % old_shape[1]]
|
254 |
+
|
255 |
+
n_used_new = n_used_new[None, :]
|
256 |
+
while len(n_used_new.shape) < len(new_shape):
|
257 |
+
n_used_new = n_used_new.unsqueeze(-1)
|
258 |
+
new_param /= n_used_new
|
259 |
+
|
260 |
+
sd[name] = new_param
|
261 |
+
|
262 |
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
|
263 |
sd, strict=False)
|
264 |
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
265 |
if len(missing) > 0:
|
266 |
+
print(f"Missing Keys:\n {missing}")
|
267 |
if len(unexpected) > 0:
|
268 |
+
print(f"\nUnexpected Keys:\n {unexpected}")
|
269 |
|
270 |
def q_mean_variance(self, x_start, t):
|
271 |
"""
|
|
|
285 |
extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
286 |
)
|
287 |
|
288 |
+
def predict_start_from_z_and_v(self, x_t, t, v):
|
289 |
+
# self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
|
290 |
+
# self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
|
291 |
+
return (
|
292 |
+
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
|
293 |
+
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
|
294 |
+
)
|
295 |
+
|
296 |
+
def predict_eps_from_z_and_v(self, x_t, t, v):
|
297 |
+
return (
|
298 |
+
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v +
|
299 |
+
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t
|
300 |
+
)
|
301 |
+
|
302 |
def q_posterior(self, x_start, x_t, t):
|
303 |
posterior_mean = (
|
304 |
extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
|
|
356 |
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
357 |
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
|
358 |
|
359 |
+
def get_v(self, x, noise, t):
|
360 |
+
return (
|
361 |
+
extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
|
362 |
+
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
|
363 |
+
)
|
364 |
+
|
365 |
def get_loss(self, pred, target, mean=True):
|
366 |
if self.loss_type == 'l1':
|
367 |
loss = (target - pred).abs()
|
|
|
387 |
target = noise
|
388 |
elif self.parameterization == "x0":
|
389 |
target = x_start
|
390 |
+
elif self.parameterization == "v":
|
391 |
+
target = self.get_v(x_start, noise, t)
|
392 |
else:
|
393 |
raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported")
|
394 |
|
|
|
416 |
|
417 |
def get_input(self, batch, k):
|
418 |
x = batch[k]
|
419 |
+
# if len(x.shape) == 3:
|
420 |
+
# x = x[..., None]
|
421 |
+
# x = rearrange(x, 'b h w c -> b c h w')
|
422 |
+
# x = x.to(memory_format=torch.contiguous_format).float()
|
423 |
return x
|
424 |
|
425 |
def shared_step(self, batch):
|
|
|
509 |
return opt
|
510 |
|
511 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
512 |
class LatentDiffusion(DDPM):
|
513 |
"""main class"""
|
514 |
+
|
515 |
def __init__(self,
|
516 |
first_stage_config,
|
517 |
cond_stage_config,
|
|
|
518 |
num_timesteps_cond=None,
|
519 |
cond_stage_key="image",
|
520 |
cond_stage_trainable=False,
|
|
|
533 |
if cond_stage_config == '__is_unconditional__':
|
534 |
conditioning_key = None
|
535 |
ckpt_path = kwargs.pop("ckpt_path", None)
|
536 |
+
reset_ema = kwargs.pop("reset_ema", False)
|
537 |
+
reset_num_ema_updates = kwargs.pop("reset_num_ema_updates", False)
|
538 |
ignore_keys = kwargs.pop("ignore_keys", [])
|
539 |
+
super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
|
|
|
540 |
self.concat_mode = concat_mode
|
541 |
self.cond_stage_trainable = cond_stage_trainable
|
542 |
self.cond_stage_key = cond_stage_key
|
|
|
552 |
self.instantiate_cond_stage(cond_stage_config)
|
553 |
self.cond_stage_forward = cond_stage_forward
|
554 |
self.clip_denoised = False
|
555 |
+
self.bbox_tokenizer = None
|
556 |
|
557 |
self.restarted_from_ckpt = False
|
558 |
if ckpt_path is not None:
|
559 |
self.init_from_ckpt(ckpt_path, ignore_keys)
|
560 |
self.restarted_from_ckpt = True
|
561 |
+
if reset_ema:
|
562 |
+
assert self.use_ema
|
563 |
+
print(
|
564 |
+
f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
|
565 |
+
self.model_ema = LitEma(self.model)
|
566 |
+
if reset_num_ema_updates:
|
567 |
+
print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
|
568 |
+
assert self.use_ema
|
569 |
+
self.model_ema.reset_num_updates()
|
570 |
|
571 |
def make_cond_schedule(self, ):
|
572 |
self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
|
573 |
ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
|
574 |
self.cond_ids[:self.num_timesteps_cond] = ids
|
575 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
576 |
def register_schedule(self,
|
577 |
given_betas=None, beta_schedule="linear", timesteps=1000,
|
578 |
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
|
|
614 |
denoise_row = []
|
615 |
for zd in tqdm(samples, desc=desc):
|
616 |
denoise_row.append(self.decode_first_stage(zd.to(self.device),
|
617 |
+
force_not_quantize=force_no_decoder_quantization))
|
618 |
n_imgs_per_row = len(denoise_row)
|
619 |
denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
|
620 |
denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
|
|
|
747 |
if cond_key is None:
|
748 |
cond_key = self.cond_stage_key
|
749 |
if cond_key != self.first_stage_key:
|
750 |
+
if cond_key in ['caption', 'coordinates_bbox', "txt"]:
|
751 |
xc = batch[cond_key]
|
752 |
+
elif cond_key in ['class_label', 'cls']:
|
753 |
xc = batch
|
754 |
else:
|
755 |
xc = super().get_input(batch, cond_key).to(self.device)
|
|
|
794 |
z = rearrange(z, 'b h w c -> b c h w').contiguous()
|
795 |
|
796 |
z = 1. / self.scale_factor * z
|
797 |
+
return self.first_stage_model.decode(z)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
798 |
|
799 |
@torch.no_grad()
|
800 |
def encode_first_stage(self, x):
|
801 |
+
return self.first_stage_model.encode(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
802 |
|
803 |
def shared_step(self, batch, **kwargs):
|
804 |
x, c = self.get_input(batch, self.first_stage_key)
|
805 |
+
loss = self(x, c, **kwargs)
|
806 |
return loss
|
807 |
|
808 |
+
def forward(self, x, c, *args, **kwargs):
|
809 |
+
if 't' not in kwargs:
|
810 |
+
t = torch.randint(0, self.num_timesteps, (x.shape[0], ), device=self.device).long()
|
811 |
+
else:
|
812 |
+
t = kwargs.pop('t')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
813 |
|
814 |
+
return self.p_losses(x, c, t, *args, **kwargs)
|
815 |
|
816 |
+
def apply_model(self, x_noisy, t, cond, return_ids=False, **kwargs):
|
817 |
if isinstance(cond, dict):
|
818 |
+
# hybrid case, cond is expected to be a dict
|
819 |
pass
|
820 |
else:
|
821 |
if not isinstance(cond, list):
|
|
|
823 |
key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
|
824 |
cond = {key: cond}
|
825 |
|
826 |
+
x_recon = self.model(x_noisy, t, **cond, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
827 |
|
828 |
if isinstance(x_recon, tuple) and not return_ids:
|
829 |
return x_recon[0]
|
|
|
848 |
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
|
849 |
return mean_flat(kl_prior) / np.log(2.0)
|
850 |
|
851 |
+
def p_losses(self, x_start, cond, t, noise=None, **kwargs):
|
852 |
noise = default(noise, lambda: torch.randn_like(x_start))
|
853 |
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
854 |
+
model_output = self.apply_model(x_noisy, t, cond, **kwargs)
|
855 |
|
856 |
loss_dict = {}
|
857 |
prefix = 'train' if self.training else 'val'
|
|
|
860 |
target = x_start
|
861 |
elif self.parameterization == "eps":
|
862 |
target = noise
|
863 |
+
elif self.parameterization == "v":
|
864 |
+
target = self.get_v(x_start, noise, t)
|
865 |
else:
|
866 |
raise NotImplementedError()
|
867 |
|
|
|
1057 |
@torch.no_grad()
|
1058 |
def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
|
1059 |
verbose=True, timesteps=None, quantize_denoised=False,
|
1060 |
+
mask=None, x0=None, shape=None, **kwargs):
|
1061 |
if shape is None:
|
1062 |
shape = (batch_size, self.channels, self.image_size, self.image_size)
|
1063 |
if cond is not None:
|
|
|
1073 |
mask=mask, x0=x0)
|
1074 |
|
1075 |
@torch.no_grad()
|
1076 |
+
def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
|
|
|
1077 |
if ddim:
|
1078 |
ddim_sampler = DDIMSampler(self)
|
1079 |
shape = (self.channels, self.image_size, self.image_size)
|
1080 |
+
samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size,
|
1081 |
+
shape, cond, verbose=False, **kwargs)
|
1082 |
|
1083 |
else:
|
1084 |
samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
|
1085 |
+
return_intermediates=True, **kwargs)
|
1086 |
|
1087 |
return samples, intermediates
|
1088 |
|
1089 |
+
@torch.no_grad()
|
1090 |
+
def get_unconditional_conditioning(self, batch_size, null_label=None):
|
1091 |
+
if null_label is not None:
|
1092 |
+
xc = null_label
|
1093 |
+
if isinstance(xc, ListConfig):
|
1094 |
+
xc = list(xc)
|
1095 |
+
if isinstance(xc, dict) or isinstance(xc, list):
|
1096 |
+
c = self.get_learned_conditioning(xc)
|
1097 |
+
else:
|
1098 |
+
if hasattr(xc, "to"):
|
1099 |
+
xc = xc.to(self.device)
|
1100 |
+
c = self.get_learned_conditioning(xc)
|
1101 |
+
else:
|
1102 |
+
if self.cond_stage_key in ["class_label", "cls"]:
|
1103 |
+
xc = self.cond_stage_model.get_unconditional_conditioning(batch_size, device=self.device)
|
1104 |
+
return self.get_learned_conditioning(xc)
|
1105 |
+
else:
|
1106 |
+
raise NotImplementedError("todo")
|
1107 |
+
if isinstance(c, list): # in case the encoder gives us a list
|
1108 |
+
for i in range(len(c)):
|
1109 |
+
c[i] = repeat(c[i], '1 ... -> b ...', b=batch_size).to(self.device)
|
1110 |
+
else:
|
1111 |
+
c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device)
|
1112 |
+
return c
|
1113 |
|
1114 |
@torch.no_grad()
|
1115 |
+
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0., return_keys=None,
|
1116 |
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
|
1117 |
+
plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
|
1118 |
+
use_ema_scope=True,
|
1119 |
+
**kwargs):
|
1120 |
+
ema_scope = self.ema_scope if use_ema_scope else nullcontext
|
1121 |
use_ddim = ddim_steps is not None
|
1122 |
|
1123 |
log = dict()
|
|
|
1134 |
if hasattr(self.cond_stage_model, "decode"):
|
1135 |
xc = self.cond_stage_model.decode(c)
|
1136 |
log["conditioning"] = xc
|
1137 |
+
elif self.cond_stage_key in ["caption", "txt"]:
|
1138 |
+
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
|
1139 |
log["conditioning"] = xc
|
1140 |
+
elif self.cond_stage_key in ['class_label', "cls"]:
|
1141 |
+
try:
|
1142 |
+
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
|
1143 |
+
log['conditioning'] = xc
|
1144 |
+
except KeyError:
|
1145 |
+
# probably no "human_label" in batch
|
1146 |
+
pass
|
1147 |
elif isimage(xc):
|
1148 |
log["conditioning"] = xc
|
1149 |
if ismap(xc):
|
|
|
1169 |
|
1170 |
if sample:
|
1171 |
# get denoise row
|
1172 |
+
with ema_scope("Sampling"):
|
1173 |
+
samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
|
1174 |
+
ddim_steps=ddim_steps, eta=ddim_eta)
|
1175 |
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
|
1176 |
x_samples = self.decode_first_stage(samples)
|
1177 |
log["samples"] = x_samples
|
|
|
1182 |
if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
|
1183 |
self.first_stage_model, IdentityFirstStage):
|
1184 |
# also display when quantizing x0 while sampling
|
1185 |
+
with ema_scope("Plotting Quantized Denoised"):
|
1186 |
+
samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
|
1187 |
+
ddim_steps=ddim_steps, eta=ddim_eta,
|
1188 |
quantize_denoised=True)
|
1189 |
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
|
1190 |
# quantize_denoised=True)
|
1191 |
x_samples = self.decode_first_stage(samples.to(self.device))
|
1192 |
log["samples_x0_quantized"] = x_samples
|
1193 |
|
1194 |
+
if unconditional_guidance_scale > 1.0:
|
1195 |
+
uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
|
1196 |
+
if self.model.conditioning_key == "crossattn-adm":
|
1197 |
+
uc = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
|
1198 |
+
with ema_scope("Sampling with classifier-free guidance"):
|
1199 |
+
samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
|
1200 |
+
ddim_steps=ddim_steps, eta=ddim_eta,
|
1201 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
1202 |
+
unconditional_conditioning=uc,
|
1203 |
+
)
|
1204 |
+
x_samples_cfg = self.decode_first_stage(samples_cfg)
|
1205 |
+
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
|
1206 |
+
|
1207 |
+
if inpaint:
|
1208 |
+
# make a simple center square
|
1209 |
+
b, h, w = z.shape[0], z.shape[2], z.shape[3]
|
1210 |
+
mask = torch.ones(N, h, w).to(self.device)
|
1211 |
+
# zeros will be filled in
|
1212 |
+
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
|
1213 |
+
mask = mask[:, None, ...]
|
1214 |
+
with ema_scope("Plotting Inpaint"):
|
1215 |
+
samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
|
1216 |
+
ddim_steps=ddim_steps, x0=z[:N], mask=mask)
|
1217 |
+
x_samples = self.decode_first_stage(samples.to(self.device))
|
1218 |
+
log["samples_inpainting"] = x_samples
|
1219 |
+
log["mask"] = mask
|
1220 |
+
|
1221 |
+
# outpaint
|
1222 |
+
mask = 1. - mask
|
1223 |
+
with ema_scope("Plotting Outpaint"):
|
1224 |
+
samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
|
1225 |
+
ddim_steps=ddim_steps, x0=z[:N], mask=mask)
|
1226 |
+
x_samples = self.decode_first_stage(samples.to(self.device))
|
1227 |
+
log["samples_outpainting"] = x_samples
|
1228 |
|
1229 |
if plot_progressive_rows:
|
1230 |
+
with ema_scope("Plotting Progressives"):
|
1231 |
img, progressives = self.progressive_denoising(c,
|
1232 |
shape=(self.channels, self.image_size, self.image_size),
|
1233 |
batch_size=N)
|
|
|
1274 |
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
|
1275 |
return x
|
1276 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1277 |
|
1278 |
+
class DiffusionWrapper(pl.LightningModule):
|
1279 |
+
def __init__(self, diff_model_config, conditioning_key):
|
1280 |
+
super().__init__()
|
1281 |
+
self.diffusion_model = instantiate_from_config(diff_model_config)
|
1282 |
+
self.conditioning_key = conditioning_key
|
1283 |
+
assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
|
1284 |
|
1285 |
+
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None, **kwargs):
|
1286 |
+
if self.conditioning_key is None:
|
1287 |
+
out = self.diffusion_model(x, t, **kwargs)
|
1288 |
+
elif self.conditioning_key == 'concat':
|
1289 |
+
xc = torch.cat([x] + c_concat, dim=1)
|
1290 |
+
out = self.diffusion_model(xc, t, **kwargs)
|
1291 |
+
elif self.conditioning_key == 'crossattn':
|
1292 |
+
cc = torch.cat(c_crossattn, 1)
|
1293 |
+
out = self.diffusion_model(x, t, context=cc, **kwargs)
|
1294 |
+
elif self.conditioning_key == 'hybrid':
|
1295 |
+
xc = torch.cat([x] + c_concat, dim=1)
|
1296 |
+
cc = torch.cat(c_crossattn, 1)
|
1297 |
+
out = self.diffusion_model(xc, t, context=cc, **kwargs)
|
1298 |
+
elif self.conditioning_key == 'hybrid-adm':
|
1299 |
+
assert c_adm is not None
|
1300 |
+
xc = torch.cat([x] + c_concat, dim=1)
|
1301 |
+
cc = torch.cat(c_crossattn, 1)
|
1302 |
+
out = self.diffusion_model(xc, t, context=cc, y=c_adm, **kwargs)
|
1303 |
+
elif self.conditioning_key == 'crossattn-adm':
|
1304 |
+
assert c_adm is not None
|
1305 |
+
cc = torch.cat(c_crossattn, 1)
|
1306 |
+
out = self.diffusion_model(x, t, context=cc, y=c_adm, **kwargs)
|
1307 |
+
elif self.conditioning_key == 'adm':
|
1308 |
+
cc = c_crossattn[0]
|
1309 |
+
out = self.diffusion_model(x, t, y=cc, **kwargs)
|
1310 |
+
else:
|
1311 |
+
raise NotImplementedError()
|
1312 |
|
1313 |
+
return out
|
|
|
|
ldm/models/diffusion/dpm_solver/dpm_solver.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import torch
|
2 |
import torch.nn.functional as F
|
3 |
import math
|
|
|
4 |
|
5 |
|
6 |
class NoiseScheduleVP:
|
@@ -11,7 +12,7 @@ class NoiseScheduleVP:
|
|
11 |
alphas_cumprod=None,
|
12 |
continuous_beta_0=0.1,
|
13 |
continuous_beta_1=20.,
|
14 |
-
|
15 |
"""Create a wrapper class for the forward SDE (VP type).
|
16 |
|
17 |
***
|
@@ -93,7 +94,9 @@ class NoiseScheduleVP:
|
|
93 |
"""
|
94 |
|
95 |
if schedule not in ['discrete', 'linear', 'cosine']:
|
96 |
-
raise ValueError(
|
|
|
|
|
97 |
|
98 |
self.schedule = schedule
|
99 |
if schedule == 'discrete':
|
@@ -112,7 +115,8 @@ class NoiseScheduleVP:
|
|
112 |
self.beta_1 = continuous_beta_1
|
113 |
self.cosine_s = 0.008
|
114 |
self.cosine_beta_max = 999.
|
115 |
-
self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (
|
|
|
116 |
self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
|
117 |
self.schedule = schedule
|
118 |
if schedule == 'cosine':
|
@@ -127,12 +131,13 @@ class NoiseScheduleVP:
|
|
127 |
Compute log(alpha_t) of a given continuous-time label t in [0, T].
|
128 |
"""
|
129 |
if self.schedule == 'discrete':
|
130 |
-
return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device),
|
|
|
131 |
elif self.schedule == 'linear':
|
132 |
return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
|
133 |
elif self.schedule == 'cosine':
|
134 |
log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
|
135 |
-
log_alpha_t =
|
136 |
return log_alpha_t
|
137 |
|
138 |
def marginal_alpha(self, t):
|
@@ -161,30 +166,32 @@ class NoiseScheduleVP:
|
|
161 |
"""
|
162 |
if self.schedule == 'linear':
|
163 |
tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
|
164 |
-
Delta = self.beta_0**2 + tmp
|
165 |
return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
|
166 |
elif self.schedule == 'discrete':
|
167 |
log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
|
168 |
-
t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]),
|
|
|
169 |
return t.reshape((-1,))
|
170 |
else:
|
171 |
log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
|
172 |
-
t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (
|
|
|
173 |
t = t_fn(log_alpha)
|
174 |
return t
|
175 |
|
176 |
|
177 |
def model_wrapper(
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
):
|
189 |
"""Create a wrapper function for the noise prediction model.
|
190 |
|
@@ -392,7 +399,7 @@ class DPM_Solver:
|
|
392 |
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
|
393 |
x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
|
394 |
if self.thresholding:
|
395 |
-
p = 0.995
|
396 |
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
397 |
s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
|
398 |
x0 = torch.clamp(x0, -s, s) / s
|
@@ -431,10 +438,11 @@ class DPM_Solver:
|
|
431 |
return torch.linspace(t_T, t_0, N + 1).to(device)
|
432 |
elif skip_type == 'time_quadratic':
|
433 |
t_order = 2
|
434 |
-
t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
|
435 |
return t
|
436 |
else:
|
437 |
-
raise ValueError(
|
|
|
438 |
|
439 |
def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
|
440 |
"""
|
@@ -471,28 +479,29 @@ class DPM_Solver:
|
|
471 |
if order == 3:
|
472 |
K = steps // 3 + 1
|
473 |
if steps % 3 == 0:
|
474 |
-
orders = [3,] * (K - 2) + [2, 1]
|
475 |
elif steps % 3 == 1:
|
476 |
-
orders = [3,] * (K - 1) + [1]
|
477 |
else:
|
478 |
-
orders = [3,] * (K - 1) + [2]
|
479 |
elif order == 2:
|
480 |
if steps % 2 == 0:
|
481 |
K = steps // 2
|
482 |
-
orders = [2,] * K
|
483 |
else:
|
484 |
K = steps // 2 + 1
|
485 |
-
orders = [2,] * (K - 1) + [1]
|
486 |
elif order == 1:
|
487 |
K = 1
|
488 |
-
orders = [1,] * steps
|
489 |
else:
|
490 |
raise ValueError("'order' must be '1' or '2' or '3'.")
|
491 |
if skip_type == 'logSNR':
|
492 |
# To reproduce the results in DPM-Solver paper
|
493 |
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
|
494 |
else:
|
495 |
-
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
|
|
|
496 |
return timesteps_outer, orders
|
497 |
|
498 |
def denoise_to_zero_fn(self, x, s):
|
@@ -528,8 +537,8 @@ class DPM_Solver:
|
|
528 |
if model_s is None:
|
529 |
model_s = self.model_fn(x, s)
|
530 |
x_t = (
|
531 |
-
|
532 |
-
|
533 |
)
|
534 |
if return_intermediate:
|
535 |
return x_t, {'model_s': model_s}
|
@@ -540,15 +549,16 @@ class DPM_Solver:
|
|
540 |
if model_s is None:
|
541 |
model_s = self.model_fn(x, s)
|
542 |
x_t = (
|
543 |
-
|
544 |
-
|
545 |
)
|
546 |
if return_intermediate:
|
547 |
return x_t, {'model_s': model_s}
|
548 |
else:
|
549 |
return x_t
|
550 |
|
551 |
-
def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False,
|
|
|
552 |
"""
|
553 |
Singlestep solver DPM-Solver-2 from time `s` to time `t`.
|
554 |
|
@@ -575,7 +585,8 @@ class DPM_Solver:
|
|
575 |
h = lambda_t - lambda_s
|
576 |
lambda_s1 = lambda_s + r1 * h
|
577 |
s1 = ns.inverse_lambda(lambda_s1)
|
578 |
-
log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(
|
|
|
579 |
sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
|
580 |
alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
|
581 |
|
@@ -586,21 +597,22 @@ class DPM_Solver:
|
|
586 |
if model_s is None:
|
587 |
model_s = self.model_fn(x, s)
|
588 |
x_s1 = (
|
589 |
-
|
590 |
-
|
591 |
)
|
592 |
model_s1 = self.model_fn(x_s1, s1)
|
593 |
if solver_type == 'dpm_solver':
|
594 |
x_t = (
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
)
|
599 |
elif solver_type == 'taylor':
|
600 |
x_t = (
|
601 |
-
|
602 |
-
|
603 |
-
|
|
|
604 |
)
|
605 |
else:
|
606 |
phi_11 = torch.expm1(r1 * h)
|
@@ -609,28 +621,29 @@ class DPM_Solver:
|
|
609 |
if model_s is None:
|
610 |
model_s = self.model_fn(x, s)
|
611 |
x_s1 = (
|
612 |
-
|
613 |
-
|
614 |
)
|
615 |
model_s1 = self.model_fn(x_s1, s1)
|
616 |
if solver_type == 'dpm_solver':
|
617 |
x_t = (
|
618 |
-
|
619 |
-
|
620 |
-
|
621 |
)
|
622 |
elif solver_type == 'taylor':
|
623 |
x_t = (
|
624 |
-
|
625 |
-
|
626 |
-
|
627 |
)
|
628 |
if return_intermediate:
|
629 |
return x_t, {'model_s': model_s, 'model_s1': model_s1}
|
630 |
else:
|
631 |
return x_t
|
632 |
|
633 |
-
def singlestep_dpm_solver_third_update(self, x, s, t, r1=1
|
|
|
634 |
"""
|
635 |
Singlestep solver DPM-Solver-3 from time `s` to time `t`.
|
636 |
|
@@ -664,8 +677,10 @@ class DPM_Solver:
|
|
664 |
lambda_s2 = lambda_s + r2 * h
|
665 |
s1 = ns.inverse_lambda(lambda_s1)
|
666 |
s2 = ns.inverse_lambda(lambda_s2)
|
667 |
-
log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(
|
668 |
-
|
|
|
|
|
669 |
alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
|
670 |
|
671 |
if self.predict_x0:
|
@@ -680,21 +695,21 @@ class DPM_Solver:
|
|
680 |
model_s = self.model_fn(x, s)
|
681 |
if model_s1 is None:
|
682 |
x_s1 = (
|
683 |
-
|
684 |
-
|
685 |
)
|
686 |
model_s1 = self.model_fn(x_s1, s1)
|
687 |
x_s2 = (
|
688 |
-
|
689 |
-
|
690 |
-
|
691 |
)
|
692 |
model_s2 = self.model_fn(x_s2, s2)
|
693 |
if solver_type == 'dpm_solver':
|
694 |
x_t = (
|
695 |
-
|
696 |
-
|
697 |
-
|
698 |
)
|
699 |
elif solver_type == 'taylor':
|
700 |
D1_0 = (1. / r1) * (model_s1 - model_s)
|
@@ -702,10 +717,10 @@ class DPM_Solver:
|
|
702 |
D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
|
703 |
D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
|
704 |
x_t = (
|
705 |
-
|
706 |
-
|
707 |
-
|
708 |
-
|
709 |
)
|
710 |
else:
|
711 |
phi_11 = torch.expm1(r1 * h)
|
@@ -719,21 +734,21 @@ class DPM_Solver:
|
|
719 |
model_s = self.model_fn(x, s)
|
720 |
if model_s1 is None:
|
721 |
x_s1 = (
|
722 |
-
|
723 |
-
|
724 |
)
|
725 |
model_s1 = self.model_fn(x_s1, s1)
|
726 |
x_s2 = (
|
727 |
-
|
728 |
-
|
729 |
-
|
730 |
)
|
731 |
model_s2 = self.model_fn(x_s2, s2)
|
732 |
if solver_type == 'dpm_solver':
|
733 |
x_t = (
|
734 |
-
|
735 |
-
|
736 |
-
|
737 |
)
|
738 |
elif solver_type == 'taylor':
|
739 |
D1_0 = (1. / r1) * (model_s1 - model_s)
|
@@ -741,10 +756,10 @@ class DPM_Solver:
|
|
741 |
D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
|
742 |
D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
|
743 |
x_t = (
|
744 |
-
|
745 |
-
|
746 |
-
|
747 |
-
|
748 |
)
|
749 |
|
750 |
if return_intermediate:
|
@@ -772,7 +787,8 @@ class DPM_Solver:
|
|
772 |
dims = x.dim()
|
773 |
model_prev_1, model_prev_0 = model_prev_list
|
774 |
t_prev_1, t_prev_0 = t_prev_list
|
775 |
-
lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(
|
|
|
776 |
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
777 |
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
778 |
alpha_t = torch.exp(log_alpha_t)
|
@@ -784,28 +800,28 @@ class DPM_Solver:
|
|
784 |
if self.predict_x0:
|
785 |
if solver_type == 'dpm_solver':
|
786 |
x_t = (
|
787 |
-
|
788 |
-
|
789 |
-
|
790 |
)
|
791 |
elif solver_type == 'taylor':
|
792 |
x_t = (
|
793 |
-
|
794 |
-
|
795 |
-
|
796 |
)
|
797 |
else:
|
798 |
if solver_type == 'dpm_solver':
|
799 |
x_t = (
|
800 |
-
|
801 |
-
|
802 |
-
|
803 |
)
|
804 |
elif solver_type == 'taylor':
|
805 |
x_t = (
|
806 |
-
|
807 |
-
|
808 |
-
|
809 |
)
|
810 |
return x_t
|
811 |
|
@@ -827,7 +843,8 @@ class DPM_Solver:
|
|
827 |
dims = x.dim()
|
828 |
model_prev_2, model_prev_1, model_prev_0 = model_prev_list
|
829 |
t_prev_2, t_prev_1, t_prev_0 = t_prev_list
|
830 |
-
lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(
|
|
|
831 |
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
832 |
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
833 |
alpha_t = torch.exp(log_alpha_t)
|
@@ -842,21 +859,22 @@ class DPM_Solver:
|
|
842 |
D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1)
|
843 |
if self.predict_x0:
|
844 |
x_t = (
|
845 |
-
|
846 |
-
|
847 |
-
|
848 |
-
|
849 |
)
|
850 |
else:
|
851 |
x_t = (
|
852 |
-
|
853 |
-
|
854 |
-
|
855 |
-
|
856 |
)
|
857 |
return x_t
|
858 |
|
859 |
-
def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None,
|
|
|
860 |
"""
|
861 |
Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
|
862 |
|
@@ -876,9 +894,11 @@ class DPM_Solver:
|
|
876 |
if order == 1:
|
877 |
return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
|
878 |
elif order == 2:
|
879 |
-
return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate,
|
|
|
880 |
elif order == 3:
|
881 |
-
return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate,
|
|
|
882 |
else:
|
883 |
raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
|
884 |
|
@@ -906,7 +926,8 @@ class DPM_Solver:
|
|
906 |
else:
|
907 |
raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
|
908 |
|
909 |
-
def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5,
|
|
|
910 |
"""
|
911 |
The adaptive step size solver based on singlestep DPM-Solver.
|
912 |
|
@@ -938,11 +959,17 @@ class DPM_Solver:
|
|
938 |
if order == 2:
|
939 |
r1 = 0.5
|
940 |
lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
|
941 |
-
higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
|
|
|
|
|
942 |
elif order == 3:
|
943 |
r1, r2 = 1. / 3., 2. / 3.
|
944 |
-
lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
|
945 |
-
|
|
|
|
|
|
|
|
|
946 |
else:
|
947 |
raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
|
948 |
while torch.abs((s - t_0)).mean() > t_err:
|
@@ -963,9 +990,9 @@ class DPM_Solver:
|
|
963 |
return x
|
964 |
|
965 |
def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
|
966 |
-
|
967 |
-
|
968 |
-
|
969 |
"""
|
970 |
Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
|
971 |
|
@@ -1073,7 +1100,8 @@ class DPM_Solver:
|
|
1073 |
device = x.device
|
1074 |
if method == 'adaptive':
|
1075 |
with torch.no_grad():
|
1076 |
-
x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol,
|
|
|
1077 |
elif method == 'multistep':
|
1078 |
assert steps >= order
|
1079 |
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
|
@@ -1083,19 +1111,21 @@ class DPM_Solver:
|
|
1083 |
model_prev_list = [self.model_fn(x, vec_t)]
|
1084 |
t_prev_list = [vec_t]
|
1085 |
# Init the first `order` values by lower order multistep DPM-Solver.
|
1086 |
-
for init_order in range(1, order):
|
1087 |
vec_t = timesteps[init_order].expand(x.shape[0])
|
1088 |
-
x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order,
|
|
|
1089 |
model_prev_list.append(self.model_fn(x, vec_t))
|
1090 |
t_prev_list.append(vec_t)
|
1091 |
# Compute the remaining values by `order`-th order multistep DPM-Solver.
|
1092 |
-
for step in range(order, steps + 1):
|
1093 |
vec_t = timesteps[step].expand(x.shape[0])
|
1094 |
if lower_order_final and steps < 15:
|
1095 |
step_order = min(order, steps + 1 - step)
|
1096 |
else:
|
1097 |
step_order = order
|
1098 |
-
x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order,
|
|
|
1099 |
for i in range(order - 1):
|
1100 |
t_prev_list[i] = t_prev_list[i + 1]
|
1101 |
model_prev_list[i] = model_prev_list[i + 1]
|
@@ -1105,14 +1135,18 @@ class DPM_Solver:
|
|
1105 |
model_prev_list[-1] = self.model_fn(x, vec_t)
|
1106 |
elif method in ['singlestep', 'singlestep_fixed']:
|
1107 |
if method == 'singlestep':
|
1108 |
-
timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order,
|
|
|
|
|
|
|
1109 |
elif method == 'singlestep_fixed':
|
1110 |
K = steps // order
|
1111 |
-
orders = [order,] * K
|
1112 |
timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
|
1113 |
for i, order in enumerate(orders):
|
1114 |
t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1]
|
1115 |
-
timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(),
|
|
|
1116 |
lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
|
1117 |
vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0])
|
1118 |
h = lambda_inner[-1] - lambda_inner[0]
|
@@ -1124,7 +1158,6 @@ class DPM_Solver:
|
|
1124 |
return x
|
1125 |
|
1126 |
|
1127 |
-
|
1128 |
#############################################################
|
1129 |
# other utility functions
|
1130 |
#############################################################
|
@@ -1181,4 +1214,4 @@ def expand_dims(v, dims):
|
|
1181 |
Returns:
|
1182 |
a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
|
1183 |
"""
|
1184 |
-
return v[(...,) + (None,)*(dims - 1)]
|
|
|
1 |
import torch
|
2 |
import torch.nn.functional as F
|
3 |
import math
|
4 |
+
from tqdm import tqdm
|
5 |
|
6 |
|
7 |
class NoiseScheduleVP:
|
|
|
12 |
alphas_cumprod=None,
|
13 |
continuous_beta_0=0.1,
|
14 |
continuous_beta_1=20.,
|
15 |
+
):
|
16 |
"""Create a wrapper class for the forward SDE (VP type).
|
17 |
|
18 |
***
|
|
|
94 |
"""
|
95 |
|
96 |
if schedule not in ['discrete', 'linear', 'cosine']:
|
97 |
+
raise ValueError(
|
98 |
+
"Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
|
99 |
+
schedule))
|
100 |
|
101 |
self.schedule = schedule
|
102 |
if schedule == 'discrete':
|
|
|
115 |
self.beta_1 = continuous_beta_1
|
116 |
self.cosine_s = 0.008
|
117 |
self.cosine_beta_max = 999.
|
118 |
+
self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (
|
119 |
+
1. + self.cosine_s) / math.pi - self.cosine_s
|
120 |
self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
|
121 |
self.schedule = schedule
|
122 |
if schedule == 'cosine':
|
|
|
131 |
Compute log(alpha_t) of a given continuous-time label t in [0, T].
|
132 |
"""
|
133 |
if self.schedule == 'discrete':
|
134 |
+
return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device),
|
135 |
+
self.log_alpha_array.to(t.device)).reshape((-1))
|
136 |
elif self.schedule == 'linear':
|
137 |
return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
|
138 |
elif self.schedule == 'cosine':
|
139 |
log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
|
140 |
+
log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
|
141 |
return log_alpha_t
|
142 |
|
143 |
def marginal_alpha(self, t):
|
|
|
166 |
"""
|
167 |
if self.schedule == 'linear':
|
168 |
tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
|
169 |
+
Delta = self.beta_0 ** 2 + tmp
|
170 |
return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
|
171 |
elif self.schedule == 'discrete':
|
172 |
log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
|
173 |
+
t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]),
|
174 |
+
torch.flip(self.t_array.to(lamb.device), [1]))
|
175 |
return t.reshape((-1,))
|
176 |
else:
|
177 |
log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
|
178 |
+
t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (
|
179 |
+
1. + self.cosine_s) / math.pi - self.cosine_s
|
180 |
t = t_fn(log_alpha)
|
181 |
return t
|
182 |
|
183 |
|
184 |
def model_wrapper(
|
185 |
+
model,
|
186 |
+
noise_schedule,
|
187 |
+
model_type="noise",
|
188 |
+
model_kwargs={},
|
189 |
+
guidance_type="uncond",
|
190 |
+
condition=None,
|
191 |
+
unconditional_condition=None,
|
192 |
+
guidance_scale=1.,
|
193 |
+
classifier_fn=None,
|
194 |
+
classifier_kwargs={},
|
195 |
):
|
196 |
"""Create a wrapper function for the noise prediction model.
|
197 |
|
|
|
399 |
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
|
400 |
x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
|
401 |
if self.thresholding:
|
402 |
+
p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
|
403 |
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
404 |
s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
|
405 |
x0 = torch.clamp(x0, -s, s) / s
|
|
|
438 |
return torch.linspace(t_T, t_0, N + 1).to(device)
|
439 |
elif skip_type == 'time_quadratic':
|
440 |
t_order = 2
|
441 |
+
t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device)
|
442 |
return t
|
443 |
else:
|
444 |
+
raise ValueError(
|
445 |
+
"Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
|
446 |
|
447 |
def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
|
448 |
"""
|
|
|
479 |
if order == 3:
|
480 |
K = steps // 3 + 1
|
481 |
if steps % 3 == 0:
|
482 |
+
orders = [3, ] * (K - 2) + [2, 1]
|
483 |
elif steps % 3 == 1:
|
484 |
+
orders = [3, ] * (K - 1) + [1]
|
485 |
else:
|
486 |
+
orders = [3, ] * (K - 1) + [2]
|
487 |
elif order == 2:
|
488 |
if steps % 2 == 0:
|
489 |
K = steps // 2
|
490 |
+
orders = [2, ] * K
|
491 |
else:
|
492 |
K = steps // 2 + 1
|
493 |
+
orders = [2, ] * (K - 1) + [1]
|
494 |
elif order == 1:
|
495 |
K = 1
|
496 |
+
orders = [1, ] * steps
|
497 |
else:
|
498 |
raise ValueError("'order' must be '1' or '2' or '3'.")
|
499 |
if skip_type == 'logSNR':
|
500 |
# To reproduce the results in DPM-Solver paper
|
501 |
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
|
502 |
else:
|
503 |
+
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
|
504 |
+
torch.cumsum(torch.tensor([0, ] + orders)).to(device)]
|
505 |
return timesteps_outer, orders
|
506 |
|
507 |
def denoise_to_zero_fn(self, x, s):
|
|
|
537 |
if model_s is None:
|
538 |
model_s = self.model_fn(x, s)
|
539 |
x_t = (
|
540 |
+
expand_dims(sigma_t / sigma_s, dims) * x
|
541 |
+
- expand_dims(alpha_t * phi_1, dims) * model_s
|
542 |
)
|
543 |
if return_intermediate:
|
544 |
return x_t, {'model_s': model_s}
|
|
|
549 |
if model_s is None:
|
550 |
model_s = self.model_fn(x, s)
|
551 |
x_t = (
|
552 |
+
expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
|
553 |
+
- expand_dims(sigma_t * phi_1, dims) * model_s
|
554 |
)
|
555 |
if return_intermediate:
|
556 |
return x_t, {'model_s': model_s}
|
557 |
else:
|
558 |
return x_t
|
559 |
|
560 |
+
def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False,
|
561 |
+
solver_type='dpm_solver'):
|
562 |
"""
|
563 |
Singlestep solver DPM-Solver-2 from time `s` to time `t`.
|
564 |
|
|
|
585 |
h = lambda_t - lambda_s
|
586 |
lambda_s1 = lambda_s + r1 * h
|
587 |
s1 = ns.inverse_lambda(lambda_s1)
|
588 |
+
log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(
|
589 |
+
s1), ns.marginal_log_mean_coeff(t)
|
590 |
sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
|
591 |
alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
|
592 |
|
|
|
597 |
if model_s is None:
|
598 |
model_s = self.model_fn(x, s)
|
599 |
x_s1 = (
|
600 |
+
expand_dims(sigma_s1 / sigma_s, dims) * x
|
601 |
+
- expand_dims(alpha_s1 * phi_11, dims) * model_s
|
602 |
)
|
603 |
model_s1 = self.model_fn(x_s1, s1)
|
604 |
if solver_type == 'dpm_solver':
|
605 |
x_t = (
|
606 |
+
expand_dims(sigma_t / sigma_s, dims) * x
|
607 |
+
- expand_dims(alpha_t * phi_1, dims) * model_s
|
608 |
+
- (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s)
|
609 |
)
|
610 |
elif solver_type == 'taylor':
|
611 |
x_t = (
|
612 |
+
expand_dims(sigma_t / sigma_s, dims) * x
|
613 |
+
- expand_dims(alpha_t * phi_1, dims) * model_s
|
614 |
+
+ (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * (
|
615 |
+
model_s1 - model_s)
|
616 |
)
|
617 |
else:
|
618 |
phi_11 = torch.expm1(r1 * h)
|
|
|
621 |
if model_s is None:
|
622 |
model_s = self.model_fn(x, s)
|
623 |
x_s1 = (
|
624 |
+
expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
|
625 |
+
- expand_dims(sigma_s1 * phi_11, dims) * model_s
|
626 |
)
|
627 |
model_s1 = self.model_fn(x_s1, s1)
|
628 |
if solver_type == 'dpm_solver':
|
629 |
x_t = (
|
630 |
+
expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
|
631 |
+
- expand_dims(sigma_t * phi_1, dims) * model_s
|
632 |
+
- (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s)
|
633 |
)
|
634 |
elif solver_type == 'taylor':
|
635 |
x_t = (
|
636 |
+
expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
|
637 |
+
- expand_dims(sigma_t * phi_1, dims) * model_s
|
638 |
+
- (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s)
|
639 |
)
|
640 |
if return_intermediate:
|
641 |
return x_t, {'model_s': model_s, 'model_s1': model_s1}
|
642 |
else:
|
643 |
return x_t
|
644 |
|
645 |
+
def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None,
|
646 |
+
return_intermediate=False, solver_type='dpm_solver'):
|
647 |
"""
|
648 |
Singlestep solver DPM-Solver-3 from time `s` to time `t`.
|
649 |
|
|
|
677 |
lambda_s2 = lambda_s + r2 * h
|
678 |
s1 = ns.inverse_lambda(lambda_s1)
|
679 |
s2 = ns.inverse_lambda(lambda_s2)
|
680 |
+
log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(
|
681 |
+
s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
|
682 |
+
sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(
|
683 |
+
s2), ns.marginal_std(t)
|
684 |
alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
|
685 |
|
686 |
if self.predict_x0:
|
|
|
695 |
model_s = self.model_fn(x, s)
|
696 |
if model_s1 is None:
|
697 |
x_s1 = (
|
698 |
+
expand_dims(sigma_s1 / sigma_s, dims) * x
|
699 |
+
- expand_dims(alpha_s1 * phi_11, dims) * model_s
|
700 |
)
|
701 |
model_s1 = self.model_fn(x_s1, s1)
|
702 |
x_s2 = (
|
703 |
+
expand_dims(sigma_s2 / sigma_s, dims) * x
|
704 |
+
- expand_dims(alpha_s2 * phi_12, dims) * model_s
|
705 |
+
+ r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s)
|
706 |
)
|
707 |
model_s2 = self.model_fn(x_s2, s2)
|
708 |
if solver_type == 'dpm_solver':
|
709 |
x_t = (
|
710 |
+
expand_dims(sigma_t / sigma_s, dims) * x
|
711 |
+
- expand_dims(alpha_t * phi_1, dims) * model_s
|
712 |
+
+ (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s)
|
713 |
)
|
714 |
elif solver_type == 'taylor':
|
715 |
D1_0 = (1. / r1) * (model_s1 - model_s)
|
|
|
717 |
D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
|
718 |
D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
|
719 |
x_t = (
|
720 |
+
expand_dims(sigma_t / sigma_s, dims) * x
|
721 |
+
- expand_dims(alpha_t * phi_1, dims) * model_s
|
722 |
+
+ expand_dims(alpha_t * phi_2, dims) * D1
|
723 |
+
- expand_dims(alpha_t * phi_3, dims) * D2
|
724 |
)
|
725 |
else:
|
726 |
phi_11 = torch.expm1(r1 * h)
|
|
|
734 |
model_s = self.model_fn(x, s)
|
735 |
if model_s1 is None:
|
736 |
x_s1 = (
|
737 |
+
expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
|
738 |
+
- expand_dims(sigma_s1 * phi_11, dims) * model_s
|
739 |
)
|
740 |
model_s1 = self.model_fn(x_s1, s1)
|
741 |
x_s2 = (
|
742 |
+
expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x
|
743 |
+
- expand_dims(sigma_s2 * phi_12, dims) * model_s
|
744 |
+
- r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s)
|
745 |
)
|
746 |
model_s2 = self.model_fn(x_s2, s2)
|
747 |
if solver_type == 'dpm_solver':
|
748 |
x_t = (
|
749 |
+
expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
|
750 |
+
- expand_dims(sigma_t * phi_1, dims) * model_s
|
751 |
+
- (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s)
|
752 |
)
|
753 |
elif solver_type == 'taylor':
|
754 |
D1_0 = (1. / r1) * (model_s1 - model_s)
|
|
|
756 |
D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
|
757 |
D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
|
758 |
x_t = (
|
759 |
+
expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
|
760 |
+
- expand_dims(sigma_t * phi_1, dims) * model_s
|
761 |
+
- expand_dims(sigma_t * phi_2, dims) * D1
|
762 |
+
- expand_dims(sigma_t * phi_3, dims) * D2
|
763 |
)
|
764 |
|
765 |
if return_intermediate:
|
|
|
787 |
dims = x.dim()
|
788 |
model_prev_1, model_prev_0 = model_prev_list
|
789 |
t_prev_1, t_prev_0 = t_prev_list
|
790 |
+
lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(
|
791 |
+
t_prev_0), ns.marginal_lambda(t)
|
792 |
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
793 |
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
794 |
alpha_t = torch.exp(log_alpha_t)
|
|
|
800 |
if self.predict_x0:
|
801 |
if solver_type == 'dpm_solver':
|
802 |
x_t = (
|
803 |
+
expand_dims(sigma_t / sigma_prev_0, dims) * x
|
804 |
+
- expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
|
805 |
+
- 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0
|
806 |
)
|
807 |
elif solver_type == 'taylor':
|
808 |
x_t = (
|
809 |
+
expand_dims(sigma_t / sigma_prev_0, dims) * x
|
810 |
+
- expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
|
811 |
+
+ expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0
|
812 |
)
|
813 |
else:
|
814 |
if solver_type == 'dpm_solver':
|
815 |
x_t = (
|
816 |
+
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
|
817 |
+
- expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
|
818 |
+
- 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0
|
819 |
)
|
820 |
elif solver_type == 'taylor':
|
821 |
x_t = (
|
822 |
+
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
|
823 |
+
- expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
|
824 |
+
- expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0
|
825 |
)
|
826 |
return x_t
|
827 |
|
|
|
843 |
dims = x.dim()
|
844 |
model_prev_2, model_prev_1, model_prev_0 = model_prev_list
|
845 |
t_prev_2, t_prev_1, t_prev_0 = t_prev_list
|
846 |
+
lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(
|
847 |
+
t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
|
848 |
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
849 |
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
850 |
alpha_t = torch.exp(log_alpha_t)
|
|
|
859 |
D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1)
|
860 |
if self.predict_x0:
|
861 |
x_t = (
|
862 |
+
expand_dims(sigma_t / sigma_prev_0, dims) * x
|
863 |
+
- expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
|
864 |
+
+ expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1
|
865 |
+
- expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h ** 2 - 0.5), dims) * D2
|
866 |
)
|
867 |
else:
|
868 |
x_t = (
|
869 |
+
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
|
870 |
+
- expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
|
871 |
+
- expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1
|
872 |
+
- expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h ** 2 - 0.5), dims) * D2
|
873 |
)
|
874 |
return x_t
|
875 |
|
876 |
+
def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None,
|
877 |
+
r2=None):
|
878 |
"""
|
879 |
Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
|
880 |
|
|
|
894 |
if order == 1:
|
895 |
return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
|
896 |
elif order == 2:
|
897 |
+
return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate,
|
898 |
+
solver_type=solver_type, r1=r1)
|
899 |
elif order == 3:
|
900 |
+
return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate,
|
901 |
+
solver_type=solver_type, r1=r1, r2=r2)
|
902 |
else:
|
903 |
raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
|
904 |
|
|
|
926 |
else:
|
927 |
raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
|
928 |
|
929 |
+
def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5,
|
930 |
+
solver_type='dpm_solver'):
|
931 |
"""
|
932 |
The adaptive step size solver based on singlestep DPM-Solver.
|
933 |
|
|
|
959 |
if order == 2:
|
960 |
r1 = 0.5
|
961 |
lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
|
962 |
+
higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
|
963 |
+
solver_type=solver_type,
|
964 |
+
**kwargs)
|
965 |
elif order == 3:
|
966 |
r1, r2 = 1. / 3., 2. / 3.
|
967 |
+
lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
|
968 |
+
return_intermediate=True,
|
969 |
+
solver_type=solver_type)
|
970 |
+
higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2,
|
971 |
+
solver_type=solver_type,
|
972 |
+
**kwargs)
|
973 |
else:
|
974 |
raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
|
975 |
while torch.abs((s - t_0)).mean() > t_err:
|
|
|
990 |
return x
|
991 |
|
992 |
def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
|
993 |
+
method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
|
994 |
+
atol=0.0078, rtol=0.05,
|
995 |
+
):
|
996 |
"""
|
997 |
Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
|
998 |
|
|
|
1100 |
device = x.device
|
1101 |
if method == 'adaptive':
|
1102 |
with torch.no_grad():
|
1103 |
+
x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol,
|
1104 |
+
solver_type=solver_type)
|
1105 |
elif method == 'multistep':
|
1106 |
assert steps >= order
|
1107 |
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
|
|
|
1111 |
model_prev_list = [self.model_fn(x, vec_t)]
|
1112 |
t_prev_list = [vec_t]
|
1113 |
# Init the first `order` values by lower order multistep DPM-Solver.
|
1114 |
+
for init_order in tqdm(range(1, order), desc="DPM init order"):
|
1115 |
vec_t = timesteps[init_order].expand(x.shape[0])
|
1116 |
+
x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order,
|
1117 |
+
solver_type=solver_type)
|
1118 |
model_prev_list.append(self.model_fn(x, vec_t))
|
1119 |
t_prev_list.append(vec_t)
|
1120 |
# Compute the remaining values by `order`-th order multistep DPM-Solver.
|
1121 |
+
for step in tqdm(range(order, steps + 1), desc="DPM multistep"):
|
1122 |
vec_t = timesteps[step].expand(x.shape[0])
|
1123 |
if lower_order_final and steps < 15:
|
1124 |
step_order = min(order, steps + 1 - step)
|
1125 |
else:
|
1126 |
step_order = order
|
1127 |
+
x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order,
|
1128 |
+
solver_type=solver_type)
|
1129 |
for i in range(order - 1):
|
1130 |
t_prev_list[i] = t_prev_list[i + 1]
|
1131 |
model_prev_list[i] = model_prev_list[i + 1]
|
|
|
1135 |
model_prev_list[-1] = self.model_fn(x, vec_t)
|
1136 |
elif method in ['singlestep', 'singlestep_fixed']:
|
1137 |
if method == 'singlestep':
|
1138 |
+
timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order,
|
1139 |
+
skip_type=skip_type,
|
1140 |
+
t_T=t_T, t_0=t_0,
|
1141 |
+
device=device)
|
1142 |
elif method == 'singlestep_fixed':
|
1143 |
K = steps // order
|
1144 |
+
orders = [order, ] * K
|
1145 |
timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
|
1146 |
for i, order in enumerate(orders):
|
1147 |
t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1]
|
1148 |
+
timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(),
|
1149 |
+
N=order, device=device)
|
1150 |
lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
|
1151 |
vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0])
|
1152 |
h = lambda_inner[-1] - lambda_inner[0]
|
|
|
1158 |
return x
|
1159 |
|
1160 |
|
|
|
1161 |
#############################################################
|
1162 |
# other utility functions
|
1163 |
#############################################################
|
|
|
1214 |
Returns:
|
1215 |
a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
|
1216 |
"""
|
1217 |
+
return v[(...,) + (None,) * (dims - 1)]
|
ldm/models/diffusion/dpm_solver/sampler.py
CHANGED
@@ -1,10 +1,15 @@
|
|
1 |
"""SAMPLING ONLY."""
|
2 |
-
|
3 |
import torch
|
4 |
|
5 |
from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
|
6 |
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
class DPMSolverSampler(object):
|
9 |
def __init__(self, model, **kwargs):
|
10 |
super().__init__()
|
@@ -56,7 +61,7 @@ class DPMSolverSampler(object):
|
|
56 |
C, H, W = shape
|
57 |
size = (batch_size, C, H, W)
|
58 |
|
59 |
-
|
60 |
|
61 |
device = self.model.betas.device
|
62 |
if x_T is None:
|
@@ -69,7 +74,7 @@ class DPMSolverSampler(object):
|
|
69 |
model_fn = model_wrapper(
|
70 |
lambda x, t, c: self.model.apply_model(x, t, c),
|
71 |
ns,
|
72 |
-
model_type=
|
73 |
guidance_type="classifier-free",
|
74 |
condition=conditioning,
|
75 |
unconditional_condition=unconditional_conditioning,
|
|
|
1 |
"""SAMPLING ONLY."""
|
|
|
2 |
import torch
|
3 |
|
4 |
from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
|
5 |
|
6 |
|
7 |
+
MODEL_TYPES = {
|
8 |
+
"eps": "noise",
|
9 |
+
"v": "v"
|
10 |
+
}
|
11 |
+
|
12 |
+
|
13 |
class DPMSolverSampler(object):
|
14 |
def __init__(self, model, **kwargs):
|
15 |
super().__init__()
|
|
|
61 |
C, H, W = shape
|
62 |
size = (batch_size, C, H, W)
|
63 |
|
64 |
+
print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
|
65 |
|
66 |
device = self.model.betas.device
|
67 |
if x_T is None:
|
|
|
74 |
model_fn = model_wrapper(
|
75 |
lambda x, t, c: self.model.apply_model(x, t, c),
|
76 |
ns,
|
77 |
+
model_type=MODEL_TYPES[self.model.parameterization],
|
78 |
guidance_type="classifier-free",
|
79 |
condition=conditioning,
|
80 |
unconditional_condition=unconditional_conditioning,
|
ldm/models/diffusion/plms.py
CHANGED
@@ -3,10 +3,9 @@
|
|
3 |
import torch
|
4 |
import numpy as np
|
5 |
from tqdm import tqdm
|
6 |
-
from functools import partial
|
7 |
-
import copy
|
8 |
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
9 |
|
|
|
10 |
class PLMSSampler(object):
|
11 |
def __init__(self, model, schedule="linear", **kwargs):
|
12 |
super().__init__()
|
@@ -24,7 +23,7 @@ class PLMSSampler(object):
|
|
24 |
if ddim_eta != 0:
|
25 |
raise ValueError('ddim_eta must be 0 for PLMS')
|
26 |
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
27 |
-
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
|
28 |
alphas_cumprod = self.model.alphas_cumprod
|
29 |
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
30 |
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
@@ -43,14 +42,14 @@ class PLMSSampler(object):
|
|
43 |
# ddim sampling parameters
|
44 |
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
45 |
ddim_timesteps=self.ddim_timesteps,
|
46 |
-
eta=ddim_eta,verbose=verbose)
|
47 |
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
48 |
self.register_buffer('ddim_alphas', ddim_alphas)
|
49 |
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
50 |
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
51 |
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
52 |
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
53 |
-
|
54 |
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
55 |
|
56 |
@torch.no_grad()
|
@@ -75,11 +74,8 @@ class PLMSSampler(object):
|
|
75 |
log_every_t=100,
|
76 |
unconditional_guidance_scale=1.,
|
77 |
unconditional_conditioning=None,
|
78 |
-
|
79 |
-
|
80 |
-
mode = 'sketch',
|
81 |
-
con_strength=30,
|
82 |
-
style_feature=None,
|
83 |
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
84 |
**kwargs
|
85 |
):
|
@@ -113,11 +109,8 @@ class PLMSSampler(object):
|
|
113 |
log_every_t=log_every_t,
|
114 |
unconditional_guidance_scale=unconditional_guidance_scale,
|
115 |
unconditional_conditioning=unconditional_conditioning,
|
116 |
-
|
117 |
-
|
118 |
-
mode = mode,
|
119 |
-
con_strength = con_strength,
|
120 |
-
style_feature=style_feature#.clone()
|
121 |
)
|
122 |
return samples, intermediates
|
123 |
|
@@ -127,7 +120,8 @@ class PLMSSampler(object):
|
|
127 |
callback=None, timesteps=None, quantize_denoised=False,
|
128 |
mask=None, x0=None, img_callback=None, log_every_t=100,
|
129 |
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
130 |
-
unconditional_guidance_scale=1., unconditional_conditioning=None,
|
|
|
131 |
device = self.model.betas.device
|
132 |
b = shape[0]
|
133 |
if x_T is None:
|
@@ -141,7 +135,7 @@ class PLMSSampler(object):
|
|
141 |
timesteps = self.ddim_timesteps[:subset_end]
|
142 |
|
143 |
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
144 |
-
time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
|
145 |
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
146 |
print(f"Running PLMS Sampling with {total_steps} timesteps")
|
147 |
|
@@ -152,41 +146,21 @@ class PLMSSampler(object):
|
|
152 |
index = total_steps - i - 1
|
153 |
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
154 |
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
|
155 |
-
cond_in = cond
|
156 |
-
unconditional_conditioning_in = unconditional_conditioning
|
157 |
|
158 |
-
if mask is not None
|
159 |
assert x0 is not None
|
160 |
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
161 |
img = img_orig * mask + (1. - mask) * img
|
162 |
|
163 |
-
|
164 |
-
if index<con_strength:
|
165 |
-
features_adapter = None
|
166 |
-
else:
|
167 |
-
features_adapter = features_adapter1
|
168 |
-
elif mode == 'style':
|
169 |
-
if index<con_strength:
|
170 |
-
features_adapter = None
|
171 |
-
else:
|
172 |
-
features_adapter = features_adapter1
|
173 |
-
|
174 |
-
if index>25:
|
175 |
-
cond_in = torch.cat([cond, style_feature.clone()], dim=1)
|
176 |
-
unconditional_conditioning_in = torch.cat(
|
177 |
-
[unconditional_conditioning, unconditional_conditioning[:, -8:, :]], dim=1)
|
178 |
-
elif mode == 'mul':
|
179 |
-
features_adapter = [a1i*0.5 + a2i for a1i, a2i in zip(features_adapter1, features_adapter2)]
|
180 |
-
else:
|
181 |
-
features_adapter = features_adapter1
|
182 |
-
|
183 |
-
outs = self.p_sample_plms(img, cond_in, ts, index=index, use_original_steps=ddim_use_original_steps,
|
184 |
quantize_denoised=quantize_denoised, temperature=temperature,
|
185 |
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
186 |
corrector_kwargs=corrector_kwargs,
|
187 |
unconditional_guidance_scale=unconditional_guidance_scale,
|
188 |
-
unconditional_conditioning=
|
189 |
-
old_eps=old_eps, t_next=ts_next,
|
|
|
|
|
190 |
|
191 |
img, pred_x0, e_t = outs
|
192 |
old_eps.append(e_t)
|
@@ -204,17 +178,18 @@ class PLMSSampler(object):
|
|
204 |
@torch.no_grad()
|
205 |
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
206 |
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
207 |
-
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
|
|
|
208 |
b, *_, device = *x.shape, x.device
|
209 |
|
210 |
def get_model_output(x, t):
|
211 |
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
212 |
-
e_t = self.model.apply_model(x, t, c,
|
213 |
else:
|
214 |
x_in = torch.cat([x] * 2)
|
215 |
t_in = torch.cat([t] * 2)
|
216 |
c_in = torch.cat([unconditional_conditioning, c])
|
217 |
-
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in,
|
218 |
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
219 |
|
220 |
if score_corrector is not None:
|
@@ -233,14 +208,14 @@ class PLMSSampler(object):
|
|
233 |
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
234 |
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
235 |
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
236 |
-
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
237 |
|
238 |
# current prediction for x_0
|
239 |
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
240 |
if quantize_denoised:
|
241 |
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
242 |
# direction pointing to x_t
|
243 |
-
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
244 |
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
245 |
if noise_dropout > 0.:
|
246 |
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
|
|
3 |
import torch
|
4 |
import numpy as np
|
5 |
from tqdm import tqdm
|
|
|
|
|
6 |
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
7 |
|
8 |
+
|
9 |
class PLMSSampler(object):
|
10 |
def __init__(self, model, schedule="linear", **kwargs):
|
11 |
super().__init__()
|
|
|
23 |
if ddim_eta != 0:
|
24 |
raise ValueError('ddim_eta must be 0 for PLMS')
|
25 |
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
26 |
+
num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose)
|
27 |
alphas_cumprod = self.model.alphas_cumprod
|
28 |
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
29 |
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
|
|
42 |
# ddim sampling parameters
|
43 |
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
44 |
ddim_timesteps=self.ddim_timesteps,
|
45 |
+
eta=ddim_eta, verbose=verbose)
|
46 |
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
47 |
self.register_buffer('ddim_alphas', ddim_alphas)
|
48 |
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
49 |
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
50 |
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
51 |
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
52 |
+
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
53 |
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
54 |
|
55 |
@torch.no_grad()
|
|
|
74 |
log_every_t=100,
|
75 |
unconditional_guidance_scale=1.,
|
76 |
unconditional_conditioning=None,
|
77 |
+
features_adapter=None,
|
78 |
+
cond_tau=0.4,
|
|
|
|
|
|
|
79 |
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
80 |
**kwargs
|
81 |
):
|
|
|
109 |
log_every_t=log_every_t,
|
110 |
unconditional_guidance_scale=unconditional_guidance_scale,
|
111 |
unconditional_conditioning=unconditional_conditioning,
|
112 |
+
features_adapter=features_adapter,
|
113 |
+
cond_tau=cond_tau
|
|
|
|
|
|
|
114 |
)
|
115 |
return samples, intermediates
|
116 |
|
|
|
120 |
callback=None, timesteps=None, quantize_denoised=False,
|
121 |
mask=None, x0=None, img_callback=None, log_every_t=100,
|
122 |
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
123 |
+
unconditional_guidance_scale=1., unconditional_conditioning=None, features_adapter=None,
|
124 |
+
cond_tau=0.4):
|
125 |
device = self.model.betas.device
|
126 |
b = shape[0]
|
127 |
if x_T is None:
|
|
|
135 |
timesteps = self.ddim_timesteps[:subset_end]
|
136 |
|
137 |
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
138 |
+
time_range = list(reversed(range(0, timesteps))) if ddim_use_original_steps else np.flip(timesteps)
|
139 |
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
140 |
print(f"Running PLMS Sampling with {total_steps} timesteps")
|
141 |
|
|
|
146 |
index = total_steps - i - 1
|
147 |
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
148 |
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
|
|
|
|
|
149 |
|
150 |
+
if mask is not None: # and index>=10:
|
151 |
assert x0 is not None
|
152 |
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
153 |
img = img_orig * mask + (1. - mask) * img
|
154 |
|
155 |
+
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
quantize_denoised=quantize_denoised, temperature=temperature,
|
157 |
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
158 |
corrector_kwargs=corrector_kwargs,
|
159 |
unconditional_guidance_scale=unconditional_guidance_scale,
|
160 |
+
unconditional_conditioning=unconditional_conditioning,
|
161 |
+
old_eps=old_eps, t_next=ts_next,
|
162 |
+
features_adapter=None if index < int(
|
163 |
+
(1 - cond_tau) * total_steps) else features_adapter)
|
164 |
|
165 |
img, pred_x0, e_t = outs
|
166 |
old_eps.append(e_t)
|
|
|
178 |
@torch.no_grad()
|
179 |
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
180 |
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
181 |
+
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
|
182 |
+
features_adapter=None):
|
183 |
b, *_, device = *x.shape, x.device
|
184 |
|
185 |
def get_model_output(x, t):
|
186 |
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
187 |
+
e_t = self.model.apply_model(x, t, c, features_adapter=features_adapter)
|
188 |
else:
|
189 |
x_in = torch.cat([x] * 2)
|
190 |
t_in = torch.cat([t] * 2)
|
191 |
c_in = torch.cat([unconditional_conditioning, c])
|
192 |
+
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in, features_adapter=features_adapter).chunk(2)
|
193 |
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
194 |
|
195 |
if score_corrector is not None:
|
|
|
208 |
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
209 |
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
210 |
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
211 |
+
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device)
|
212 |
|
213 |
# current prediction for x_0
|
214 |
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
215 |
if quantize_denoised:
|
216 |
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
217 |
# direction pointing to x_t
|
218 |
+
dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
|
219 |
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
220 |
if noise_dropout > 0.:
|
221 |
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
ldm/modules/attention.py
CHANGED
@@ -20,6 +20,10 @@ except:
|
|
20 |
import os
|
21 |
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
|
22 |
|
|
|
|
|
|
|
|
|
23 |
def exists(val):
|
24 |
return val is not None
|
25 |
|
|
|
20 |
import os
|
21 |
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
|
22 |
|
23 |
+
if os.environ.get("DISABLE_XFORMERS", "false").lower() == 'true':
|
24 |
+
XFORMERS_IS_AVAILBLE = False
|
25 |
+
|
26 |
+
|
27 |
def exists(val):
|
28 |
return val is not None
|
29 |
|
ldm/modules/diffusionmodules/openaimodel.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
from abc import abstractmethod
|
2 |
-
from functools import partial
|
3 |
import math
|
4 |
-
|
5 |
|
6 |
import numpy as np
|
7 |
import torch as th
|
@@ -18,6 +17,7 @@ from ldm.modules.diffusionmodules.util import (
|
|
18 |
timestep_embedding,
|
19 |
)
|
20 |
from ldm.modules.attention import SpatialTransformer
|
|
|
21 |
|
22 |
|
23 |
# dummy replace
|
@@ -270,8 +270,6 @@ class ResBlock(TimestepBlock):
|
|
270 |
h = out_norm(h) * (1 + scale) + shift
|
271 |
h = out_rest(h)
|
272 |
else:
|
273 |
-
# print(h.shape, emb_out.shape)
|
274 |
-
# exit(0)
|
275 |
h = h + emb_out
|
276 |
h = self.out_layers(h)
|
277 |
return self.skip_connection(x) + h
|
@@ -468,16 +466,16 @@ class UNetModel(nn.Module):
|
|
468 |
context_dim=None, # custom transformer support
|
469 |
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
470 |
legacy=True,
|
471 |
-
|
|
|
|
|
|
|
472 |
):
|
473 |
super().__init__()
|
474 |
-
|
475 |
-
# print('UNet', context_dim)
|
476 |
if use_spatial_transformer:
|
477 |
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
478 |
|
479 |
if context_dim is not None:
|
480 |
-
# print('UNet not none', context_dim, context_dim is not None, context_dim != None, context_dim == "None")
|
481 |
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
482 |
from omegaconf.listconfig import ListConfig
|
483 |
if type(context_dim) == ListConfig:
|
@@ -496,7 +494,24 @@ class UNetModel(nn.Module):
|
|
496 |
self.in_channels = in_channels
|
497 |
self.model_channels = model_channels
|
498 |
self.out_channels = out_channels
|
499 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
500 |
self.attention_resolutions = attention_resolutions
|
501 |
self.dropout = dropout
|
502 |
self.channel_mult = channel_mult
|
@@ -508,9 +523,6 @@ class UNetModel(nn.Module):
|
|
508 |
self.num_head_channels = num_head_channels
|
509 |
self.num_heads_upsample = num_heads_upsample
|
510 |
self.predict_codebook_ids = n_embed is not None
|
511 |
-
# self.l_cond = l_cond
|
512 |
-
# print(self.l_cond)
|
513 |
-
# exit(0)
|
514 |
|
515 |
time_embed_dim = model_channels * 4
|
516 |
self.time_embed = nn.Sequential(
|
@@ -520,7 +532,13 @@ class UNetModel(nn.Module):
|
|
520 |
)
|
521 |
|
522 |
if self.num_classes is not None:
|
523 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
524 |
|
525 |
self.input_blocks = nn.ModuleList(
|
526 |
[
|
@@ -534,7 +552,7 @@ class UNetModel(nn.Module):
|
|
534 |
ch = model_channels
|
535 |
ds = 1
|
536 |
for level, mult in enumerate(channel_mult):
|
537 |
-
for
|
538 |
layers = [
|
539 |
ResBlock(
|
540 |
ch,
|
@@ -556,17 +574,25 @@ class UNetModel(nn.Module):
|
|
556 |
if legacy:
|
557 |
#num_heads = 1
|
558 |
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
568 |
)
|
569 |
-
)
|
570 |
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
571 |
self._feature_size += ch
|
572 |
input_block_chans.append(ch)
|
@@ -618,8 +644,10 @@ class UNetModel(nn.Module):
|
|
618 |
num_heads=num_heads,
|
619 |
num_head_channels=dim_head,
|
620 |
use_new_attention_order=use_new_attention_order,
|
621 |
-
) if not use_spatial_transformer else SpatialTransformer(
|
622 |
-
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
|
|
|
|
|
623 |
),
|
624 |
ResBlock(
|
625 |
ch,
|
@@ -634,7 +662,7 @@ class UNetModel(nn.Module):
|
|
634 |
|
635 |
self.output_blocks = nn.ModuleList([])
|
636 |
for level, mult in list(enumerate(channel_mult))[::-1]:
|
637 |
-
for i in range(num_res_blocks + 1):
|
638 |
ich = input_block_chans.pop()
|
639 |
layers = [
|
640 |
ResBlock(
|
@@ -657,18 +685,26 @@ class UNetModel(nn.Module):
|
|
657 |
if legacy:
|
658 |
#num_heads = 1
|
659 |
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
660 |
-
|
661 |
-
|
662 |
-
|
663 |
-
|
664 |
-
|
665 |
-
|
666 |
-
|
667 |
-
|
668 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
669 |
)
|
670 |
-
|
671 |
-
if level and i == num_res_blocks:
|
672 |
out_ch = ch
|
673 |
layers.append(
|
674 |
ResBlock(
|
@@ -716,7 +752,7 @@ class UNetModel(nn.Module):
|
|
716 |
self.middle_block.apply(convert_module_to_f32)
|
717 |
self.output_blocks.apply(convert_module_to_f32)
|
718 |
|
719 |
-
def forward(self, x, timesteps=None, context=None, y=None, features_adapter=None,
|
720 |
"""
|
721 |
Apply the model to an input batch.
|
722 |
:param x: an [N x C x ...] Tensor of inputs.
|
@@ -733,21 +769,26 @@ class UNetModel(nn.Module):
|
|
733 |
emb = self.time_embed(t_emb)
|
734 |
|
735 |
if self.num_classes is not None:
|
736 |
-
assert y.shape ==
|
737 |
emb = emb + self.label_emb(y)
|
738 |
|
739 |
h = x.type(self.dtype)
|
740 |
|
|
|
|
|
|
|
|
|
741 |
for id, module in enumerate(self.input_blocks):
|
742 |
h = module(h, emb, context)
|
743 |
-
if ((id+1)%3 == 0) and features_adapter is not None
|
744 |
-
h = h + features_adapter
|
|
|
745 |
hs.append(h)
|
746 |
if features_adapter is not None:
|
747 |
-
assert len(features_adapter)==
|
748 |
|
749 |
h = self.middle_block(h, emb, context)
|
750 |
-
for
|
751 |
h = th.cat([h, hs.pop()], dim=1)
|
752 |
h = module(h, emb, context)
|
753 |
h = h.type(x.dtype)
|
@@ -755,222 +796,3 @@ class UNetModel(nn.Module):
|
|
755 |
return self.id_predictor(h)
|
756 |
else:
|
757 |
return self.out(h)
|
758 |
-
|
759 |
-
|
760 |
-
class EncoderUNetModel(nn.Module):
|
761 |
-
"""
|
762 |
-
The half UNet model with attention and timestep embedding.
|
763 |
-
For usage, see UNet.
|
764 |
-
"""
|
765 |
-
|
766 |
-
def __init__(
|
767 |
-
self,
|
768 |
-
image_size,
|
769 |
-
in_channels,
|
770 |
-
model_channels,
|
771 |
-
out_channels,
|
772 |
-
num_res_blocks,
|
773 |
-
attention_resolutions,
|
774 |
-
dropout=0,
|
775 |
-
channel_mult=(1, 2, 4, 8),
|
776 |
-
conv_resample=True,
|
777 |
-
dims=2,
|
778 |
-
use_checkpoint=False,
|
779 |
-
use_fp16=False,
|
780 |
-
num_heads=1,
|
781 |
-
num_head_channels=-1,
|
782 |
-
num_heads_upsample=-1,
|
783 |
-
use_scale_shift_norm=False,
|
784 |
-
resblock_updown=False,
|
785 |
-
use_new_attention_order=False,
|
786 |
-
pool="adaptive",
|
787 |
-
*args,
|
788 |
-
**kwargs
|
789 |
-
):
|
790 |
-
super().__init__()
|
791 |
-
|
792 |
-
if num_heads_upsample == -1:
|
793 |
-
num_heads_upsample = num_heads
|
794 |
-
|
795 |
-
self.in_channels = in_channels
|
796 |
-
self.model_channels = model_channels
|
797 |
-
self.out_channels = out_channels
|
798 |
-
self.num_res_blocks = num_res_blocks
|
799 |
-
self.attention_resolutions = attention_resolutions
|
800 |
-
self.dropout = dropout
|
801 |
-
self.channel_mult = channel_mult
|
802 |
-
self.conv_resample = conv_resample
|
803 |
-
self.use_checkpoint = use_checkpoint
|
804 |
-
self.dtype = th.float16 if use_fp16 else th.float32
|
805 |
-
self.num_heads = num_heads
|
806 |
-
self.num_head_channels = num_head_channels
|
807 |
-
self.num_heads_upsample = num_heads_upsample
|
808 |
-
|
809 |
-
time_embed_dim = model_channels * 4
|
810 |
-
self.time_embed = nn.Sequential(
|
811 |
-
linear(model_channels, time_embed_dim),
|
812 |
-
nn.SiLU(),
|
813 |
-
linear(time_embed_dim, time_embed_dim),
|
814 |
-
)
|
815 |
-
|
816 |
-
self.input_blocks = nn.ModuleList(
|
817 |
-
[
|
818 |
-
TimestepEmbedSequential(
|
819 |
-
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
820 |
-
)
|
821 |
-
]
|
822 |
-
)
|
823 |
-
self._feature_size = model_channels
|
824 |
-
input_block_chans = [model_channels]
|
825 |
-
ch = model_channels
|
826 |
-
ds = 1
|
827 |
-
for level, mult in enumerate(channel_mult):
|
828 |
-
for _ in range(num_res_blocks):
|
829 |
-
layers = [
|
830 |
-
ResBlock(
|
831 |
-
ch,
|
832 |
-
time_embed_dim,
|
833 |
-
dropout,
|
834 |
-
out_channels=mult * model_channels,
|
835 |
-
dims=dims,
|
836 |
-
use_checkpoint=use_checkpoint,
|
837 |
-
use_scale_shift_norm=use_scale_shift_norm,
|
838 |
-
)
|
839 |
-
]
|
840 |
-
ch = mult * model_channels
|
841 |
-
if ds in attention_resolutions:
|
842 |
-
layers.append(
|
843 |
-
AttentionBlock(
|
844 |
-
ch,
|
845 |
-
use_checkpoint=use_checkpoint,
|
846 |
-
num_heads=num_heads,
|
847 |
-
num_head_channels=num_head_channels,
|
848 |
-
use_new_attention_order=use_new_attention_order,
|
849 |
-
)
|
850 |
-
)
|
851 |
-
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
852 |
-
self._feature_size += ch
|
853 |
-
input_block_chans.append(ch)
|
854 |
-
if level != len(channel_mult) - 1:
|
855 |
-
out_ch = ch
|
856 |
-
self.input_blocks.append(
|
857 |
-
TimestepEmbedSequential(
|
858 |
-
ResBlock(
|
859 |
-
ch,
|
860 |
-
time_embed_dim,
|
861 |
-
dropout,
|
862 |
-
out_channels=out_ch,
|
863 |
-
dims=dims,
|
864 |
-
use_checkpoint=use_checkpoint,
|
865 |
-
use_scale_shift_norm=use_scale_shift_norm,
|
866 |
-
down=True,
|
867 |
-
)
|
868 |
-
if resblock_updown
|
869 |
-
else Downsample(
|
870 |
-
ch, conv_resample, dims=dims, out_channels=out_ch
|
871 |
-
)
|
872 |
-
)
|
873 |
-
)
|
874 |
-
ch = out_ch
|
875 |
-
input_block_chans.append(ch)
|
876 |
-
ds *= 2
|
877 |
-
self._feature_size += ch
|
878 |
-
|
879 |
-
self.middle_block = TimestepEmbedSequential(
|
880 |
-
ResBlock(
|
881 |
-
ch,
|
882 |
-
time_embed_dim,
|
883 |
-
dropout,
|
884 |
-
dims=dims,
|
885 |
-
use_checkpoint=use_checkpoint,
|
886 |
-
use_scale_shift_norm=use_scale_shift_norm,
|
887 |
-
),
|
888 |
-
AttentionBlock(
|
889 |
-
ch,
|
890 |
-
use_checkpoint=use_checkpoint,
|
891 |
-
num_heads=num_heads,
|
892 |
-
num_head_channels=num_head_channels,
|
893 |
-
use_new_attention_order=use_new_attention_order,
|
894 |
-
),
|
895 |
-
ResBlock(
|
896 |
-
ch,
|
897 |
-
time_embed_dim,
|
898 |
-
dropout,
|
899 |
-
dims=dims,
|
900 |
-
use_checkpoint=use_checkpoint,
|
901 |
-
use_scale_shift_norm=use_scale_shift_norm,
|
902 |
-
),
|
903 |
-
)
|
904 |
-
self._feature_size += ch
|
905 |
-
self.pool = pool
|
906 |
-
if pool == "adaptive":
|
907 |
-
self.out = nn.Sequential(
|
908 |
-
normalization(ch),
|
909 |
-
nn.SiLU(),
|
910 |
-
nn.AdaptiveAvgPool2d((1, 1)),
|
911 |
-
zero_module(conv_nd(dims, ch, out_channels, 1)),
|
912 |
-
nn.Flatten(),
|
913 |
-
)
|
914 |
-
elif pool == "attention":
|
915 |
-
assert num_head_channels != -1
|
916 |
-
self.out = nn.Sequential(
|
917 |
-
normalization(ch),
|
918 |
-
nn.SiLU(),
|
919 |
-
AttentionPool2d(
|
920 |
-
(image_size // ds), ch, num_head_channels, out_channels
|
921 |
-
),
|
922 |
-
)
|
923 |
-
elif pool == "spatial":
|
924 |
-
self.out = nn.Sequential(
|
925 |
-
nn.Linear(self._feature_size, 2048),
|
926 |
-
nn.ReLU(),
|
927 |
-
nn.Linear(2048, self.out_channels),
|
928 |
-
)
|
929 |
-
elif pool == "spatial_v2":
|
930 |
-
self.out = nn.Sequential(
|
931 |
-
nn.Linear(self._feature_size, 2048),
|
932 |
-
normalization(2048),
|
933 |
-
nn.SiLU(),
|
934 |
-
nn.Linear(2048, self.out_channels),
|
935 |
-
)
|
936 |
-
else:
|
937 |
-
raise NotImplementedError(f"Unexpected {pool} pooling")
|
938 |
-
|
939 |
-
def convert_to_fp16(self):
|
940 |
-
"""
|
941 |
-
Convert the torso of the model to float16.
|
942 |
-
"""
|
943 |
-
self.input_blocks.apply(convert_module_to_f16)
|
944 |
-
self.middle_block.apply(convert_module_to_f16)
|
945 |
-
|
946 |
-
def convert_to_fp32(self):
|
947 |
-
"""
|
948 |
-
Convert the torso of the model to float32.
|
949 |
-
"""
|
950 |
-
self.input_blocks.apply(convert_module_to_f32)
|
951 |
-
self.middle_block.apply(convert_module_to_f32)
|
952 |
-
|
953 |
-
def forward(self, x, timesteps):
|
954 |
-
"""
|
955 |
-
Apply the model to an input batch.
|
956 |
-
:param x: an [N x C x ...] Tensor of inputs.
|
957 |
-
:param timesteps: a 1-D batch of timesteps.
|
958 |
-
:return: an [N x K] Tensor of outputs.
|
959 |
-
"""
|
960 |
-
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
961 |
-
|
962 |
-
results = []
|
963 |
-
h = x.type(self.dtype)
|
964 |
-
for module in self.input_blocks:
|
965 |
-
h = module(h, emb)
|
966 |
-
if self.pool.startswith("spatial"):
|
967 |
-
results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
968 |
-
h = self.middle_block(h, emb)
|
969 |
-
if self.pool.startswith("spatial"):
|
970 |
-
results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
971 |
-
h = th.cat(results, axis=-1)
|
972 |
-
return self.out(h)
|
973 |
-
else:
|
974 |
-
h = h.type(x.dtype)
|
975 |
-
return self.out(h)
|
976 |
-
|
|
|
1 |
from abc import abstractmethod
|
|
|
2 |
import math
|
3 |
+
import torch
|
4 |
|
5 |
import numpy as np
|
6 |
import torch as th
|
|
|
17 |
timestep_embedding,
|
18 |
)
|
19 |
from ldm.modules.attention import SpatialTransformer
|
20 |
+
from ldm.util import exists
|
21 |
|
22 |
|
23 |
# dummy replace
|
|
|
270 |
h = out_norm(h) * (1 + scale) + shift
|
271 |
h = out_rest(h)
|
272 |
else:
|
|
|
|
|
273 |
h = h + emb_out
|
274 |
h = self.out_layers(h)
|
275 |
return self.skip_connection(x) + h
|
|
|
466 |
context_dim=None, # custom transformer support
|
467 |
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
468 |
legacy=True,
|
469 |
+
disable_self_attentions=None,
|
470 |
+
num_attention_blocks=None,
|
471 |
+
disable_middle_self_attn=False,
|
472 |
+
use_linear_in_transformer=False,
|
473 |
):
|
474 |
super().__init__()
|
|
|
|
|
475 |
if use_spatial_transformer:
|
476 |
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
477 |
|
478 |
if context_dim is not None:
|
|
|
479 |
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
480 |
from omegaconf.listconfig import ListConfig
|
481 |
if type(context_dim) == ListConfig:
|
|
|
494 |
self.in_channels = in_channels
|
495 |
self.model_channels = model_channels
|
496 |
self.out_channels = out_channels
|
497 |
+
if isinstance(num_res_blocks, int):
|
498 |
+
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
499 |
+
else:
|
500 |
+
if len(num_res_blocks) != len(channel_mult):
|
501 |
+
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
502 |
+
"as a list/tuple (per-level) with the same length as channel_mult")
|
503 |
+
self.num_res_blocks = num_res_blocks
|
504 |
+
if disable_self_attentions is not None:
|
505 |
+
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
506 |
+
assert len(disable_self_attentions) == len(channel_mult)
|
507 |
+
if num_attention_blocks is not None:
|
508 |
+
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
509 |
+
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
|
510 |
+
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
511 |
+
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
512 |
+
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
513 |
+
f"attention will still not be set.")
|
514 |
+
|
515 |
self.attention_resolutions = attention_resolutions
|
516 |
self.dropout = dropout
|
517 |
self.channel_mult = channel_mult
|
|
|
523 |
self.num_head_channels = num_head_channels
|
524 |
self.num_heads_upsample = num_heads_upsample
|
525 |
self.predict_codebook_ids = n_embed is not None
|
|
|
|
|
|
|
526 |
|
527 |
time_embed_dim = model_channels * 4
|
528 |
self.time_embed = nn.Sequential(
|
|
|
532 |
)
|
533 |
|
534 |
if self.num_classes is not None:
|
535 |
+
if isinstance(self.num_classes, int):
|
536 |
+
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
537 |
+
elif self.num_classes == "continuous":
|
538 |
+
print("setting up linear c_adm embedding layer")
|
539 |
+
self.label_emb = nn.Linear(1, time_embed_dim)
|
540 |
+
else:
|
541 |
+
raise ValueError()
|
542 |
|
543 |
self.input_blocks = nn.ModuleList(
|
544 |
[
|
|
|
552 |
ch = model_channels
|
553 |
ds = 1
|
554 |
for level, mult in enumerate(channel_mult):
|
555 |
+
for nr in range(self.num_res_blocks[level]):
|
556 |
layers = [
|
557 |
ResBlock(
|
558 |
ch,
|
|
|
574 |
if legacy:
|
575 |
#num_heads = 1
|
576 |
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
577 |
+
if exists(disable_self_attentions):
|
578 |
+
disabled_sa = disable_self_attentions[level]
|
579 |
+
else:
|
580 |
+
disabled_sa = False
|
581 |
+
|
582 |
+
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
583 |
+
layers.append(
|
584 |
+
AttentionBlock(
|
585 |
+
ch,
|
586 |
+
use_checkpoint=use_checkpoint,
|
587 |
+
num_heads=num_heads,
|
588 |
+
num_head_channels=dim_head,
|
589 |
+
use_new_attention_order=use_new_attention_order,
|
590 |
+
) if not use_spatial_transformer else SpatialTransformer(
|
591 |
+
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
592 |
+
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
593 |
+
use_checkpoint=use_checkpoint
|
594 |
+
)
|
595 |
)
|
|
|
596 |
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
597 |
self._feature_size += ch
|
598 |
input_block_chans.append(ch)
|
|
|
644 |
num_heads=num_heads,
|
645 |
num_head_channels=dim_head,
|
646 |
use_new_attention_order=use_new_attention_order,
|
647 |
+
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
|
648 |
+
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
649 |
+
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
650 |
+
use_checkpoint=use_checkpoint
|
651 |
),
|
652 |
ResBlock(
|
653 |
ch,
|
|
|
662 |
|
663 |
self.output_blocks = nn.ModuleList([])
|
664 |
for level, mult in list(enumerate(channel_mult))[::-1]:
|
665 |
+
for i in range(self.num_res_blocks[level] + 1):
|
666 |
ich = input_block_chans.pop()
|
667 |
layers = [
|
668 |
ResBlock(
|
|
|
685 |
if legacy:
|
686 |
#num_heads = 1
|
687 |
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
688 |
+
if exists(disable_self_attentions):
|
689 |
+
disabled_sa = disable_self_attentions[level]
|
690 |
+
else:
|
691 |
+
disabled_sa = False
|
692 |
+
|
693 |
+
if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
|
694 |
+
layers.append(
|
695 |
+
AttentionBlock(
|
696 |
+
ch,
|
697 |
+
use_checkpoint=use_checkpoint,
|
698 |
+
num_heads=num_heads_upsample,
|
699 |
+
num_head_channels=dim_head,
|
700 |
+
use_new_attention_order=use_new_attention_order,
|
701 |
+
) if not use_spatial_transformer else SpatialTransformer(
|
702 |
+
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
703 |
+
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
704 |
+
use_checkpoint=use_checkpoint
|
705 |
+
)
|
706 |
)
|
707 |
+
if level and i == self.num_res_blocks[level]:
|
|
|
708 |
out_ch = ch
|
709 |
layers.append(
|
710 |
ResBlock(
|
|
|
752 |
self.middle_block.apply(convert_module_to_f32)
|
753 |
self.output_blocks.apply(convert_module_to_f32)
|
754 |
|
755 |
+
def forward(self, x, timesteps=None, context=None, y=None, features_adapter=None, append_to_context=None, **kwargs):
|
756 |
"""
|
757 |
Apply the model to an input batch.
|
758 |
:param x: an [N x C x ...] Tensor of inputs.
|
|
|
769 |
emb = self.time_embed(t_emb)
|
770 |
|
771 |
if self.num_classes is not None:
|
772 |
+
assert y.shape[0] == x.shape[0]
|
773 |
emb = emb + self.label_emb(y)
|
774 |
|
775 |
h = x.type(self.dtype)
|
776 |
|
777 |
+
if append_to_context is not None:
|
778 |
+
context = torch.cat([context, append_to_context], dim=1)
|
779 |
+
|
780 |
+
adapter_idx = 0
|
781 |
for id, module in enumerate(self.input_blocks):
|
782 |
h = module(h, emb, context)
|
783 |
+
if ((id+1)%3 == 0) and features_adapter is not None:
|
784 |
+
h = h + features_adapter[adapter_idx]
|
785 |
+
adapter_idx += 1
|
786 |
hs.append(h)
|
787 |
if features_adapter is not None:
|
788 |
+
assert len(features_adapter)==adapter_idx, 'Wrong features_adapter'
|
789 |
|
790 |
h = self.middle_block(h, emb, context)
|
791 |
+
for module in self.output_blocks:
|
792 |
h = th.cat([h, hs.pop()], dim=1)
|
793 |
h = module(h, emb, context)
|
794 |
h = h.type(x.dtype)
|
|
|
796 |
return self.id_predictor(h)
|
797 |
else:
|
798 |
return self.out(h)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ldm/modules/diffusionmodules/util.py
CHANGED
@@ -122,7 +122,9 @@ class CheckpointFunction(torch.autograd.Function):
|
|
122 |
ctx.run_function = run_function
|
123 |
ctx.input_tensors = list(args[:length])
|
124 |
ctx.input_params = list(args[length:])
|
125 |
-
|
|
|
|
|
126 |
with torch.no_grad():
|
127 |
output_tensors = ctx.run_function(*ctx.input_tensors)
|
128 |
return output_tensors
|
@@ -130,7 +132,8 @@ class CheckpointFunction(torch.autograd.Function):
|
|
130 |
@staticmethod
|
131 |
def backward(ctx, *output_grads):
|
132 |
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
133 |
-
with torch.enable_grad()
|
|
|
134 |
# Fixes a bug where the first op in run_function modifies the
|
135 |
# Tensor storage in place, which is not allowed for detach()'d
|
136 |
# Tensors.
|
|
|
122 |
ctx.run_function = run_function
|
123 |
ctx.input_tensors = list(args[:length])
|
124 |
ctx.input_params = list(args[length:])
|
125 |
+
ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
|
126 |
+
"dtype": torch.get_autocast_gpu_dtype(),
|
127 |
+
"cache_enabled": torch.is_autocast_cache_enabled()}
|
128 |
with torch.no_grad():
|
129 |
output_tensors = ctx.run_function(*ctx.input_tensors)
|
130 |
return output_tensors
|
|
|
132 |
@staticmethod
|
133 |
def backward(ctx, *output_grads):
|
134 |
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
135 |
+
with torch.enable_grad(), \
|
136 |
+
torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
|
137 |
# Fixes a bug where the first op in run_function modifies the
|
138 |
# Tensor storage in place, which is not allowed for detach()'d
|
139 |
# Tensors.
|
ldm/modules/ema.py
CHANGED
@@ -10,24 +10,28 @@ class LitEma(nn.Module):
|
|
10 |
|
11 |
self.m_name2s_name = {}
|
12 |
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
|
13 |
-
self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
|
14 |
-
|
15 |
|
16 |
for name, p in model.named_parameters():
|
17 |
if p.requires_grad:
|
18 |
-
#remove as '.'-character is not allowed in buffers
|
19 |
-
s_name = name.replace('.','')
|
20 |
-
self.m_name2s_name.update({name:s_name})
|
21 |
-
self.register_buffer(s_name,p.clone().detach().data)
|
22 |
|
23 |
self.collected_params = []
|
24 |
|
25 |
-
def
|
|
|
|
|
|
|
|
|
26 |
decay = self.decay
|
27 |
|
28 |
if self.num_updates >= 0:
|
29 |
self.num_updates += 1
|
30 |
-
decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
|
31 |
|
32 |
one_minus_decay = 1.0 - decay
|
33 |
|
|
|
10 |
|
11 |
self.m_name2s_name = {}
|
12 |
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
|
13 |
+
self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
|
14 |
+
else torch.tensor(-1, dtype=torch.int))
|
15 |
|
16 |
for name, p in model.named_parameters():
|
17 |
if p.requires_grad:
|
18 |
+
# remove as '.'-character is not allowed in buffers
|
19 |
+
s_name = name.replace('.', '')
|
20 |
+
self.m_name2s_name.update({name: s_name})
|
21 |
+
self.register_buffer(s_name, p.clone().detach().data)
|
22 |
|
23 |
self.collected_params = []
|
24 |
|
25 |
+
def reset_num_updates(self):
|
26 |
+
del self.num_updates
|
27 |
+
self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
|
28 |
+
|
29 |
+
def forward(self, model):
|
30 |
decay = self.decay
|
31 |
|
32 |
if self.num_updates >= 0:
|
33 |
self.num_updates += 1
|
34 |
+
decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
|
35 |
|
36 |
one_minus_decay = 1.0 - decay
|
37 |
|
ldm/modules/encoders/adapter.py
CHANGED
@@ -1,9 +1,8 @@
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
-
import torch.nn.functional as F
|
4 |
-
from ldm.modules.attention import SpatialTransformer, BasicTransformerBlock
|
5 |
from collections import OrderedDict
|
6 |
|
|
|
7 |
def conv_nd(dims, *args, **kwargs):
|
8 |
"""
|
9 |
Create a 1D, 2D, or 3D convolution module.
|
@@ -16,6 +15,7 @@ def conv_nd(dims, *args, **kwargs):
|
|
16 |
return nn.Conv3d(*args, **kwargs)
|
17 |
raise ValueError(f"unsupported dimensions: {dims}")
|
18 |
|
|
|
19 |
def avg_pool_nd(dims, *args, **kwargs):
|
20 |
"""
|
21 |
Create a 1D, 2D, or 3D average pooling module.
|
@@ -28,6 +28,7 @@ def avg_pool_nd(dims, *args, **kwargs):
|
|
28 |
return nn.AvgPool3d(*args, **kwargs)
|
29 |
raise ValueError(f"unsupported dimensions: {dims}")
|
30 |
|
|
|
31 |
class Downsample(nn.Module):
|
32 |
"""
|
33 |
A downsampling layer with an optional convolution.
|
@@ -37,7 +38,7 @@ class Downsample(nn.Module):
|
|
37 |
downsampling occurs in the inner-two dimensions.
|
38 |
"""
|
39 |
|
40 |
-
def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
|
41 |
super().__init__()
|
42 |
self.channels = channels
|
43 |
self.out_channels = out_channels or channels
|
@@ -60,15 +61,16 @@ class Downsample(nn.Module):
|
|
60 |
class ResnetBlock(nn.Module):
|
61 |
def __init__(self, in_c, out_c, down, ksize=3, sk=False, use_conv=True):
|
62 |
super().__init__()
|
63 |
-
ps = ksize//2
|
64 |
-
if in_c != out_c or sk==False:
|
65 |
self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps)
|
66 |
else:
|
|
|
67 |
self.in_conv = None
|
68 |
self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1)
|
69 |
self.act = nn.ReLU()
|
70 |
self.block2 = nn.Conv2d(out_c, out_c, ksize, 1, ps)
|
71 |
-
if sk==False:
|
72 |
self.skep = nn.Conv2d(in_c, out_c, ksize, 1, ps)
|
73 |
else:
|
74 |
self.skep = None
|
@@ -80,7 +82,7 @@ class ResnetBlock(nn.Module):
|
|
80 |
def forward(self, x):
|
81 |
if self.down == True:
|
82 |
x = self.down_opt(x)
|
83 |
-
if self.in_conv is not None:
|
84 |
x = self.in_conv(x)
|
85 |
|
86 |
h = self.block1(x)
|
@@ -101,12 +103,14 @@ class Adapter(nn.Module):
|
|
101 |
self.body = []
|
102 |
for i in range(len(channels)):
|
103 |
for j in range(nums_rb):
|
104 |
-
if (i!=0) and (j==0):
|
105 |
-
self.body.append(
|
|
|
106 |
else:
|
107 |
-
self.body.append(
|
|
|
108 |
self.body = nn.ModuleList(self.body)
|
109 |
-
self.conv_in = nn.Conv2d(cin,channels[0], 3, 1, 1)
|
110 |
|
111 |
def forward(self, x):
|
112 |
# unshuffle
|
@@ -116,12 +120,79 @@ class Adapter(nn.Module):
|
|
116 |
x = self.conv_in(x)
|
117 |
for i in range(len(self.channels)):
|
118 |
for j in range(self.nums_rb):
|
119 |
-
idx = i*self.nums_rb +j
|
120 |
x = self.body[idx](x)
|
121 |
features.append(x)
|
122 |
|
123 |
return features
|
124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
class ResnetBlock_light(nn.Module):
|
127 |
def __init__(self, in_c):
|
@@ -185,66 +256,3 @@ class Adapter_light(nn.Module):
|
|
185 |
features.append(x)
|
186 |
|
187 |
return features
|
188 |
-
|
189 |
-
class QuickGELU(nn.Module):
|
190 |
-
|
191 |
-
def forward(self, x: torch.Tensor):
|
192 |
-
return x * torch.sigmoid(1.702 * x)
|
193 |
-
|
194 |
-
class ResidualAttentionBlock(nn.Module):
|
195 |
-
|
196 |
-
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
197 |
-
super().__init__()
|
198 |
-
|
199 |
-
self.attn = nn.MultiheadAttention(d_model, n_head)
|
200 |
-
self.ln_1 = LayerNorm(d_model)
|
201 |
-
self.mlp = nn.Sequential(
|
202 |
-
OrderedDict([("c_fc", nn.Linear(d_model, d_model * 4)), ("gelu", QuickGELU()),
|
203 |
-
("c_proj", nn.Linear(d_model * 4, d_model))]))
|
204 |
-
self.ln_2 = LayerNorm(d_model)
|
205 |
-
self.attn_mask = attn_mask
|
206 |
-
|
207 |
-
def attention(self, x: torch.Tensor):
|
208 |
-
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
209 |
-
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
210 |
-
|
211 |
-
def forward(self, x: torch.Tensor):
|
212 |
-
x = x + self.attention(self.ln_1(x))
|
213 |
-
x = x + self.mlp(self.ln_2(x))
|
214 |
-
return x
|
215 |
-
|
216 |
-
class LayerNorm(nn.LayerNorm):
|
217 |
-
"""Subclass torch's LayerNorm to handle fp16."""
|
218 |
-
|
219 |
-
def forward(self, x: torch.Tensor):
|
220 |
-
orig_type = x.dtype
|
221 |
-
ret = super().forward(x.type(torch.float32))
|
222 |
-
return ret.type(orig_type)
|
223 |
-
|
224 |
-
class StyleAdapter(nn.Module):
|
225 |
-
|
226 |
-
def __init__(self, width=1024, context_dim=768, num_head=8, n_layes=3, num_token=4):
|
227 |
-
super().__init__()
|
228 |
-
|
229 |
-
scale = width ** -0.5
|
230 |
-
self.transformer_layes = nn.Sequential(*[ResidualAttentionBlock(width, num_head) for _ in range(n_layes)])
|
231 |
-
self.num_token = num_token
|
232 |
-
self.style_embedding = nn.Parameter(torch.randn(1, num_token, width) * scale)
|
233 |
-
self.ln_post = LayerNorm(width)
|
234 |
-
self.ln_pre = LayerNorm(width)
|
235 |
-
self.proj = nn.Parameter(scale * torch.randn(width, context_dim))
|
236 |
-
|
237 |
-
def forward(self, x):
|
238 |
-
# x shape [N, HW+1, C]
|
239 |
-
style_embedding = self.style_embedding + torch.zeros(
|
240 |
-
(x.shape[0], self.num_token, self.style_embedding.shape[-1]), device=x.device)
|
241 |
-
x = torch.cat([x, style_embedding], dim=1)
|
242 |
-
x = self.ln_pre(x)
|
243 |
-
x = x.permute(1, 0, 2) # NLD -> LND
|
244 |
-
x = self.transformer_layes(x)
|
245 |
-
x = x.permute(1, 0, 2) # LND -> NLD
|
246 |
-
|
247 |
-
x = self.ln_post(x[:, -self.num_token:, :])
|
248 |
-
x = x @ self.proj
|
249 |
-
|
250 |
-
return x
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
|
|
|
|
3 |
from collections import OrderedDict
|
4 |
|
5 |
+
|
6 |
def conv_nd(dims, *args, **kwargs):
|
7 |
"""
|
8 |
Create a 1D, 2D, or 3D convolution module.
|
|
|
15 |
return nn.Conv3d(*args, **kwargs)
|
16 |
raise ValueError(f"unsupported dimensions: {dims}")
|
17 |
|
18 |
+
|
19 |
def avg_pool_nd(dims, *args, **kwargs):
|
20 |
"""
|
21 |
Create a 1D, 2D, or 3D average pooling module.
|
|
|
28 |
return nn.AvgPool3d(*args, **kwargs)
|
29 |
raise ValueError(f"unsupported dimensions: {dims}")
|
30 |
|
31 |
+
|
32 |
class Downsample(nn.Module):
|
33 |
"""
|
34 |
A downsampling layer with an optional convolution.
|
|
|
38 |
downsampling occurs in the inner-two dimensions.
|
39 |
"""
|
40 |
|
41 |
+
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
42 |
super().__init__()
|
43 |
self.channels = channels
|
44 |
self.out_channels = out_channels or channels
|
|
|
61 |
class ResnetBlock(nn.Module):
|
62 |
def __init__(self, in_c, out_c, down, ksize=3, sk=False, use_conv=True):
|
63 |
super().__init__()
|
64 |
+
ps = ksize // 2
|
65 |
+
if in_c != out_c or sk == False:
|
66 |
self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps)
|
67 |
else:
|
68 |
+
# print('n_in')
|
69 |
self.in_conv = None
|
70 |
self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1)
|
71 |
self.act = nn.ReLU()
|
72 |
self.block2 = nn.Conv2d(out_c, out_c, ksize, 1, ps)
|
73 |
+
if sk == False:
|
74 |
self.skep = nn.Conv2d(in_c, out_c, ksize, 1, ps)
|
75 |
else:
|
76 |
self.skep = None
|
|
|
82 |
def forward(self, x):
|
83 |
if self.down == True:
|
84 |
x = self.down_opt(x)
|
85 |
+
if self.in_conv is not None: # edit
|
86 |
x = self.in_conv(x)
|
87 |
|
88 |
h = self.block1(x)
|
|
|
103 |
self.body = []
|
104 |
for i in range(len(channels)):
|
105 |
for j in range(nums_rb):
|
106 |
+
if (i != 0) and (j == 0):
|
107 |
+
self.body.append(
|
108 |
+
ResnetBlock(channels[i - 1], channels[i], down=True, ksize=ksize, sk=sk, use_conv=use_conv))
|
109 |
else:
|
110 |
+
self.body.append(
|
111 |
+
ResnetBlock(channels[i], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv))
|
112 |
self.body = nn.ModuleList(self.body)
|
113 |
+
self.conv_in = nn.Conv2d(cin, channels[0], 3, 1, 1)
|
114 |
|
115 |
def forward(self, x):
|
116 |
# unshuffle
|
|
|
120 |
x = self.conv_in(x)
|
121 |
for i in range(len(self.channels)):
|
122 |
for j in range(self.nums_rb):
|
123 |
+
idx = i * self.nums_rb + j
|
124 |
x = self.body[idx](x)
|
125 |
features.append(x)
|
126 |
|
127 |
return features
|
128 |
+
|
129 |
+
|
130 |
+
class LayerNorm(nn.LayerNorm):
|
131 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
132 |
+
|
133 |
+
def forward(self, x: torch.Tensor):
|
134 |
+
orig_type = x.dtype
|
135 |
+
ret = super().forward(x.type(torch.float32))
|
136 |
+
return ret.type(orig_type)
|
137 |
+
|
138 |
+
|
139 |
+
class QuickGELU(nn.Module):
|
140 |
+
|
141 |
+
def forward(self, x: torch.Tensor):
|
142 |
+
return x * torch.sigmoid(1.702 * x)
|
143 |
+
|
144 |
+
|
145 |
+
class ResidualAttentionBlock(nn.Module):
|
146 |
+
|
147 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
148 |
+
super().__init__()
|
149 |
+
|
150 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
151 |
+
self.ln_1 = LayerNorm(d_model)
|
152 |
+
self.mlp = nn.Sequential(
|
153 |
+
OrderedDict([("c_fc", nn.Linear(d_model, d_model * 4)), ("gelu", QuickGELU()),
|
154 |
+
("c_proj", nn.Linear(d_model * 4, d_model))]))
|
155 |
+
self.ln_2 = LayerNorm(d_model)
|
156 |
+
self.attn_mask = attn_mask
|
157 |
+
|
158 |
+
def attention(self, x: torch.Tensor):
|
159 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
160 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
161 |
+
|
162 |
+
def forward(self, x: torch.Tensor):
|
163 |
+
x = x + self.attention(self.ln_1(x))
|
164 |
+
x = x + self.mlp(self.ln_2(x))
|
165 |
+
return x
|
166 |
+
|
167 |
+
|
168 |
+
class StyleAdapter(nn.Module):
|
169 |
+
|
170 |
+
def __init__(self, width=1024, context_dim=768, num_head=8, n_layes=3, num_token=4):
|
171 |
+
super().__init__()
|
172 |
+
|
173 |
+
scale = width ** -0.5
|
174 |
+
self.transformer_layes = nn.Sequential(*[ResidualAttentionBlock(width, num_head) for _ in range(n_layes)])
|
175 |
+
self.num_token = num_token
|
176 |
+
self.style_embedding = nn.Parameter(torch.randn(1, num_token, width) * scale)
|
177 |
+
self.ln_post = LayerNorm(width)
|
178 |
+
self.ln_pre = LayerNorm(width)
|
179 |
+
self.proj = nn.Parameter(scale * torch.randn(width, context_dim))
|
180 |
+
|
181 |
+
def forward(self, x):
|
182 |
+
# x shape [N, HW+1, C]
|
183 |
+
style_embedding = self.style_embedding + torch.zeros(
|
184 |
+
(x.shape[0], self.num_token, self.style_embedding.shape[-1]), device=x.device)
|
185 |
+
x = torch.cat([x, style_embedding], dim=1)
|
186 |
+
x = self.ln_pre(x)
|
187 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
188 |
+
x = self.transformer_layes(x)
|
189 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
190 |
+
|
191 |
+
x = self.ln_post(x[:, -self.num_token:, :])
|
192 |
+
x = x @ self.proj
|
193 |
+
|
194 |
+
return x
|
195 |
+
|
196 |
|
197 |
class ResnetBlock_light(nn.Module):
|
198 |
def __init__(self, in_c):
|
|
|
256 |
features.append(x)
|
257 |
|
258 |
return features
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ldm/modules/encoders/modules.py
CHANGED
@@ -1,12 +1,13 @@
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
-
|
4 |
-
import
|
5 |
-
from einops import rearrange, repeat
|
6 |
-
from transformers import CLIPTokenizer, CLIPTextModel
|
7 |
-
import kornia
|
8 |
|
9 |
-
from
|
|
|
|
|
|
|
|
|
10 |
|
11 |
|
12 |
class AbstractEncoder(nn.Module):
|
@@ -17,6 +18,11 @@ class AbstractEncoder(nn.Module):
|
|
17 |
raise NotImplementedError
|
18 |
|
19 |
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
class ClassEmbedder(nn.Module):
|
22 |
def __init__(self, embed_dim, n_classes=1000, key='class'):
|
@@ -33,116 +39,48 @@ class ClassEmbedder(nn.Module):
|
|
33 |
return c
|
34 |
|
35 |
|
36 |
-
class
|
37 |
-
"""
|
38 |
-
def __init__(self,
|
39 |
super().__init__()
|
|
|
|
|
40 |
self.device = device
|
41 |
-
self.
|
42 |
-
|
43 |
-
|
44 |
-
def forward(self, tokens):
|
45 |
-
tokens = tokens.to(self.device) # meh
|
46 |
-
z = self.transformer(tokens, return_embeddings=True)
|
47 |
-
return z
|
48 |
|
49 |
-
def
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
""" Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
|
55 |
-
def __init__(self, device="cuda", vq_interface=True, max_length=77):
|
56 |
-
super().__init__()
|
57 |
-
from transformers import BertTokenizerFast # TODO: add to reuquirements
|
58 |
-
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
59 |
-
self.device = device
|
60 |
-
self.vq_interface = vq_interface
|
61 |
-
self.max_length = max_length
|
62 |
|
63 |
def forward(self, text):
|
64 |
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
65 |
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
66 |
tokens = batch_encoding["input_ids"].to(self.device)
|
67 |
-
|
68 |
-
|
69 |
-
@torch.no_grad()
|
70 |
-
def encode(self, text):
|
71 |
-
tokens = self(text)
|
72 |
-
if not self.vq_interface:
|
73 |
-
return tokens
|
74 |
-
return None, None, [None, None, tokens]
|
75 |
-
|
76 |
-
def decode(self, text):
|
77 |
-
return text
|
78 |
-
|
79 |
-
|
80 |
-
class BERTEmbedder(AbstractEncoder):
|
81 |
-
"""Uses the BERT tokenizr model and add some transformer encoder layers"""
|
82 |
-
def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
|
83 |
-
device="cuda",use_tokenizer=True, embedding_dropout=0.0):
|
84 |
-
super().__init__()
|
85 |
-
self.use_tknz_fn = use_tokenizer
|
86 |
-
if self.use_tknz_fn:
|
87 |
-
self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
|
88 |
-
self.device = device
|
89 |
-
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
|
90 |
-
attn_layers=Encoder(dim=n_embed, depth=n_layer),
|
91 |
-
emb_dropout=embedding_dropout)
|
92 |
|
93 |
-
|
94 |
-
if self.use_tknz_fn:
|
95 |
-
tokens = self.tknz_fn(text)#.to(self.device)
|
96 |
-
else:
|
97 |
-
tokens = text
|
98 |
-
z = self.transformer(tokens, return_embeddings=True)
|
99 |
return z
|
100 |
|
101 |
def encode(self, text):
|
102 |
-
# output of length 77
|
103 |
return self(text)
|
104 |
|
105 |
|
106 |
-
class SpatialRescaler(nn.Module):
|
107 |
-
def __init__(self,
|
108 |
-
n_stages=1,
|
109 |
-
method='bilinear',
|
110 |
-
multiplier=0.5,
|
111 |
-
in_channels=3,
|
112 |
-
out_channels=None,
|
113 |
-
bias=False):
|
114 |
-
super().__init__()
|
115 |
-
self.n_stages = n_stages
|
116 |
-
assert self.n_stages >= 0
|
117 |
-
assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
|
118 |
-
self.multiplier = multiplier
|
119 |
-
self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
|
120 |
-
self.remap_output = out_channels is not None
|
121 |
-
if self.remap_output:
|
122 |
-
print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
|
123 |
-
self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
|
124 |
-
|
125 |
-
def forward(self,x):
|
126 |
-
for stage in range(self.n_stages):
|
127 |
-
x = self.interpolator(x, scale_factor=self.multiplier)
|
128 |
-
|
129 |
-
|
130 |
-
if self.remap_output:
|
131 |
-
x = self.channel_mapper(x)
|
132 |
-
return x
|
133 |
-
|
134 |
-
def encode(self, x):
|
135 |
-
return self(x)
|
136 |
-
|
137 |
class FrozenCLIPEmbedder(AbstractEncoder):
|
138 |
-
"""Uses the CLIP transformer encoder for text (from
|
139 |
-
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77
|
|
|
140 |
super().__init__()
|
141 |
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
142 |
-
self.transformer =
|
143 |
self.device = device
|
144 |
self.max_length = max_length
|
145 |
-
|
|
|
|
|
146 |
|
147 |
def freeze(self):
|
148 |
self.transformer = self.transformer.eval()
|
@@ -153,26 +91,47 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
|
153 |
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
154 |
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
155 |
tokens = batch_encoding["input_ids"].to(self.device)
|
156 |
-
outputs = self.transformer(input_ids=tokens)
|
157 |
|
158 |
-
|
|
|
|
|
|
|
|
|
159 |
return z
|
160 |
|
161 |
def encode(self, text):
|
162 |
return self(text)
|
163 |
|
164 |
|
165 |
-
class
|
166 |
"""
|
167 |
-
Uses the
|
168 |
"""
|
169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
super().__init__()
|
171 |
-
|
|
|
|
|
|
|
|
|
172 |
self.device = device
|
173 |
self.max_length = max_length
|
174 |
-
|
175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
|
177 |
def freeze(self):
|
178 |
self.model = self.model.eval()
|
@@ -180,55 +139,303 @@ class FrozenCLIPTextEmbedder(nn.Module):
|
|
180 |
param.requires_grad = False
|
181 |
|
182 |
def forward(self, text):
|
183 |
-
tokens =
|
184 |
-
z = self.
|
185 |
-
if self.normalize:
|
186 |
-
z = z / torch.linalg.norm(z, dim=1, keepdim=True)
|
187 |
return z
|
188 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
def encode(self, text):
|
190 |
-
|
191 |
-
if z.ndim==2:
|
192 |
-
z = z[:, None, :]
|
193 |
-
z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
|
194 |
-
return z
|
195 |
|
196 |
|
197 |
-
class
|
198 |
-
"""
|
199 |
-
|
200 |
-
"""
|
201 |
-
def __init__(
|
202 |
-
self,
|
203 |
-
model,
|
204 |
-
jit=False,
|
205 |
-
device='cuda' if torch.cuda.is_available() else 'cpu',
|
206 |
-
antialias=False,
|
207 |
-
):
|
208 |
super().__init__()
|
209 |
-
self.
|
|
|
|
|
|
|
210 |
|
211 |
-
|
|
|
212 |
|
213 |
-
|
214 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
|
226 |
-
def forward(self, x):
|
227 |
-
# x is assumed to be in range [-1,1]
|
228 |
-
return self.model.encode_image(self.preprocess(x))
|
229 |
|
230 |
|
231 |
if __name__ == "__main__":
|
232 |
-
from ldm.util import count_params
|
233 |
model = FrozenCLIPEmbedder()
|
234 |
-
count_params(model, verbose=True)
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
+
import math
|
4 |
+
from torch.utils.checkpoint import checkpoint
|
|
|
|
|
|
|
5 |
|
6 |
+
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel, CLIPModel
|
7 |
+
|
8 |
+
import open_clip
|
9 |
+
import re
|
10 |
+
from ldm.util import default, count_params
|
11 |
|
12 |
|
13 |
class AbstractEncoder(nn.Module):
|
|
|
18 |
raise NotImplementedError
|
19 |
|
20 |
|
21 |
+
class IdentityEncoder(AbstractEncoder):
|
22 |
+
|
23 |
+
def encode(self, x):
|
24 |
+
return x
|
25 |
+
|
26 |
|
27 |
class ClassEmbedder(nn.Module):
|
28 |
def __init__(self, embed_dim, n_classes=1000, key='class'):
|
|
|
39 |
return c
|
40 |
|
41 |
|
42 |
+
class FrozenT5Embedder(AbstractEncoder):
|
43 |
+
"""Uses the T5 transformer encoder for text"""
|
44 |
+
def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
|
45 |
super().__init__()
|
46 |
+
self.tokenizer = T5Tokenizer.from_pretrained(version)
|
47 |
+
self.transformer = T5EncoderModel.from_pretrained(version)
|
48 |
self.device = device
|
49 |
+
self.max_length = max_length # TODO: typical value?
|
50 |
+
if freeze:
|
51 |
+
self.freeze()
|
|
|
|
|
|
|
|
|
52 |
|
53 |
+
def freeze(self):
|
54 |
+
self.transformer = self.transformer.eval()
|
55 |
+
#self.train = disabled_train
|
56 |
+
for param in self.parameters():
|
57 |
+
param.requires_grad = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
def forward(self, text):
|
60 |
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
61 |
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
62 |
tokens = batch_encoding["input_ids"].to(self.device)
|
63 |
+
outputs = self.transformer(input_ids=tokens)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
+
z = outputs.last_hidden_state
|
|
|
|
|
|
|
|
|
|
|
66 |
return z
|
67 |
|
68 |
def encode(self, text):
|
|
|
69 |
return self(text)
|
70 |
|
71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
class FrozenCLIPEmbedder(AbstractEncoder):
|
73 |
+
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
74 |
+
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
|
75 |
+
freeze=True, layer="last"): # clip-vit-base-patch32
|
76 |
super().__init__()
|
77 |
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
78 |
+
self.transformer = CLIPModel.from_pretrained(version).text_model
|
79 |
self.device = device
|
80 |
self.max_length = max_length
|
81 |
+
if freeze:
|
82 |
+
self.freeze()
|
83 |
+
self.layer = layer
|
84 |
|
85 |
def freeze(self):
|
86 |
self.transformer = self.transformer.eval()
|
|
|
91 |
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
92 |
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
93 |
tokens = batch_encoding["input_ids"].to(self.device)
|
94 |
+
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer != 'last')
|
95 |
|
96 |
+
if self.layer == 'penultimate':
|
97 |
+
z = outputs.hidden_states[-2]
|
98 |
+
z = self.transformer.final_layer_norm(z)
|
99 |
+
else:
|
100 |
+
z = outputs.last_hidden_state
|
101 |
return z
|
102 |
|
103 |
def encode(self, text):
|
104 |
return self(text)
|
105 |
|
106 |
|
107 |
+
class FrozenOpenCLIPEmbedder(AbstractEncoder):
|
108 |
"""
|
109 |
+
Uses the OpenCLIP transformer encoder for text
|
110 |
"""
|
111 |
+
LAYERS = [
|
112 |
+
#"pooled",
|
113 |
+
"last",
|
114 |
+
"penultimate"
|
115 |
+
]
|
116 |
+
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
|
117 |
+
freeze=True, layer="last"):
|
118 |
super().__init__()
|
119 |
+
assert layer in self.LAYERS
|
120 |
+
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
|
121 |
+
del model.visual
|
122 |
+
self.model = model
|
123 |
+
|
124 |
self.device = device
|
125 |
self.max_length = max_length
|
126 |
+
if freeze:
|
127 |
+
self.freeze()
|
128 |
+
self.layer = layer
|
129 |
+
if self.layer == "last":
|
130 |
+
self.layer_idx = 0
|
131 |
+
elif self.layer == "penultimate":
|
132 |
+
self.layer_idx = 1
|
133 |
+
else:
|
134 |
+
raise NotImplementedError()
|
135 |
|
136 |
def freeze(self):
|
137 |
self.model = self.model.eval()
|
|
|
139 |
param.requires_grad = False
|
140 |
|
141 |
def forward(self, text):
|
142 |
+
tokens = open_clip.tokenize(text)
|
143 |
+
z = self.encode_with_transformer(tokens.to(self.device))
|
|
|
|
|
144 |
return z
|
145 |
|
146 |
+
def encode_with_transformer(self, text):
|
147 |
+
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
|
148 |
+
x = x + self.model.positional_embedding
|
149 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
150 |
+
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
|
151 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
152 |
+
x = self.model.ln_final(x)
|
153 |
+
return x
|
154 |
+
|
155 |
+
def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
|
156 |
+
for i, r in enumerate(self.model.transformer.resblocks):
|
157 |
+
if i == len(self.model.transformer.resblocks) - self.layer_idx:
|
158 |
+
break
|
159 |
+
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
|
160 |
+
x = checkpoint(r, x, attn_mask)
|
161 |
+
else:
|
162 |
+
x = r(x, attn_mask=attn_mask)
|
163 |
+
return x
|
164 |
+
|
165 |
def encode(self, text):
|
166 |
+
return self(text)
|
|
|
|
|
|
|
|
|
167 |
|
168 |
|
169 |
+
class FrozenCLIPT5Encoder(AbstractEncoder):
|
170 |
+
def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
|
171 |
+
clip_max_length=77, t5_max_length=77):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
super().__init__()
|
173 |
+
self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
|
174 |
+
self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
|
175 |
+
print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
|
176 |
+
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.")
|
177 |
|
178 |
+
def encode(self, text):
|
179 |
+
return self(text)
|
180 |
|
181 |
+
def forward(self, text):
|
182 |
+
clip_z = self.clip_encoder.encode(text)
|
183 |
+
t5_z = self.t5_encoder.encode(text)
|
184 |
+
return [clip_z, t5_z]
|
185 |
+
|
186 |
+
|
187 |
+
# code from sd-webui
|
188 |
+
re_attention = re.compile(r"""
|
189 |
+
\\\(|
|
190 |
+
\\\)|
|
191 |
+
\\\[|
|
192 |
+
\\]|
|
193 |
+
\\\\|
|
194 |
+
\\|
|
195 |
+
\(|
|
196 |
+
\[|
|
197 |
+
:([+-]?[.\d]+)\)|
|
198 |
+
\)|
|
199 |
+
]|
|
200 |
+
[^\\()\[\]:]+|
|
201 |
+
:
|
202 |
+
""", re.X)
|
203 |
+
|
204 |
+
|
205 |
+
def parse_prompt_attention(text):
|
206 |
+
"""
|
207 |
+
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
208 |
+
Accepted tokens are:
|
209 |
+
(abc) - increases attention to abc by a multiplier of 1.1
|
210 |
+
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
211 |
+
[abc] - decreases attention to abc by a multiplier of 1.1
|
212 |
+
\( - literal character '('
|
213 |
+
\[ - literal character '['
|
214 |
+
\) - literal character ')'
|
215 |
+
\] - literal character ']'
|
216 |
+
\\ - literal character '\'
|
217 |
+
anything else - just text
|
218 |
+
|
219 |
+
>>> parse_prompt_attention('normal text')
|
220 |
+
[['normal text', 1.0]]
|
221 |
+
>>> parse_prompt_attention('an (important) word')
|
222 |
+
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
223 |
+
>>> parse_prompt_attention('(unbalanced')
|
224 |
+
[['unbalanced', 1.1]]
|
225 |
+
>>> parse_prompt_attention('\(literal\]')
|
226 |
+
[['(literal]', 1.0]]
|
227 |
+
>>> parse_prompt_attention('(unnecessary)(parens)')
|
228 |
+
[['unnecessaryparens', 1.1]]
|
229 |
+
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
230 |
+
[['a ', 1.0],
|
231 |
+
['house', 1.5730000000000004],
|
232 |
+
[' ', 1.1],
|
233 |
+
['on', 1.0],
|
234 |
+
[' a ', 1.1],
|
235 |
+
['hill', 0.55],
|
236 |
+
[', sun, ', 1.1],
|
237 |
+
['sky', 1.4641000000000006],
|
238 |
+
['.', 1.1]]
|
239 |
+
"""
|
240 |
|
241 |
+
res = []
|
242 |
+
round_brackets = []
|
243 |
+
square_brackets = []
|
244 |
+
|
245 |
+
round_bracket_multiplier = 1.1
|
246 |
+
square_bracket_multiplier = 1 / 1.1
|
247 |
+
|
248 |
+
def multiply_range(start_position, multiplier):
|
249 |
+
for p in range(start_position, len(res)):
|
250 |
+
res[p][1] *= multiplier
|
251 |
+
|
252 |
+
for m in re_attention.finditer(text):
|
253 |
+
text = m.group(0)
|
254 |
+
weight = m.group(1)
|
255 |
+
|
256 |
+
if text.startswith('\\'):
|
257 |
+
res.append([text[1:], 1.0])
|
258 |
+
elif text == '(':
|
259 |
+
round_brackets.append(len(res))
|
260 |
+
elif text == '[':
|
261 |
+
square_brackets.append(len(res))
|
262 |
+
elif weight is not None and len(round_brackets) > 0:
|
263 |
+
multiply_range(round_brackets.pop(), float(weight))
|
264 |
+
elif text == ')' and len(round_brackets) > 0:
|
265 |
+
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
266 |
+
elif text == ']' and len(square_brackets) > 0:
|
267 |
+
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
268 |
+
else:
|
269 |
+
res.append([text, 1.0])
|
270 |
+
|
271 |
+
for pos in round_brackets:
|
272 |
+
multiply_range(pos, round_bracket_multiplier)
|
273 |
+
|
274 |
+
for pos in square_brackets:
|
275 |
+
multiply_range(pos, square_bracket_multiplier)
|
276 |
+
|
277 |
+
if len(res) == 0:
|
278 |
+
res = [["", 1.0]]
|
279 |
+
|
280 |
+
# merge runs of identical weights
|
281 |
+
i = 0
|
282 |
+
while i + 1 < len(res):
|
283 |
+
if res[i][1] == res[i + 1][1]:
|
284 |
+
res[i][0] += res[i + 1][0]
|
285 |
+
res.pop(i + 1)
|
286 |
+
else:
|
287 |
+
i += 1
|
288 |
+
|
289 |
+
return res
|
290 |
+
|
291 |
+
class WebUIFrozenCLIPEmebedder(AbstractEncoder):
|
292 |
+
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", freeze=True, layer="penultimate"):
|
293 |
+
super(WebUIFrozenCLIPEmebedder, self).__init__()
|
294 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
295 |
+
self.transformer = CLIPModel.from_pretrained(version).text_model
|
296 |
+
self.device = device
|
297 |
+
self.layer = layer
|
298 |
+
if freeze:
|
299 |
+
self.freeze()
|
300 |
+
|
301 |
+
self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
|
302 |
+
self.comma_padding_backtrack = 20
|
303 |
+
|
304 |
+
def freeze(self):
|
305 |
+
self.transformer = self.transformer.eval()
|
306 |
+
for param in self.parameters():
|
307 |
+
param.requires_grad = False
|
308 |
+
|
309 |
+
def tokenize(self, texts):
|
310 |
+
tokenized = self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
|
311 |
+
return tokenized
|
312 |
+
|
313 |
+
def encode_with_transformers(self, tokens):
|
314 |
+
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer!='last')
|
315 |
+
|
316 |
+
if self.layer == 'penultimate':
|
317 |
+
z = outputs.hidden_states[-2]
|
318 |
+
z = self.transformer.final_layer_norm(z)
|
319 |
+
else:
|
320 |
+
z = outputs.last_hidden_state
|
321 |
+
|
322 |
+
return z
|
323 |
+
|
324 |
+
def tokenize_line(self, line):
|
325 |
+
parsed = parse_prompt_attention(line)
|
326 |
+
# print(parsed)
|
327 |
+
|
328 |
+
tokenized = self.tokenize([text for text, _ in parsed])
|
329 |
+
|
330 |
+
remade_tokens = []
|
331 |
+
multipliers = []
|
332 |
+
last_comma = -1
|
333 |
+
|
334 |
+
for tokens, (text, weight) in zip(tokenized, parsed):
|
335 |
+
i = 0
|
336 |
+
while i < len(tokens):
|
337 |
+
token = tokens[i]
|
338 |
+
|
339 |
+
if token == self.comma_token:
|
340 |
+
last_comma = len(remade_tokens)
|
341 |
+
elif self.comma_padding_backtrack != 0 and max(len(remade_tokens),
|
342 |
+
1) % 75 == 0 and last_comma != -1 and len(
|
343 |
+
remade_tokens) - last_comma <= self.comma_padding_backtrack:
|
344 |
+
last_comma += 1
|
345 |
+
reloc_tokens = remade_tokens[last_comma:]
|
346 |
+
reloc_mults = multipliers[last_comma:]
|
347 |
+
|
348 |
+
remade_tokens = remade_tokens[:last_comma]
|
349 |
+
length = len(remade_tokens)
|
350 |
+
|
351 |
+
rem = int(math.ceil(length / 75)) * 75 - length
|
352 |
+
remade_tokens += [self.tokenizer.eos_token_id] * rem + reloc_tokens
|
353 |
+
multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
|
354 |
+
|
355 |
+
remade_tokens.append(token)
|
356 |
+
multipliers.append(weight)
|
357 |
+
i += 1
|
358 |
+
|
359 |
+
token_count = len(remade_tokens)
|
360 |
+
prompt_target_length = math.ceil(max(token_count, 1) / 75) * 75
|
361 |
+
tokens_to_add = prompt_target_length - len(remade_tokens)
|
362 |
+
|
363 |
+
remade_tokens = remade_tokens + [self.tokenizer.eos_token_id] * tokens_to_add
|
364 |
+
multipliers = multipliers + [1.0] * tokens_to_add
|
365 |
+
|
366 |
+
return remade_tokens, multipliers, token_count
|
367 |
+
|
368 |
+
def process_text(self, texts):
|
369 |
+
remade_batch_tokens = []
|
370 |
+
token_count = 0
|
371 |
+
|
372 |
+
cache = {}
|
373 |
+
batch_multipliers = []
|
374 |
+
for line in texts:
|
375 |
+
if line in cache:
|
376 |
+
remade_tokens, multipliers = cache[line]
|
377 |
+
else:
|
378 |
+
remade_tokens, multipliers, current_token_count = self.tokenize_line(line)
|
379 |
+
token_count = max(current_token_count, token_count)
|
380 |
+
|
381 |
+
cache[line] = (remade_tokens, multipliers)
|
382 |
+
|
383 |
+
remade_batch_tokens.append(remade_tokens)
|
384 |
+
batch_multipliers.append(multipliers)
|
385 |
+
|
386 |
+
return batch_multipliers, remade_batch_tokens, token_count
|
387 |
+
|
388 |
+
def process_tokens(self, remade_batch_tokens, batch_multipliers):
|
389 |
+
remade_batch_tokens = [[self.tokenizer.bos_token_id] + x[:75] + [self.tokenizer.eos_token_id] for x in remade_batch_tokens]
|
390 |
+
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
|
391 |
+
|
392 |
+
tokens = torch.asarray(remade_batch_tokens).to(self.device)
|
393 |
+
|
394 |
+
z = self.encode_with_transformers(tokens)
|
395 |
+
|
396 |
+
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
397 |
+
batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers]
|
398 |
+
batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(self.device)
|
399 |
+
original_mean = z.mean()
|
400 |
+
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
|
401 |
+
new_mean = z.mean()
|
402 |
+
z *= original_mean / new_mean
|
403 |
+
|
404 |
+
return z
|
405 |
+
|
406 |
+
def forward(self, text):
|
407 |
+
batch_multipliers, remade_batch_tokens, token_count = self.process_text(text)
|
408 |
+
|
409 |
+
z = None
|
410 |
+
i = 0
|
411 |
+
while max(map(len, remade_batch_tokens)) != 0:
|
412 |
+
rem_tokens = [x[75:] for x in remade_batch_tokens]
|
413 |
+
rem_multipliers = [x[75:] for x in batch_multipliers]
|
414 |
+
|
415 |
+
tokens = []
|
416 |
+
multipliers = []
|
417 |
+
for j in range(len(remade_batch_tokens)):
|
418 |
+
if len(remade_batch_tokens[j]) > 0:
|
419 |
+
tokens.append(remade_batch_tokens[j][:75])
|
420 |
+
multipliers.append(batch_multipliers[j][:75])
|
421 |
+
else:
|
422 |
+
tokens.append([self.tokenizer.eos_token_id] * 75)
|
423 |
+
multipliers.append([1.0] * 75)
|
424 |
+
|
425 |
+
z1 = self.process_tokens(tokens, multipliers)
|
426 |
+
z = z1 if z is None else torch.cat((z, z1), axis=-2)
|
427 |
+
|
428 |
+
remade_batch_tokens = rem_tokens
|
429 |
+
batch_multipliers = rem_multipliers
|
430 |
+
i += 1
|
431 |
+
|
432 |
+
return z
|
433 |
+
|
434 |
+
def encode(self, text):
|
435 |
+
return self(text)
|
436 |
|
|
|
|
|
|
|
437 |
|
438 |
|
439 |
if __name__ == "__main__":
|
|
|
440 |
model = FrozenCLIPEmbedder()
|
441 |
+
count_params(model, verbose=True)
|
ldm/modules/{structure_condition β extra_condition}/__init__.py
RENAMED
File without changes
|
ldm/modules/extra_condition/api.py
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum, unique
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import torch
|
5 |
+
from basicsr.utils import img2tensor
|
6 |
+
from ldm.util import resize_numpy_image
|
7 |
+
from PIL import Image
|
8 |
+
from torch import autocast
|
9 |
+
|
10 |
+
|
11 |
+
@unique
|
12 |
+
class ExtraCondition(Enum):
|
13 |
+
sketch = 0
|
14 |
+
keypose = 1
|
15 |
+
seg = 2
|
16 |
+
depth = 3
|
17 |
+
canny = 4
|
18 |
+
style = 5
|
19 |
+
color = 6
|
20 |
+
openpose = 7
|
21 |
+
|
22 |
+
|
23 |
+
def get_cond_model(opt, cond_type: ExtraCondition):
|
24 |
+
if cond_type == ExtraCondition.sketch:
|
25 |
+
from ldm.modules.extra_condition.model_edge import pidinet
|
26 |
+
model = pidinet()
|
27 |
+
ckp = torch.load('models/table5_pidinet.pth', map_location='cpu')['state_dict']
|
28 |
+
model.load_state_dict({k.replace('module.', ''): v for k, v in ckp.items()}, strict=True)
|
29 |
+
model.to(opt.device)
|
30 |
+
return model
|
31 |
+
elif cond_type == ExtraCondition.seg:
|
32 |
+
raise NotImplementedError
|
33 |
+
elif cond_type == ExtraCondition.keypose:
|
34 |
+
import mmcv
|
35 |
+
from mmdet.apis import init_detector
|
36 |
+
from mmpose.apis import init_pose_model
|
37 |
+
det_config = 'configs/mm/faster_rcnn_r50_fpn_coco.py'
|
38 |
+
det_checkpoint = 'models/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
|
39 |
+
pose_config = 'configs/mm/hrnet_w48_coco_256x192.py'
|
40 |
+
pose_checkpoint = 'models/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth'
|
41 |
+
det_config_mmcv = mmcv.Config.fromfile(det_config)
|
42 |
+
det_model = init_detector(det_config_mmcv, det_checkpoint, device=opt.device)
|
43 |
+
pose_config_mmcv = mmcv.Config.fromfile(pose_config)
|
44 |
+
pose_model = init_pose_model(pose_config_mmcv, pose_checkpoint, device=opt.device)
|
45 |
+
return {'pose_model': pose_model, 'det_model': det_model}
|
46 |
+
elif cond_type == ExtraCondition.depth:
|
47 |
+
from ldm.modules.extra_condition.midas.api import MiDaSInference
|
48 |
+
model = MiDaSInference(model_type='dpt_hybrid').to(opt.device)
|
49 |
+
return model
|
50 |
+
elif cond_type == ExtraCondition.canny:
|
51 |
+
return None
|
52 |
+
elif cond_type == ExtraCondition.style:
|
53 |
+
from transformers import CLIPProcessor, CLIPVisionModel
|
54 |
+
version = 'openai/clip-vit-large-patch14'
|
55 |
+
processor = CLIPProcessor.from_pretrained(version)
|
56 |
+
clip_vision_model = CLIPVisionModel.from_pretrained(version).to(opt.device)
|
57 |
+
return {'processor': processor, 'clip_vision_model': clip_vision_model}
|
58 |
+
elif cond_type == ExtraCondition.color:
|
59 |
+
return None
|
60 |
+
elif cond_type == ExtraCondition.openpose:
|
61 |
+
from ldm.modules.extra_condition.openpose.api import OpenposeInference
|
62 |
+
model = OpenposeInference().to(opt.device)
|
63 |
+
return model
|
64 |
+
else:
|
65 |
+
raise NotImplementedError
|
66 |
+
|
67 |
+
|
68 |
+
def get_cond_sketch(opt, cond_image, cond_inp_type, cond_model=None):
|
69 |
+
if isinstance(cond_image, str):
|
70 |
+
edge = cv2.imread(cond_image)
|
71 |
+
else:
|
72 |
+
# for gradio input, pay attention, it's rgb numpy
|
73 |
+
edge = cv2.cvtColor(cond_image, cv2.COLOR_RGB2BGR)
|
74 |
+
edge = resize_numpy_image(edge, max_resolution=opt.max_resolution, resize_short_edge=opt.resize_short_edge)
|
75 |
+
opt.H, opt.W = edge.shape[:2]
|
76 |
+
if cond_inp_type == 'sketch':
|
77 |
+
edge = img2tensor(edge)[0].unsqueeze(0).unsqueeze(0) / 255.
|
78 |
+
edge = edge.to(opt.device)
|
79 |
+
elif cond_inp_type == 'image':
|
80 |
+
edge = img2tensor(edge).unsqueeze(0) / 255.
|
81 |
+
edge = cond_model(edge.to(opt.device))[-1]
|
82 |
+
else:
|
83 |
+
raise NotImplementedError
|
84 |
+
|
85 |
+
# edge = 1-edge # for white background
|
86 |
+
edge = edge > 0.5
|
87 |
+
edge = edge.float()
|
88 |
+
|
89 |
+
return edge
|
90 |
+
|
91 |
+
|
92 |
+
def get_cond_seg(opt, cond_image, cond_inp_type='image', cond_model=None):
|
93 |
+
if isinstance(cond_image, str):
|
94 |
+
seg = cv2.imread(cond_image)
|
95 |
+
else:
|
96 |
+
seg = cv2.cvtColor(cond_image, cv2.COLOR_RGB2BGR)
|
97 |
+
seg = resize_numpy_image(seg, max_resolution=opt.max_resolution, resize_short_edge=opt.resize_short_edge)
|
98 |
+
opt.H, opt.W = seg.shape[:2]
|
99 |
+
if cond_inp_type == 'seg':
|
100 |
+
seg = img2tensor(seg).unsqueeze(0) / 255.
|
101 |
+
seg = seg.to(opt.device)
|
102 |
+
else:
|
103 |
+
raise NotImplementedError
|
104 |
+
|
105 |
+
return seg
|
106 |
+
|
107 |
+
|
108 |
+
def get_cond_keypose(opt, cond_image, cond_inp_type='image', cond_model=None):
|
109 |
+
if isinstance(cond_image, str):
|
110 |
+
pose = cv2.imread(cond_image)
|
111 |
+
else:
|
112 |
+
pose = cv2.cvtColor(cond_image, cv2.COLOR_RGB2BGR)
|
113 |
+
pose = resize_numpy_image(pose, max_resolution=opt.max_resolution, resize_short_edge=opt.resize_short_edge)
|
114 |
+
opt.H, opt.W = pose.shape[:2]
|
115 |
+
if cond_inp_type == 'keypose':
|
116 |
+
pose = img2tensor(pose).unsqueeze(0) / 255.
|
117 |
+
pose = pose.to(opt.device)
|
118 |
+
elif cond_inp_type == 'image':
|
119 |
+
from ldm.modules.extra_condition.utils import imshow_keypoints
|
120 |
+
from mmdet.apis import inference_detector
|
121 |
+
from mmpose.apis import (inference_top_down_pose_model, process_mmdet_results)
|
122 |
+
|
123 |
+
# mmpose seems not compatible with autocast fp16
|
124 |
+
with autocast("cuda", dtype=torch.float32):
|
125 |
+
mmdet_results = inference_detector(cond_model['det_model'], pose)
|
126 |
+
# keep the person class bounding boxes.
|
127 |
+
person_results = process_mmdet_results(mmdet_results, 1)
|
128 |
+
|
129 |
+
# optional
|
130 |
+
return_heatmap = False
|
131 |
+
dataset = cond_model['pose_model'].cfg.data['test']['type']
|
132 |
+
|
133 |
+
# e.g. use ('backbone', ) to return backbone feature
|
134 |
+
output_layer_names = None
|
135 |
+
pose_results, returned_outputs = inference_top_down_pose_model(
|
136 |
+
cond_model['pose_model'],
|
137 |
+
pose,
|
138 |
+
person_results,
|
139 |
+
bbox_thr=0.2,
|
140 |
+
format='xyxy',
|
141 |
+
dataset=dataset,
|
142 |
+
dataset_info=None,
|
143 |
+
return_heatmap=return_heatmap,
|
144 |
+
outputs=output_layer_names)
|
145 |
+
|
146 |
+
# show the results
|
147 |
+
pose = imshow_keypoints(pose, pose_results, radius=2, thickness=2)
|
148 |
+
pose = img2tensor(pose).unsqueeze(0) / 255.
|
149 |
+
pose = pose.to(opt.device)
|
150 |
+
else:
|
151 |
+
raise NotImplementedError
|
152 |
+
|
153 |
+
return pose
|
154 |
+
|
155 |
+
|
156 |
+
def get_cond_depth(opt, cond_image, cond_inp_type='image', cond_model=None):
|
157 |
+
if isinstance(cond_image, str):
|
158 |
+
depth = cv2.imread(cond_image)
|
159 |
+
else:
|
160 |
+
depth = cv2.cvtColor(cond_image, cv2.COLOR_RGB2BGR)
|
161 |
+
depth = resize_numpy_image(depth, max_resolution=opt.max_resolution, resize_short_edge=opt.resize_short_edge)
|
162 |
+
opt.H, opt.W = depth.shape[:2]
|
163 |
+
if cond_inp_type == 'depth':
|
164 |
+
depth = img2tensor(depth).unsqueeze(0) / 255.
|
165 |
+
depth = depth.to(opt.device)
|
166 |
+
elif cond_inp_type == 'image':
|
167 |
+
depth = img2tensor(depth).unsqueeze(0) / 127.5 - 1.0
|
168 |
+
depth = cond_model(depth.to(opt.device)).repeat(1, 3, 1, 1)
|
169 |
+
depth -= torch.min(depth)
|
170 |
+
depth /= torch.max(depth)
|
171 |
+
else:
|
172 |
+
raise NotImplementedError
|
173 |
+
|
174 |
+
return depth
|
175 |
+
|
176 |
+
|
177 |
+
def get_cond_canny(opt, cond_image, cond_inp_type='image', cond_model=None):
|
178 |
+
if isinstance(cond_image, str):
|
179 |
+
canny = cv2.imread(cond_image)
|
180 |
+
else:
|
181 |
+
canny = cv2.cvtColor(cond_image, cv2.COLOR_RGB2BGR)
|
182 |
+
canny = resize_numpy_image(canny, max_resolution=opt.max_resolution, resize_short_edge=opt.resize_short_edge)
|
183 |
+
opt.H, opt.W = canny.shape[:2]
|
184 |
+
if cond_inp_type == 'canny':
|
185 |
+
canny = img2tensor(canny)[0:1].unsqueeze(0) / 255.
|
186 |
+
canny = canny.to(opt.device)
|
187 |
+
elif cond_inp_type == 'image':
|
188 |
+
canny = cv2.Canny(canny, 100, 200)[..., None]
|
189 |
+
canny = img2tensor(canny).unsqueeze(0) / 255.
|
190 |
+
canny = canny.to(opt.device)
|
191 |
+
else:
|
192 |
+
raise NotImplementedError
|
193 |
+
|
194 |
+
return canny
|
195 |
+
|
196 |
+
|
197 |
+
def get_cond_style(opt, cond_image, cond_inp_type='image', cond_model=None):
|
198 |
+
assert cond_inp_type == 'image'
|
199 |
+
if isinstance(cond_image, str):
|
200 |
+
style = Image.open(cond_image)
|
201 |
+
else:
|
202 |
+
# numpy image to PIL image
|
203 |
+
style = Image.fromarray(cond_image)
|
204 |
+
|
205 |
+
style_for_clip = cond_model['processor'](images=style, return_tensors="pt")['pixel_values']
|
206 |
+
style_feat = cond_model['clip_vision_model'](style_for_clip.to(opt.device))['last_hidden_state']
|
207 |
+
|
208 |
+
return style_feat
|
209 |
+
|
210 |
+
|
211 |
+
def get_cond_color(opt, cond_image, cond_inp_type='image', cond_model=None):
|
212 |
+
if isinstance(cond_image, str):
|
213 |
+
color = cv2.imread(cond_image)
|
214 |
+
else:
|
215 |
+
color = cv2.cvtColor(cond_image, cv2.COLOR_RGB2BGR)
|
216 |
+
color = resize_numpy_image(color, max_resolution=opt.max_resolution, resize_short_edge=opt.resize_short_edge)
|
217 |
+
opt.H, opt.W = color.shape[:2]
|
218 |
+
if cond_inp_type == 'image':
|
219 |
+
color = cv2.resize(color, (opt.W//64, opt.H//64), interpolation=cv2.INTER_CUBIC)
|
220 |
+
color = cv2.resize(color, (opt.W, opt.H), interpolation=cv2.INTER_NEAREST)
|
221 |
+
color = img2tensor(color).unsqueeze(0) / 255.
|
222 |
+
color = color.to(opt.device)
|
223 |
+
return color
|
224 |
+
|
225 |
+
|
226 |
+
def get_cond_openpose(opt, cond_image, cond_inp_type='image', cond_model=None):
|
227 |
+
if isinstance(cond_image, str):
|
228 |
+
openpose_keypose = cv2.imread(cond_image)
|
229 |
+
else:
|
230 |
+
openpose_keypose = cv2.cvtColor(cond_image, cv2.COLOR_RGB2BGR)
|
231 |
+
openpose_keypose = resize_numpy_image(
|
232 |
+
openpose_keypose, max_resolution=opt.max_resolution, resize_short_edge=opt.resize_short_edge)
|
233 |
+
opt.H, opt.W = openpose_keypose.shape[:2]
|
234 |
+
if cond_inp_type == 'openpose':
|
235 |
+
openpose_keypose = img2tensor(openpose_keypose).unsqueeze(0) / 255.
|
236 |
+
openpose_keypose = openpose_keypose.to(opt.device)
|
237 |
+
elif cond_inp_type == 'image':
|
238 |
+
with autocast('cuda', dtype=torch.float32):
|
239 |
+
openpose_keypose = cond_model(openpose_keypose)
|
240 |
+
openpose_keypose = img2tensor(openpose_keypose).unsqueeze(0) / 255.
|
241 |
+
openpose_keypose = openpose_keypose.to(opt.device)
|
242 |
+
|
243 |
+
else:
|
244 |
+
raise NotImplementedError
|
245 |
+
|
246 |
+
return openpose_keypose
|
247 |
+
|
248 |
+
|
249 |
+
def get_adapter_feature(inputs, adapters):
|
250 |
+
ret_feat_map = None
|
251 |
+
ret_feat_seq = None
|
252 |
+
if not isinstance(inputs, list):
|
253 |
+
inputs = [inputs]
|
254 |
+
adapters = [adapters]
|
255 |
+
|
256 |
+
for input, adapter in zip(inputs, adapters):
|
257 |
+
cur_feature = adapter['model'](input)
|
258 |
+
if isinstance(cur_feature, list):
|
259 |
+
if ret_feat_map is None:
|
260 |
+
ret_feat_map = list(map(lambda x: x * adapter['cond_weight'], cur_feature))
|
261 |
+
else:
|
262 |
+
ret_feat_map = list(map(lambda x, y: x + y * adapter['cond_weight'], ret_feat_map, cur_feature))
|
263 |
+
else:
|
264 |
+
if ret_feat_seq is None:
|
265 |
+
ret_feat_seq = cur_feature
|
266 |
+
else:
|
267 |
+
ret_feat_seq = torch.cat([ret_feat_seq, cur_feature], dim=1)
|
268 |
+
|
269 |
+
return ret_feat_map, ret_feat_seq
|
ldm/modules/{structure_condition/midas β extra_condition}/midas/__init__.py
RENAMED
File without changes
|
ldm/modules/{structure_condition β extra_condition}/midas/api.py
RENAMED
@@ -6,10 +6,10 @@ import torch
|
|
6 |
import torch.nn as nn
|
7 |
from torchvision.transforms import Compose
|
8 |
|
9 |
-
from ldm.modules.
|
10 |
-
from ldm.modules.
|
11 |
-
from ldm.modules.
|
12 |
-
from ldm.modules.
|
13 |
|
14 |
|
15 |
ISL_PATHS = {
|
|
|
6 |
import torch.nn as nn
|
7 |
from torchvision.transforms import Compose
|
8 |
|
9 |
+
from ldm.modules.extra_condition.midas.midas.dpt_depth import DPTDepthModel
|
10 |
+
from ldm.modules.extra_condition.midas.midas.midas_net import MidasNet
|
11 |
+
from ldm.modules.extra_condition.midas.midas.midas_net_custom import MidasNet_small
|
12 |
+
from ldm.modules.extra_condition.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet
|
13 |
|
14 |
|
15 |
ISL_PATHS = {
|
ldm/modules/{structure_condition/openpose β extra_condition/midas/midas}/__init__.py
RENAMED
File without changes
|
ldm/modules/{structure_condition β extra_condition}/midas/midas/base_model.py
RENAMED
File without changes
|
ldm/modules/{structure_condition β extra_condition}/midas/midas/blocks.py
RENAMED
File without changes
|
ldm/modules/{structure_condition β extra_condition}/midas/midas/dpt_depth.py
RENAMED
File without changes
|
ldm/modules/{structure_condition β extra_condition}/midas/midas/midas_net.py
RENAMED
File without changes
|
ldm/modules/{structure_condition β extra_condition}/midas/midas/midas_net_custom.py
RENAMED
File without changes
|
ldm/modules/{structure_condition β extra_condition}/midas/midas/transforms.py
RENAMED
File without changes
|