Spaces:
Runtime error
Runtime error
new functions
Browse files- app.py +20 -14
- demo/demos.py +110 -5
- demo/model.py +277 -2
- ldm/models/diffusion/plms.py +14 -2
- ldm/modules/attention.py +126 -47
- ldm/modules/diffusionmodules/model.py +96 -79
- ldm/modules/encoders/adapter.py +128 -0
- ldm/modules/structure_condition/openpose/__init__.py +0 -0
- ldm/modules/structure_condition/openpose/api.py +36 -0
- ldm/modules/structure_condition/openpose/body.py +224 -0
- ldm/modules/structure_condition/openpose/hand.py +86 -0
- ldm/modules/structure_condition/openpose/model.py +219 -0
- ldm/modules/structure_condition/openpose/util.py +203 -0
app.py
CHANGED
@@ -8,14 +8,14 @@ os.system('mim install mmcv-full==1.7.0')
|
|
8 |
|
9 |
from demo.model import Model_all
|
10 |
import gradio as gr
|
11 |
-
from demo.demos import create_demo_keypose, create_demo_sketch, create_demo_draw, create_demo_seg, create_demo_depth, create_demo_depth_keypose
|
12 |
import torch
|
13 |
import subprocess
|
14 |
import shlex
|
15 |
from huggingface_hub import hf_hub_url
|
16 |
|
17 |
urls = {
|
18 |
-
'TencentARC/T2I-Adapter':['models/t2iadapter_keypose_sd14v1.pth', 'models/t2iadapter_seg_sd14v1.pth', 'models/t2iadapter_sketch_sd14v1.pth', 'models/t2iadapter_depth_sd14v1.pth'],
|
19 |
'CompVis/stable-diffusion-v-1-4-original':['sd-v1-4.ckpt'],
|
20 |
'andite/anything-v4.0':['anything-v4.0-pruned.ckpt', 'anything-v4.0.vae.pt'],
|
21 |
}
|
@@ -44,37 +44,43 @@ for url in urls_mmpose:
|
|
44 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
45 |
model = Model_all(device)
|
46 |
|
47 |
-
DESCRIPTION = '''# T2I-Adapter
|
48 |
-
[Paper](https://arxiv.org/abs/2302.08453) [GitHub](https://github.com/TencentARC/T2I-Adapter)
|
49 |
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
54 |
'''
|
55 |
|
56 |
with gr.Blocks(css='style.css') as demo:
|
57 |
gr.Markdown(DESCRIPTION)
|
58 |
|
59 |
-
gr.HTML("""
|
60 |
-
<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
|
61 |
<br/>
|
62 |
<a href="https://huggingface.co/spaces/Adapter/T2I-Adapter?duplicate=true">
|
63 |
<img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
|
64 |
<p/>""")
|
65 |
|
66 |
with gr.Tabs():
|
|
|
|
|
67 |
with gr.TabItem('Keypose'):
|
68 |
create_demo_keypose(model.process_keypose)
|
69 |
with gr.TabItem('Sketch'):
|
70 |
create_demo_sketch(model.process_sketch)
|
71 |
with gr.TabItem('Draw'):
|
72 |
create_demo_draw(model.process_draw)
|
73 |
-
with gr.TabItem('Segmentation'):
|
74 |
-
create_demo_seg(model.process_seg)
|
75 |
with gr.TabItem('Depth'):
|
76 |
create_demo_depth(model.process_depth)
|
77 |
-
with gr.TabItem('
|
78 |
create_demo_depth_keypose(model.process_depth_keypose)
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
demo.queue().launch(debug=True, server_name='0.0.0.0')
|
|
|
8 |
|
9 |
from demo.model import Model_all
|
10 |
import gradio as gr
|
11 |
+
from demo.demos import create_demo_keypose, create_demo_sketch, create_demo_draw, create_demo_seg, create_demo_depth, create_demo_depth_keypose, create_demo_color, create_demo_color_sketch, create_demo_openpose, create_demo_style_sketch
|
12 |
import torch
|
13 |
import subprocess
|
14 |
import shlex
|
15 |
from huggingface_hub import hf_hub_url
|
16 |
|
17 |
urls = {
|
18 |
+
'TencentARC/T2I-Adapter':['models/t2iadapter_keypose_sd14v1.pth', 'models/t2iadapter_color_sd14v1.pth', 'models/t2iadapter_openpose_sd14v1.pth', 'models/t2iadapter_seg_sd14v1.pth', 'models/t2iadapter_sketch_sd14v1.pth', 'models/t2iadapter_depth_sd14v1.pth','third-party-models/body_pose_model.pth', "models/t2iadapter_style_sd14v1.pth"],
|
19 |
'CompVis/stable-diffusion-v-1-4-original':['sd-v1-4.ckpt'],
|
20 |
'andite/anything-v4.0':['anything-v4.0-pruned.ckpt', 'anything-v4.0.vae.pt'],
|
21 |
}
|
|
|
44 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
45 |
model = Model_all(device)
|
46 |
|
47 |
+
DESCRIPTION = '''# T2I-Adapter
|
|
|
48 |
|
49 |
+
Gradio demo for **T2I-Adapter**: [[GitHub]](https://github.com/TencentARC/T2I-Adapter), [[Paper]](https://arxiv.org/abs/2302.08453).
|
50 |
+
|
51 |
+
It also supports **multiple adapters** in the follwing tabs showing **"A adapter + B adapter"**.
|
52 |
+
|
53 |
+
If T2I-Adapter is helpful, please help to ⭐ the [Github Repo](https://github.com/TencentARC/T2I-Adapter) and recommend it to your friends 😊
|
54 |
'''
|
55 |
|
56 |
with gr.Blocks(css='style.css') as demo:
|
57 |
gr.Markdown(DESCRIPTION)
|
58 |
|
59 |
+
gr.HTML("""<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
|
|
|
60 |
<br/>
|
61 |
<a href="https://huggingface.co/spaces/Adapter/T2I-Adapter?duplicate=true">
|
62 |
<img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
|
63 |
<p/>""")
|
64 |
|
65 |
with gr.Tabs():
|
66 |
+
with gr.TabItem('Openpose'):
|
67 |
+
create_demo_openpose(model.process_openpose)
|
68 |
with gr.TabItem('Keypose'):
|
69 |
create_demo_keypose(model.process_keypose)
|
70 |
with gr.TabItem('Sketch'):
|
71 |
create_demo_sketch(model.process_sketch)
|
72 |
with gr.TabItem('Draw'):
|
73 |
create_demo_draw(model.process_draw)
|
|
|
|
|
74 |
with gr.TabItem('Depth'):
|
75 |
create_demo_depth(model.process_depth)
|
76 |
+
with gr.TabItem('Depth + Keypose'):
|
77 |
create_demo_depth_keypose(model.process_depth_keypose)
|
78 |
+
with gr.TabItem('Color'):
|
79 |
+
create_demo_color(model.process_color)
|
80 |
+
with gr.TabItem('Color + Sketch'):
|
81 |
+
create_demo_color_sketch(model.process_color_sketch)
|
82 |
+
with gr.TabItem('Style + Sketch'):
|
83 |
+
create_demo_style_sketch(model.process_style_sketch)
|
84 |
+
with gr.TabItem('Segmentation'):
|
85 |
+
create_demo_seg(model.process_seg)
|
86 |
demo.queue().launch(debug=True, server_name='0.0.0.0')
|
demo/demos.py
CHANGED
@@ -18,11 +18,6 @@ def create_demo_keypose(process):
|
|
18 |
with gr.Blocks() as demo:
|
19 |
with gr.Row():
|
20 |
gr.Markdown('## T2I-Adapter (Keypose)')
|
21 |
-
# with gr.Row():
|
22 |
-
# with gr.Column():
|
23 |
-
# gr.Textbox(value="Hello Memory")
|
24 |
-
# with gr.Column():
|
25 |
-
# gr.JSON(get_system_memory, every=1)
|
26 |
with gr.Row():
|
27 |
with gr.Column():
|
28 |
input_img = gr.Image(source='upload', type="numpy")
|
@@ -44,6 +39,31 @@ def create_demo_keypose(process):
|
|
44 |
run_button.click(fn=process, inputs=ips, outputs=[result])
|
45 |
return demo
|
46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
def create_demo_sketch(process):
|
48 |
with gr.Blocks() as demo:
|
49 |
with gr.Row():
|
@@ -70,6 +90,91 @@ def create_demo_sketch(process):
|
|
70 |
run_button.click(fn=process, inputs=ips, outputs=[result])
|
71 |
return demo
|
72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
def create_demo_seg(process):
|
74 |
with gr.Blocks() as demo:
|
75 |
with gr.Row():
|
|
|
18 |
with gr.Blocks() as demo:
|
19 |
with gr.Row():
|
20 |
gr.Markdown('## T2I-Adapter (Keypose)')
|
|
|
|
|
|
|
|
|
|
|
21 |
with gr.Row():
|
22 |
with gr.Column():
|
23 |
input_img = gr.Image(source='upload', type="numpy")
|
|
|
39 |
run_button.click(fn=process, inputs=ips, outputs=[result])
|
40 |
return demo
|
41 |
|
42 |
+
def create_demo_openpose(process):
|
43 |
+
with gr.Blocks() as demo:
|
44 |
+
with gr.Row():
|
45 |
+
gr.Markdown('## T2I-Adapter (Openpose)')
|
46 |
+
with gr.Row():
|
47 |
+
with gr.Column():
|
48 |
+
input_img = gr.Image(source='upload', type="numpy")
|
49 |
+
prompt = gr.Textbox(label="Prompt")
|
50 |
+
neg_prompt = gr.Textbox(label="Negative Prompt",
|
51 |
+
value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
|
52 |
+
pos_prompt = gr.Textbox(label="Positive Prompt",
|
53 |
+
value = 'crafted, elegant, meticulous, magnificent, maximum details, extremely hyper aesthetic, intricately detailed')
|
54 |
+
with gr.Row():
|
55 |
+
type_in = gr.inputs.Radio(['Openpose', 'Image'], type="value", default='Image', label='Input Types\n (You can input an image or a openpose map)')
|
56 |
+
fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed to produce a fixed output)')
|
57 |
+
run_button = gr.Button(label="Run")
|
58 |
+
con_strength = gr.Slider(label="Controling Strength (The guidance strength of the openpose to the result)", minimum=0, maximum=1, value=1, step=0.1)
|
59 |
+
scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
|
60 |
+
base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
|
61 |
+
with gr.Column():
|
62 |
+
result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
|
63 |
+
ips = [input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model]
|
64 |
+
run_button.click(fn=process, inputs=ips, outputs=[result])
|
65 |
+
return demo
|
66 |
+
|
67 |
def create_demo_sketch(process):
|
68 |
with gr.Blocks() as demo:
|
69 |
with gr.Row():
|
|
|
90 |
run_button.click(fn=process, inputs=ips, outputs=[result])
|
91 |
return demo
|
92 |
|
93 |
+
def create_demo_color_sketch(process):
|
94 |
+
with gr.Blocks() as demo:
|
95 |
+
with gr.Row():
|
96 |
+
gr.Markdown('## T2I-Adapter (Color + Sketch)')
|
97 |
+
with gr.Row():
|
98 |
+
with gr.Column():
|
99 |
+
with gr.Row():
|
100 |
+
input_img_sketch = gr.Image(source='upload', type="numpy", label='Sketch guidance')
|
101 |
+
input_img_color = gr.Image(source='upload', type="numpy", label='Color guidance')
|
102 |
+
prompt = gr.Textbox(label="Prompt")
|
103 |
+
neg_prompt = gr.Textbox(label="Negative Prompt",
|
104 |
+
value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
|
105 |
+
pos_prompt = gr.Textbox(label="Positive Prompt",
|
106 |
+
value = 'crafted, elegant, meticulous, magnificent, maximum details, extremely hyper aesthetic, intricately detailed')
|
107 |
+
type_in_color = gr.inputs.Radio(['ColorMap', 'Image'], type="value", default='Image', label='Input Types of Color\n (You can input an image or a color map)')
|
108 |
+
with gr.Row():
|
109 |
+
type_in = gr.inputs.Radio(['Sketch', 'Image'], type="value", default='Image', label='Input Types of Sketch\n (You can input an image or a sketch)')
|
110 |
+
color_back = gr.inputs.Radio(['White', 'Black'], type="value", default='Black', label='Color of the sketch background\n (Only work for sketch input)')
|
111 |
+
with gr.Row():
|
112 |
+
w_sketch = gr.Slider(label="Depth guidance weight", minimum=0, maximum=2, value=1.0, step=0.1)
|
113 |
+
w_color = gr.Slider(label="Color guidance weight", minimum=0, maximum=2, value=1.2, step=0.1)
|
114 |
+
run_button = gr.Button(label="Run")
|
115 |
+
con_strength = gr.Slider(label="Controling Strength (The guidance strength of the sketch to the result)", minimum=0, maximum=1, value=0.4, step=0.1)
|
116 |
+
scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
|
117 |
+
fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed)')
|
118 |
+
base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
|
119 |
+
with gr.Column():
|
120 |
+
result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=3, height='auto')
|
121 |
+
ips = [input_img_sketch, input_img_color, type_in, type_in_color, w_sketch, w_color, color_back, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model]
|
122 |
+
run_button.click(fn=process, inputs=ips, outputs=[result])
|
123 |
+
return demo
|
124 |
+
|
125 |
+
def create_demo_style_sketch(process):
|
126 |
+
with gr.Blocks() as demo:
|
127 |
+
with gr.Row():
|
128 |
+
gr.Markdown('## T2I-Adapter (Style + Sketch)')
|
129 |
+
with gr.Row():
|
130 |
+
with gr.Column():
|
131 |
+
with gr.Row():
|
132 |
+
input_img_sketch = gr.Image(source='upload', type="numpy", label='Sketch guidance')
|
133 |
+
input_img_style = gr.Image(source='upload', type="numpy", label='Style guidance')
|
134 |
+
prompt = gr.Textbox(label="Prompt")
|
135 |
+
neg_prompt = gr.Textbox(label="Negative Prompt",
|
136 |
+
value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
|
137 |
+
pos_prompt = gr.Textbox(label="Positive Prompt",
|
138 |
+
value = 'crafted, elegant, meticulous, magnificent, maximum details, extremely hyper aesthetic, intricately detailed')
|
139 |
+
with gr.Row():
|
140 |
+
type_in = gr.inputs.Radio(['Sketch', 'Image'], type="value", default='Image', label='Input Types of Sketch\n (You can input an image or a sketch)')
|
141 |
+
color_back = gr.inputs.Radio(['White', 'Black'], type="value", default='Black', label='Color of the sketch background\n (Only work for sketch input)')
|
142 |
+
run_button = gr.Button(label="Run")
|
143 |
+
con_strength = gr.Slider(label="Controling Strength (The guidance strength of the sketch to the result)", minimum=0, maximum=1, value=1, step=0.1)
|
144 |
+
scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
|
145 |
+
fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed)')
|
146 |
+
base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
|
147 |
+
with gr.Column():
|
148 |
+
result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
|
149 |
+
ips = [input_img_sketch, input_img_style, type_in, color_back, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model]
|
150 |
+
run_button.click(fn=process, inputs=ips, outputs=[result])
|
151 |
+
return demo
|
152 |
+
|
153 |
+
def create_demo_color(process):
|
154 |
+
with gr.Blocks() as demo:
|
155 |
+
with gr.Row():
|
156 |
+
gr.Markdown('## T2I-Adapter (Color)')
|
157 |
+
with gr.Row():
|
158 |
+
with gr.Column():
|
159 |
+
input_img = gr.Image(source='upload', type="numpy", label='Color guidance')
|
160 |
+
prompt = gr.Textbox(label="Prompt")
|
161 |
+
neg_prompt = gr.Textbox(label="Negative Prompt",
|
162 |
+
value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
|
163 |
+
pos_prompt = gr.Textbox(label="Positive Prompt",
|
164 |
+
value = 'crafted, elegant, meticulous, magnificent, maximum details, extremely hyper aesthetic, intricately detailed')
|
165 |
+
type_in_color = gr.inputs.Radio(['ColorMap', 'Image'], type="value", default='Image', label='Input Types of Color\n (You can input an image or a color map)')
|
166 |
+
w_color = gr.Slider(label="Color guidance weight", minimum=0, maximum=2, value=1, step=0.1)
|
167 |
+
run_button = gr.Button(label="Run")
|
168 |
+
con_strength = gr.Slider(label="Controling Strength (The guidance strength of the sketch to the result)", minimum=0, maximum=1, value=1, step=0.1)
|
169 |
+
scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
|
170 |
+
fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed)')
|
171 |
+
base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
|
172 |
+
with gr.Column():
|
173 |
+
result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
|
174 |
+
ips = [input_img, prompt, neg_prompt, pos_prompt, w_color, type_in_color, fix_sample, scale, con_strength, base_model]
|
175 |
+
run_button.click(fn=process, inputs=ips, outputs=[result])
|
176 |
+
return demo
|
177 |
+
|
178 |
def create_demo_seg(process):
|
179 |
with gr.Blocks() as demo:
|
180 |
with gr.Row():
|
demo/model.py
CHANGED
@@ -2,7 +2,7 @@ import torch
|
|
2 |
from basicsr.utils import img2tensor, tensor2img
|
3 |
from pytorch_lightning import seed_everything
|
4 |
from ldm.models.diffusion.plms import PLMSSampler
|
5 |
-
from ldm.modules.encoders.adapter import Adapter
|
6 |
from ldm.util import instantiate_from_config
|
7 |
from ldm.modules.structure_condition.model_edge import pidinet
|
8 |
from ldm.modules.structure_condition.model_seg import seger, Colorize
|
@@ -16,6 +16,8 @@ import os
|
|
16 |
import cv2
|
17 |
import numpy as np
|
18 |
import torch.nn.functional as F
|
|
|
|
|
19 |
|
20 |
|
21 |
def preprocessing(image, device):
|
@@ -151,9 +153,9 @@ class Model_all:
|
|
151 |
self.model_seg = Adapter(cin=int(3 * 64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
|
152 |
use_conv=False).to(device)
|
153 |
self.model_seg.load_state_dict(torch.load("models/t2iadapter_seg_sd14v1.pth", map_location=device))
|
154 |
-
self.depth_model = MiDaSInference(model_type='dpt_hybrid').to(device)
|
155 |
|
156 |
# depth part
|
|
|
157 |
self.model_depth = Adapter(cin=3 * 64, channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
|
158 |
use_conv=False).to(device)
|
159 |
self.model_depth.load_state_dict(torch.load("models/t2iadapter_depth_sd14v1.pth", map_location=device))
|
@@ -162,6 +164,23 @@ class Model_all:
|
|
162 |
self.model_pose = Adapter(cin=int(3 * 64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
|
163 |
use_conv=False).to(device)
|
164 |
self.model_pose.load_state_dict(torch.load("models/t2iadapter_keypose_sd14v1.pth", map_location=device))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
## mmpose
|
166 |
det_config = 'models/faster_rcnn_r50_fpn_coco.py'
|
167 |
det_checkpoint = 'models/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
|
@@ -257,7 +276,202 @@ class Model_all:
|
|
257 |
x_samples_ddim = x_samples_ddim.astype(np.uint8)
|
258 |
|
259 |
return [im_edge, x_samples_ddim]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
260 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
261 |
@torch.no_grad()
|
262 |
def process_depth(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale,
|
263 |
con_strength, base_model):
|
@@ -638,6 +852,67 @@ class Model_all:
|
|
638 |
x_samples_ddim = x_samples_ddim.astype(np.uint8)
|
639 |
|
640 |
return [im_pose[:, :, ::-1].astype(np.uint8), x_samples_ddim]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
641 |
|
642 |
|
643 |
if __name__ == '__main__':
|
|
|
2 |
from basicsr.utils import img2tensor, tensor2img
|
3 |
from pytorch_lightning import seed_everything
|
4 |
from ldm.models.diffusion.plms import PLMSSampler
|
5 |
+
from ldm.modules.encoders.adapter import Adapter, Adapter_light, StyleAdapter
|
6 |
from ldm.util import instantiate_from_config
|
7 |
from ldm.modules.structure_condition.model_edge import pidinet
|
8 |
from ldm.modules.structure_condition.model_seg import seger, Colorize
|
|
|
16 |
import cv2
|
17 |
import numpy as np
|
18 |
import torch.nn.functional as F
|
19 |
+
from transformers import CLIPProcessor, CLIPVisionModel
|
20 |
+
from PIL import Image
|
21 |
|
22 |
|
23 |
def preprocessing(image, device):
|
|
|
153 |
self.model_seg = Adapter(cin=int(3 * 64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
|
154 |
use_conv=False).to(device)
|
155 |
self.model_seg.load_state_dict(torch.load("models/t2iadapter_seg_sd14v1.pth", map_location=device))
|
|
|
156 |
|
157 |
# depth part
|
158 |
+
self.depth_model = MiDaSInference(model_type='dpt_hybrid').to(device)
|
159 |
self.model_depth = Adapter(cin=3 * 64, channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
|
160 |
use_conv=False).to(device)
|
161 |
self.model_depth.load_state_dict(torch.load("models/t2iadapter_depth_sd14v1.pth", map_location=device))
|
|
|
164 |
self.model_pose = Adapter(cin=int(3 * 64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
|
165 |
use_conv=False).to(device)
|
166 |
self.model_pose.load_state_dict(torch.load("models/t2iadapter_keypose_sd14v1.pth", map_location=device))
|
167 |
+
|
168 |
+
# openpose part
|
169 |
+
self.model_openpose = Adapter(cin=int(3 * 64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
|
170 |
+
use_conv=False).to(device)
|
171 |
+
self.model_openpose.load_state_dict(torch.load("models/t2iadapter_openpose_sd14v1.pth", map_location=device))
|
172 |
+
|
173 |
+
# color part
|
174 |
+
self.model_color = Adapter_light(cin=int(3 * 64), channels=[320, 640, 1280, 1280], nums_rb=4).to(device)
|
175 |
+
self.model_color.load_state_dict(torch.load("models/t2iadapter_color_sd14v1_small.pth", map_location=device))
|
176 |
+
|
177 |
+
# style part
|
178 |
+
self.model_style = StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8).to(device)
|
179 |
+
self.model_style.load_state_dict(torch.load("models/t2iadapter_style_sd14v1.pth", map_location=device))
|
180 |
+
self.clip_processor = CLIPProcessor.from_pretrained('openai/clip-vit-large-patch14')
|
181 |
+
self.clip_vision_model = CLIPVisionModel.from_pretrained('openai/clip-vit-large-patch14').to(device)
|
182 |
+
|
183 |
+
device = 'cpu'
|
184 |
## mmpose
|
185 |
det_config = 'models/faster_rcnn_r50_fpn_coco.py'
|
186 |
det_checkpoint = 'models/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
|
|
|
276 |
x_samples_ddim = x_samples_ddim.astype(np.uint8)
|
277 |
|
278 |
return [im_edge, x_samples_ddim]
|
279 |
+
|
280 |
+
@torch.no_grad()
|
281 |
+
def process_color_sketch(self, input_img_sketch, input_img_color, type_in, type_in_color, w_sketch, w_color, color_back, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
|
282 |
+
if self.current_base != base_model:
|
283 |
+
ckpt = os.path.join("models", base_model)
|
284 |
+
pl_sd = torch.load(ckpt, map_location="cuda")
|
285 |
+
if "state_dict" in pl_sd:
|
286 |
+
sd = pl_sd["state_dict"]
|
287 |
+
else:
|
288 |
+
sd = pl_sd
|
289 |
+
self.base_model.load_state_dict(sd, strict=False)
|
290 |
+
self.current_base = base_model
|
291 |
+
if 'anything' in base_model.lower():
|
292 |
+
self.load_vae()
|
293 |
+
|
294 |
+
con_strength = int((1 - con_strength) * 50)
|
295 |
+
if fix_sample == 'True':
|
296 |
+
seed_everything(42)
|
297 |
+
im = cv2.resize(input_img_sketch, (512, 512))
|
298 |
+
|
299 |
+
if type_in == 'Sketch':
|
300 |
+
if color_back == 'White':
|
301 |
+
im = 255 - im
|
302 |
+
im_edge = im.copy()
|
303 |
+
im = img2tensor(im)[0].unsqueeze(0).unsqueeze(0) / 255.
|
304 |
+
im = im > 0.5
|
305 |
+
im = im.float()
|
306 |
+
elif type_in == 'Image':
|
307 |
+
im = img2tensor(im).unsqueeze(0) / 255.
|
308 |
+
im = self.model_edge(im.to(self.device))[-1]#.cuda()
|
309 |
+
im = im > 0.5
|
310 |
+
im = im.float()
|
311 |
+
im_edge = tensor2img(im)
|
312 |
+
if type_in_color == 'Image':
|
313 |
+
input_img_color = cv2.resize(input_img_color,(512//64, 512//64), interpolation=cv2.INTER_CUBIC)
|
314 |
+
input_img_color = cv2.resize(input_img_color,(512,512), interpolation=cv2.INTER_NEAREST)
|
315 |
+
else:
|
316 |
+
input_img_color = cv2.resize(input_img_color, (512, 512))
|
317 |
+
im_color = input_img_color.copy()
|
318 |
+
im_color_tensor = img2tensor(input_img_color, bgr2rgb=False).unsqueeze(0) / 255.
|
319 |
+
|
320 |
+
# extract condition features
|
321 |
+
c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
|
322 |
+
nc = self.base_model.get_learned_conditioning([neg_prompt])
|
323 |
+
features_adapter_sketch = self.model_sketch(im.to(self.device))
|
324 |
+
features_adapter_color = self.model_color(im_color_tensor.to(self.device))
|
325 |
+
features_adapter = [fs*w_sketch+fc*w_color for fs, fc in zip(features_adapter_sketch,features_adapter_color)]
|
326 |
+
shape = [4, 64, 64]
|
327 |
+
|
328 |
+
# sampling
|
329 |
+
samples_ddim, _ = self.sampler.sample(S=50,
|
330 |
+
conditioning=c,
|
331 |
+
batch_size=1,
|
332 |
+
shape=shape,
|
333 |
+
verbose=False,
|
334 |
+
unconditional_guidance_scale=scale,
|
335 |
+
unconditional_conditioning=nc,
|
336 |
+
eta=0.0,
|
337 |
+
x_T=None,
|
338 |
+
features_adapter1=features_adapter,
|
339 |
+
mode='sketch',
|
340 |
+
con_strength=con_strength)
|
341 |
+
|
342 |
+
x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
|
343 |
+
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
344 |
+
x_samples_ddim = x_samples_ddim.to('cpu')
|
345 |
+
x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
|
346 |
+
x_samples_ddim = 255. * x_samples_ddim
|
347 |
+
x_samples_ddim = x_samples_ddim.astype(np.uint8)
|
348 |
|
349 |
+
return [im_edge, im_color, x_samples_ddim]
|
350 |
+
|
351 |
+
@torch.no_grad()
|
352 |
+
def process_style_sketch(self, input_img_sketch, input_img_style, type_in, color_back, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
|
353 |
+
if self.current_base != base_model:
|
354 |
+
ckpt = os.path.join("models", base_model)
|
355 |
+
pl_sd = torch.load(ckpt, map_location="cuda")
|
356 |
+
if "state_dict" in pl_sd:
|
357 |
+
sd = pl_sd["state_dict"]
|
358 |
+
else:
|
359 |
+
sd = pl_sd
|
360 |
+
self.base_model.load_state_dict(sd, strict=False)
|
361 |
+
self.current_base = base_model
|
362 |
+
if 'anything' in base_model.lower():
|
363 |
+
self.load_vae()
|
364 |
+
|
365 |
+
con_strength = int((1 - con_strength) * 50)
|
366 |
+
if fix_sample == 'True':
|
367 |
+
seed_everything(42)
|
368 |
+
im = cv2.resize(input_img_sketch, (512, 512))
|
369 |
+
|
370 |
+
if type_in == 'Sketch':
|
371 |
+
if color_back == 'White':
|
372 |
+
im = 255 - im
|
373 |
+
im_edge = im.copy()
|
374 |
+
im = img2tensor(im)[0].unsqueeze(0).unsqueeze(0) / 255.
|
375 |
+
im = im > 0.5
|
376 |
+
im = im.float()
|
377 |
+
elif type_in == 'Image':
|
378 |
+
im = img2tensor(im).unsqueeze(0) / 255.
|
379 |
+
im = self.model_edge(im.to(self.device))[-1]#.cuda()
|
380 |
+
im = im > 0.5
|
381 |
+
im = im.float()
|
382 |
+
im_edge = tensor2img(im)
|
383 |
+
|
384 |
+
style = Image.fromarray(input_img_style)
|
385 |
+
style_for_clip = self.clip_processor(images=style, return_tensors="pt")['pixel_values']
|
386 |
+
style_feat = self.clip_vision_model(style_for_clip.to(self.device))['last_hidden_state']
|
387 |
+
style_feat = self.model_style(style_feat)
|
388 |
+
|
389 |
+
# extract condition features
|
390 |
+
c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
|
391 |
+
nc = self.base_model.get_learned_conditioning([neg_prompt])
|
392 |
+
features_adapter = self.model_sketch(im.to(self.device))
|
393 |
+
shape = [4, 64, 64]
|
394 |
+
|
395 |
+
# sampling
|
396 |
+
samples_ddim, _ = self.sampler.sample(S=50,
|
397 |
+
conditioning=c,
|
398 |
+
batch_size=1,
|
399 |
+
shape=shape,
|
400 |
+
verbose=False,
|
401 |
+
unconditional_guidance_scale=scale,
|
402 |
+
unconditional_conditioning=nc,
|
403 |
+
eta=0.0,
|
404 |
+
x_T=None,
|
405 |
+
features_adapter1=features_adapter,
|
406 |
+
mode='style',
|
407 |
+
con_strength=con_strength,
|
408 |
+
style_feature=style_feat)
|
409 |
+
|
410 |
+
x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
|
411 |
+
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
412 |
+
x_samples_ddim = x_samples_ddim.to('cpu')
|
413 |
+
x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
|
414 |
+
x_samples_ddim = 255. * x_samples_ddim
|
415 |
+
x_samples_ddim = x_samples_ddim.astype(np.uint8)
|
416 |
+
|
417 |
+
return [im_edge, x_samples_ddim]
|
418 |
+
|
419 |
+
@torch.no_grad()
|
420 |
+
def process_color(self, input_img, prompt, neg_prompt, pos_prompt, w_color, type_in_color, fix_sample, scale, con_strength, base_model):
|
421 |
+
if self.current_base != base_model:
|
422 |
+
ckpt = os.path.join("models", base_model)
|
423 |
+
pl_sd = torch.load(ckpt, map_location="cuda")
|
424 |
+
if "state_dict" in pl_sd:
|
425 |
+
sd = pl_sd["state_dict"]
|
426 |
+
else:
|
427 |
+
sd = pl_sd
|
428 |
+
self.base_model.load_state_dict(sd, strict=False)
|
429 |
+
self.current_base = base_model
|
430 |
+
if 'anything' in base_model.lower():
|
431 |
+
self.load_vae()
|
432 |
+
|
433 |
+
con_strength = int((1 - con_strength) * 50)
|
434 |
+
if fix_sample == 'True':
|
435 |
+
seed_everything(42)
|
436 |
+
if type_in_color == 'Image':
|
437 |
+
input_img = cv2.resize(input_img,(512//64, 512//64), interpolation=cv2.INTER_CUBIC)
|
438 |
+
input_img = cv2.resize(input_img,(512,512), interpolation=cv2.INTER_NEAREST)
|
439 |
+
else:
|
440 |
+
input_img = cv2.resize(input_img, (512, 512))
|
441 |
+
|
442 |
+
im_color = input_img.copy()
|
443 |
+
im = img2tensor(input_img, bgr2rgb=False).unsqueeze(0) / 255.
|
444 |
+
|
445 |
+
# extract condition features
|
446 |
+
c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
|
447 |
+
nc = self.base_model.get_learned_conditioning([neg_prompt])
|
448 |
+
features_adapter = self.model_color(im.to(self.device))
|
449 |
+
features_adapter = [fi*w_color for fi in features_adapter]
|
450 |
+
shape = [4, 64, 64]
|
451 |
+
|
452 |
+
# sampling
|
453 |
+
samples_ddim, _ = self.sampler.sample(S=50,
|
454 |
+
conditioning=c,
|
455 |
+
batch_size=1,
|
456 |
+
shape=shape,
|
457 |
+
verbose=False,
|
458 |
+
unconditional_guidance_scale=scale,
|
459 |
+
unconditional_conditioning=nc,
|
460 |
+
eta=0.0,
|
461 |
+
x_T=None,
|
462 |
+
features_adapter1=features_adapter,
|
463 |
+
mode='sketch',
|
464 |
+
con_strength=con_strength)
|
465 |
+
|
466 |
+
x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
|
467 |
+
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
468 |
+
x_samples_ddim = x_samples_ddim.to('cpu')
|
469 |
+
x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
|
470 |
+
x_samples_ddim = 255. * x_samples_ddim
|
471 |
+
x_samples_ddim = x_samples_ddim.astype(np.uint8)
|
472 |
+
|
473 |
+
return [im_color, x_samples_ddim]
|
474 |
+
|
475 |
@torch.no_grad()
|
476 |
def process_depth(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale,
|
477 |
con_strength, base_model):
|
|
|
852 |
x_samples_ddim = x_samples_ddim.astype(np.uint8)
|
853 |
|
854 |
return [im_pose[:, :, ::-1].astype(np.uint8), x_samples_ddim]
|
855 |
+
|
856 |
+
@torch.no_grad()
|
857 |
+
def process_openpose(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength,
|
858 |
+
base_model):
|
859 |
+
if self.current_base != base_model:
|
860 |
+
ckpt = os.path.join("models", base_model)
|
861 |
+
pl_sd = torch.load(ckpt, map_location="cuda")
|
862 |
+
if "state_dict" in pl_sd:
|
863 |
+
sd = pl_sd["state_dict"]
|
864 |
+
else:
|
865 |
+
sd = pl_sd
|
866 |
+
self.base_model.load_state_dict(sd, strict=False)
|
867 |
+
self.current_base = base_model
|
868 |
+
if 'anything' in base_model.lower():
|
869 |
+
self.load_vae()
|
870 |
+
|
871 |
+
con_strength = int((1 - con_strength) * 50)
|
872 |
+
if fix_sample == 'True':
|
873 |
+
seed_everything(42)
|
874 |
+
im = cv2.resize(input_img, (512, 512))
|
875 |
+
|
876 |
+
if type_in == 'Openpose':
|
877 |
+
im_pose = im.copy()[:,:,::-1]
|
878 |
+
elif type_in == 'Image':
|
879 |
+
from ldm.modules.structure_condition.openpose.api import OpenposeInference
|
880 |
+
model = OpenposeInference()
|
881 |
+
keypose = model(im)
|
882 |
+
im_pose = keypose.copy()[:,:,::-1]
|
883 |
+
# keypose = img2tensor(keypose).unsqueeze(0) / 255.
|
884 |
+
|
885 |
+
# extract condition features
|
886 |
+
c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
|
887 |
+
nc = self.base_model.get_learned_conditioning([neg_prompt])
|
888 |
+
pose = img2tensor(im_pose, bgr2rgb=True, float32=True) / 255.
|
889 |
+
pose = pose.unsqueeze(0)
|
890 |
+
features_adapter = self.model_openpose(pose.to(self.device))
|
891 |
+
|
892 |
+
shape = [4, 64, 64]
|
893 |
+
|
894 |
+
# sampling
|
895 |
+
samples_ddim, _ = self.sampler.sample(S=50,
|
896 |
+
conditioning=c,
|
897 |
+
batch_size=1,
|
898 |
+
shape=shape,
|
899 |
+
verbose=False,
|
900 |
+
unconditional_guidance_scale=scale,
|
901 |
+
unconditional_conditioning=nc,
|
902 |
+
eta=0.0,
|
903 |
+
x_T=None,
|
904 |
+
features_adapter1=features_adapter,
|
905 |
+
mode='sketch',
|
906 |
+
con_strength=con_strength)
|
907 |
+
|
908 |
+
x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
|
909 |
+
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
910 |
+
x_samples_ddim = x_samples_ddim.to('cpu')
|
911 |
+
x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
|
912 |
+
x_samples_ddim = 255. * x_samples_ddim
|
913 |
+
x_samples_ddim = x_samples_ddim.astype(np.uint8)
|
914 |
+
|
915 |
+
return [im_pose[:, :, ::-1].astype(np.uint8), x_samples_ddim]
|
916 |
|
917 |
|
918 |
if __name__ == '__main__':
|
ldm/models/diffusion/plms.py
CHANGED
@@ -79,6 +79,7 @@ class PLMSSampler(object):
|
|
79 |
features_adapter2=None,
|
80 |
mode = 'sketch',
|
81 |
con_strength=30,
|
|
|
82 |
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
83 |
**kwargs
|
84 |
):
|
@@ -115,7 +116,8 @@ class PLMSSampler(object):
|
|
115 |
features_adapter1=copy.deepcopy(features_adapter1),
|
116 |
features_adapter2=copy.deepcopy(features_adapter2),
|
117 |
mode = mode,
|
118 |
-
con_strength = con_strength
|
|
|
119 |
)
|
120 |
return samples, intermediates
|
121 |
|
@@ -125,7 +127,7 @@ class PLMSSampler(object):
|
|
125 |
callback=None, timesteps=None, quantize_denoised=False,
|
126 |
mask=None, x0=None, img_callback=None, log_every_t=100,
|
127 |
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
128 |
-
unconditional_guidance_scale=1., unconditional_conditioning=None,features_adapter1=None, features_adapter2=None, mode='sketch', con_strength=30):
|
129 |
device = self.model.betas.device
|
130 |
b = shape[0]
|
131 |
if x_T is None:
|
@@ -161,6 +163,16 @@ class PLMSSampler(object):
|
|
161 |
features_adapter = None
|
162 |
else:
|
163 |
features_adapter = features_adapter1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
elif mode == 'mul':
|
165 |
features_adapter = [a1i*0.5 + a2i for a1i, a2i in zip(features_adapter1, features_adapter2)]
|
166 |
else:
|
|
|
79 |
features_adapter2=None,
|
80 |
mode = 'sketch',
|
81 |
con_strength=30,
|
82 |
+
style_feature=None,
|
83 |
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
84 |
**kwargs
|
85 |
):
|
|
|
116 |
features_adapter1=copy.deepcopy(features_adapter1),
|
117 |
features_adapter2=copy.deepcopy(features_adapter2),
|
118 |
mode = mode,
|
119 |
+
con_strength = con_strength,
|
120 |
+
style_feature=style_feature
|
121 |
)
|
122 |
return samples, intermediates
|
123 |
|
|
|
127 |
callback=None, timesteps=None, quantize_denoised=False,
|
128 |
mask=None, x0=None, img_callback=None, log_every_t=100,
|
129 |
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
130 |
+
unconditional_guidance_scale=1., unconditional_conditioning=None,features_adapter1=None, features_adapter2=None, mode='sketch', con_strength=30, style_feature=None):
|
131 |
device = self.model.betas.device
|
132 |
b = shape[0]
|
133 |
if x_T is None:
|
|
|
163 |
features_adapter = None
|
164 |
else:
|
165 |
features_adapter = features_adapter1
|
166 |
+
elif mode == 'style':
|
167 |
+
if index<con_strength:
|
168 |
+
features_adapter = None
|
169 |
+
else:
|
170 |
+
features_adapter = features_adapter1
|
171 |
+
|
172 |
+
if index>25:
|
173 |
+
cond = torch.cat([cond, style_feature], dim=1)
|
174 |
+
unconditional_conditioning = torch.cat(
|
175 |
+
[unconditional_conditioning, unconditional_conditioning[:, -8:, :]], dim=1)
|
176 |
elif mode == 'mul':
|
177 |
features_adapter = [a1i*0.5 + a2i for a1i, a2i in zip(features_adapter1, features_adapter2)]
|
178 |
else:
|
ldm/modules/attention.py
CHANGED
@@ -4,10 +4,22 @@ import torch
|
|
4 |
import torch.nn.functional as F
|
5 |
from torch import nn, einsum
|
6 |
from einops import rearrange, repeat
|
|
|
7 |
|
8 |
from ldm.modules.diffusionmodules.util import checkpoint
|
9 |
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
def exists(val):
|
12 |
return val is not None
|
13 |
|
@@ -77,25 +89,6 @@ def Normalize(in_channels):
|
|
77 |
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
78 |
|
79 |
|
80 |
-
class LinearAttention(nn.Module):
|
81 |
-
def __init__(self, dim, heads=4, dim_head=32):
|
82 |
-
super().__init__()
|
83 |
-
self.heads = heads
|
84 |
-
hidden_dim = dim_head * heads
|
85 |
-
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
|
86 |
-
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
87 |
-
|
88 |
-
def forward(self, x):
|
89 |
-
b, c, h, w = x.shape
|
90 |
-
qkv = self.to_qkv(x)
|
91 |
-
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
|
92 |
-
k = k.softmax(dim=-1)
|
93 |
-
context = torch.einsum('bhdn,bhen->bhde', k, v)
|
94 |
-
out = torch.einsum('bhde,bhdn->bhen', context, q)
|
95 |
-
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
|
96 |
-
return self.to_out(out)
|
97 |
-
|
98 |
-
|
99 |
class SpatialSelfAttention(nn.Module):
|
100 |
def __init__(self, in_channels):
|
101 |
super().__init__()
|
@@ -177,7 +170,15 @@ class CrossAttention(nn.Module):
|
|
177 |
|
178 |
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
179 |
|
180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
|
182 |
if exists(mask):
|
183 |
mask = rearrange(mask, 'b ... -> b (...)')
|
@@ -186,20 +187,79 @@ class CrossAttention(nn.Module):
|
|
186 |
sim.masked_fill_(~mask, max_neg_value)
|
187 |
|
188 |
# attention, what we cannot get enough of
|
189 |
-
|
190 |
|
191 |
-
out = einsum('b i j, b j d -> b i d',
|
192 |
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
193 |
return self.to_out(out)
|
194 |
|
195 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
class BasicTransformerBlock(nn.Module):
|
197 |
-
|
|
|
|
|
|
|
|
|
|
|
198 |
super().__init__()
|
199 |
-
|
|
|
|
|
|
|
|
|
|
|
200 |
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
201 |
-
self.attn2 =
|
202 |
-
|
203 |
self.norm1 = nn.LayerNorm(dim)
|
204 |
self.norm2 = nn.LayerNorm(dim)
|
205 |
self.norm3 = nn.LayerNorm(dim)
|
@@ -209,7 +269,7 @@ class BasicTransformerBlock(nn.Module):
|
|
209 |
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
|
210 |
|
211 |
def _forward(self, x, context=None):
|
212 |
-
x = self.attn1(self.norm1(x)) + x
|
213 |
x = self.attn2(self.norm2(x), context=context) + x
|
214 |
x = self.ff(self.norm3(x)) + x
|
215 |
return x
|
@@ -222,40 +282,59 @@ class SpatialTransformer(nn.Module):
|
|
222 |
and reshape to b, t, d.
|
223 |
Then apply standard transformer action.
|
224 |
Finally, reshape to image
|
|
|
225 |
"""
|
226 |
def __init__(self, in_channels, n_heads, d_head,
|
227 |
-
depth=1, dropout=0., context_dim=None
|
|
|
|
|
228 |
super().__init__()
|
|
|
|
|
229 |
self.in_channels = in_channels
|
230 |
inner_dim = n_heads * d_head
|
231 |
self.norm = Normalize(in_channels)
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
|
|
|
|
238 |
|
239 |
self.transformer_blocks = nn.ModuleList(
|
240 |
-
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim
|
|
|
241 |
for d in range(depth)]
|
242 |
)
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
|
|
|
|
|
|
249 |
|
250 |
def forward(self, x, context=None):
|
251 |
# note: if no context is given, cross-attention defaults to self-attention
|
|
|
|
|
252 |
b, c, h, w = x.shape
|
253 |
x_in = x
|
254 |
x = self.norm(x)
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
261 |
return x + x_in
|
|
|
4 |
import torch.nn.functional as F
|
5 |
from torch import nn, einsum
|
6 |
from einops import rearrange, repeat
|
7 |
+
from typing import Optional, Any
|
8 |
|
9 |
from ldm.modules.diffusionmodules.util import checkpoint
|
10 |
|
11 |
|
12 |
+
try:
|
13 |
+
import xformers
|
14 |
+
import xformers.ops
|
15 |
+
XFORMERS_IS_AVAILBLE = True
|
16 |
+
except:
|
17 |
+
XFORMERS_IS_AVAILBLE = False
|
18 |
+
|
19 |
+
# CrossAttn precision handling
|
20 |
+
import os
|
21 |
+
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
|
22 |
+
|
23 |
def exists(val):
|
24 |
return val is not None
|
25 |
|
|
|
89 |
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
90 |
|
91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
class SpatialSelfAttention(nn.Module):
|
93 |
def __init__(self, in_channels):
|
94 |
super().__init__()
|
|
|
170 |
|
171 |
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
172 |
|
173 |
+
# force cast to fp32 to avoid overflowing
|
174 |
+
if _ATTN_PRECISION =="fp32":
|
175 |
+
with torch.autocast(enabled=False, device_type = 'cuda'):
|
176 |
+
q, k = q.float(), k.float()
|
177 |
+
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
178 |
+
else:
|
179 |
+
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
180 |
+
|
181 |
+
del q, k
|
182 |
|
183 |
if exists(mask):
|
184 |
mask = rearrange(mask, 'b ... -> b (...)')
|
|
|
187 |
sim.masked_fill_(~mask, max_neg_value)
|
188 |
|
189 |
# attention, what we cannot get enough of
|
190 |
+
sim = sim.softmax(dim=-1)
|
191 |
|
192 |
+
out = einsum('b i j, b j d -> b i d', sim, v)
|
193 |
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
194 |
return self.to_out(out)
|
195 |
|
196 |
|
197 |
+
class MemoryEfficientCrossAttention(nn.Module):
|
198 |
+
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
199 |
+
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
200 |
+
super().__init__()
|
201 |
+
print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
|
202 |
+
f"{heads} heads.")
|
203 |
+
inner_dim = dim_head * heads
|
204 |
+
context_dim = default(context_dim, query_dim)
|
205 |
+
|
206 |
+
self.heads = heads
|
207 |
+
self.dim_head = dim_head
|
208 |
+
|
209 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
210 |
+
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
211 |
+
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
212 |
+
|
213 |
+
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
214 |
+
self.attention_op: Optional[Any] = None
|
215 |
+
|
216 |
+
def forward(self, x, context=None, mask=None):
|
217 |
+
q = self.to_q(x)
|
218 |
+
context = default(context, x)
|
219 |
+
k = self.to_k(context)
|
220 |
+
v = self.to_v(context)
|
221 |
+
|
222 |
+
b, _, _ = q.shape
|
223 |
+
q, k, v = map(
|
224 |
+
lambda t: t.unsqueeze(3)
|
225 |
+
.reshape(b, t.shape[1], self.heads, self.dim_head)
|
226 |
+
.permute(0, 2, 1, 3)
|
227 |
+
.reshape(b * self.heads, t.shape[1], self.dim_head)
|
228 |
+
.contiguous(),
|
229 |
+
(q, k, v),
|
230 |
+
)
|
231 |
+
|
232 |
+
# actually compute the attention, what we cannot get enough of
|
233 |
+
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
|
234 |
+
|
235 |
+
if exists(mask):
|
236 |
+
raise NotImplementedError
|
237 |
+
out = (
|
238 |
+
out.unsqueeze(0)
|
239 |
+
.reshape(b, self.heads, out.shape[1], self.dim_head)
|
240 |
+
.permute(0, 2, 1, 3)
|
241 |
+
.reshape(b, out.shape[1], self.heads * self.dim_head)
|
242 |
+
)
|
243 |
+
return self.to_out(out)
|
244 |
+
|
245 |
+
|
246 |
class BasicTransformerBlock(nn.Module):
|
247 |
+
ATTENTION_MODES = {
|
248 |
+
"softmax": CrossAttention, # vanilla attention
|
249 |
+
"softmax-xformers": MemoryEfficientCrossAttention
|
250 |
+
}
|
251 |
+
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
|
252 |
+
disable_self_attn=False):
|
253 |
super().__init__()
|
254 |
+
attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
|
255 |
+
assert attn_mode in self.ATTENTION_MODES
|
256 |
+
attn_cls = self.ATTENTION_MODES[attn_mode]
|
257 |
+
self.disable_self_attn = disable_self_attn
|
258 |
+
self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
259 |
+
context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
|
260 |
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
261 |
+
self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
|
262 |
+
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
|
263 |
self.norm1 = nn.LayerNorm(dim)
|
264 |
self.norm2 = nn.LayerNorm(dim)
|
265 |
self.norm3 = nn.LayerNorm(dim)
|
|
|
269 |
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
|
270 |
|
271 |
def _forward(self, x, context=None):
|
272 |
+
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
|
273 |
x = self.attn2(self.norm2(x), context=context) + x
|
274 |
x = self.ff(self.norm3(x)) + x
|
275 |
return x
|
|
|
282 |
and reshape to b, t, d.
|
283 |
Then apply standard transformer action.
|
284 |
Finally, reshape to image
|
285 |
+
NEW: use_linear for more efficiency instead of the 1x1 convs
|
286 |
"""
|
287 |
def __init__(self, in_channels, n_heads, d_head,
|
288 |
+
depth=1, dropout=0., context_dim=None,
|
289 |
+
disable_self_attn=False, use_linear=False,
|
290 |
+
use_checkpoint=True):
|
291 |
super().__init__()
|
292 |
+
if exists(context_dim) and not isinstance(context_dim, list):
|
293 |
+
context_dim = [context_dim]
|
294 |
self.in_channels = in_channels
|
295 |
inner_dim = n_heads * d_head
|
296 |
self.norm = Normalize(in_channels)
|
297 |
+
if not use_linear:
|
298 |
+
self.proj_in = nn.Conv2d(in_channels,
|
299 |
+
inner_dim,
|
300 |
+
kernel_size=1,
|
301 |
+
stride=1,
|
302 |
+
padding=0)
|
303 |
+
else:
|
304 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
305 |
|
306 |
self.transformer_blocks = nn.ModuleList(
|
307 |
+
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
|
308 |
+
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
|
309 |
for d in range(depth)]
|
310 |
)
|
311 |
+
if not use_linear:
|
312 |
+
self.proj_out = zero_module(nn.Conv2d(inner_dim,
|
313 |
+
in_channels,
|
314 |
+
kernel_size=1,
|
315 |
+
stride=1,
|
316 |
+
padding=0))
|
317 |
+
else:
|
318 |
+
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
319 |
+
self.use_linear = use_linear
|
320 |
|
321 |
def forward(self, x, context=None):
|
322 |
# note: if no context is given, cross-attention defaults to self-attention
|
323 |
+
if not isinstance(context, list):
|
324 |
+
context = [context]
|
325 |
b, c, h, w = x.shape
|
326 |
x_in = x
|
327 |
x = self.norm(x)
|
328 |
+
if not self.use_linear:
|
329 |
+
x = self.proj_in(x)
|
330 |
+
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
|
331 |
+
if self.use_linear:
|
332 |
+
x = self.proj_in(x)
|
333 |
+
for i, block in enumerate(self.transformer_blocks):
|
334 |
+
x = block(x, context=context[i])
|
335 |
+
if self.use_linear:
|
336 |
+
x = self.proj_out(x)
|
337 |
+
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
338 |
+
if not self.use_linear:
|
339 |
+
x = self.proj_out(x)
|
340 |
return x + x_in
|
ldm/modules/diffusionmodules/model.py
CHANGED
@@ -4,9 +4,17 @@ import torch
|
|
4 |
import torch.nn as nn
|
5 |
import numpy as np
|
6 |
from einops import rearrange
|
|
|
7 |
|
8 |
-
from ldm.
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
|
12 |
def get_timestep_embedding(timesteps, embedding_dim):
|
@@ -141,12 +149,6 @@ class ResnetBlock(nn.Module):
|
|
141 |
return x+h
|
142 |
|
143 |
|
144 |
-
class LinAttnBlock(LinearAttention):
|
145 |
-
"""to match AttnBlock usage"""
|
146 |
-
def __init__(self, in_channels):
|
147 |
-
super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
|
148 |
-
|
149 |
-
|
150 |
class AttnBlock(nn.Module):
|
151 |
def __init__(self, in_channels):
|
152 |
super().__init__()
|
@@ -174,7 +176,6 @@ class AttnBlock(nn.Module):
|
|
174 |
stride=1,
|
175 |
padding=0)
|
176 |
|
177 |
-
|
178 |
def forward(self, x):
|
179 |
h_ = x
|
180 |
h_ = self.norm(h_)
|
@@ -201,16 +202,99 @@ class AttnBlock(nn.Module):
|
|
201 |
|
202 |
return x+h_
|
203 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
|
205 |
-
|
206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
208 |
if attn_type == "vanilla":
|
|
|
209 |
return AttnBlock(in_channels)
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
elif attn_type == "none":
|
211 |
return nn.Identity(in_channels)
|
212 |
else:
|
213 |
-
|
214 |
|
215 |
|
216 |
class Model(nn.Module):
|
@@ -766,70 +850,3 @@ class Resize(nn.Module):
|
|
766 |
else:
|
767 |
x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
|
768 |
return x
|
769 |
-
|
770 |
-
class FirstStagePostProcessor(nn.Module):
|
771 |
-
|
772 |
-
def __init__(self, ch_mult:list, in_channels,
|
773 |
-
pretrained_model:nn.Module=None,
|
774 |
-
reshape=False,
|
775 |
-
n_channels=None,
|
776 |
-
dropout=0.,
|
777 |
-
pretrained_config=None):
|
778 |
-
super().__init__()
|
779 |
-
if pretrained_config is None:
|
780 |
-
assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
|
781 |
-
self.pretrained_model = pretrained_model
|
782 |
-
else:
|
783 |
-
assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
|
784 |
-
self.instantiate_pretrained(pretrained_config)
|
785 |
-
|
786 |
-
self.do_reshape = reshape
|
787 |
-
|
788 |
-
if n_channels is None:
|
789 |
-
n_channels = self.pretrained_model.encoder.ch
|
790 |
-
|
791 |
-
self.proj_norm = Normalize(in_channels,num_groups=in_channels//2)
|
792 |
-
self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3,
|
793 |
-
stride=1,padding=1)
|
794 |
-
|
795 |
-
blocks = []
|
796 |
-
downs = []
|
797 |
-
ch_in = n_channels
|
798 |
-
for m in ch_mult:
|
799 |
-
blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout))
|
800 |
-
ch_in = m * n_channels
|
801 |
-
downs.append(Downsample(ch_in, with_conv=False))
|
802 |
-
|
803 |
-
self.model = nn.ModuleList(blocks)
|
804 |
-
self.downsampler = nn.ModuleList(downs)
|
805 |
-
|
806 |
-
|
807 |
-
def instantiate_pretrained(self, config):
|
808 |
-
model = instantiate_from_config(config)
|
809 |
-
self.pretrained_model = model.eval()
|
810 |
-
# self.pretrained_model.train = False
|
811 |
-
for param in self.pretrained_model.parameters():
|
812 |
-
param.requires_grad = False
|
813 |
-
|
814 |
-
|
815 |
-
@torch.no_grad()
|
816 |
-
def encode_with_pretrained(self,x):
|
817 |
-
c = self.pretrained_model.encode(x)
|
818 |
-
if isinstance(c, DiagonalGaussianDistribution):
|
819 |
-
c = c.mode()
|
820 |
-
return c
|
821 |
-
|
822 |
-
def forward(self,x):
|
823 |
-
z_fs = self.encode_with_pretrained(x)
|
824 |
-
z = self.proj_norm(z_fs)
|
825 |
-
z = self.proj(z)
|
826 |
-
z = nonlinearity(z)
|
827 |
-
|
828 |
-
for submodel, downmodel in zip(self.model,self.downsampler):
|
829 |
-
z = submodel(z,temb=None)
|
830 |
-
z = downmodel(z)
|
831 |
-
|
832 |
-
if self.do_reshape:
|
833 |
-
z = rearrange(z,'b c h w -> b (h w) c')
|
834 |
-
return z
|
835 |
-
|
|
|
4 |
import torch.nn as nn
|
5 |
import numpy as np
|
6 |
from einops import rearrange
|
7 |
+
from typing import Optional, Any
|
8 |
|
9 |
+
from ldm.modules.attention import MemoryEfficientCrossAttention
|
10 |
+
|
11 |
+
try:
|
12 |
+
import xformers
|
13 |
+
import xformers.ops
|
14 |
+
XFORMERS_IS_AVAILBLE = True
|
15 |
+
except:
|
16 |
+
XFORMERS_IS_AVAILBLE = False
|
17 |
+
print("No module 'xformers'. Proceeding without it.")
|
18 |
|
19 |
|
20 |
def get_timestep_embedding(timesteps, embedding_dim):
|
|
|
149 |
return x+h
|
150 |
|
151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
class AttnBlock(nn.Module):
|
153 |
def __init__(self, in_channels):
|
154 |
super().__init__()
|
|
|
176 |
stride=1,
|
177 |
padding=0)
|
178 |
|
|
|
179 |
def forward(self, x):
|
180 |
h_ = x
|
181 |
h_ = self.norm(h_)
|
|
|
202 |
|
203 |
return x+h_
|
204 |
|
205 |
+
class MemoryEfficientAttnBlock(nn.Module):
|
206 |
+
"""
|
207 |
+
Uses xformers efficient implementation,
|
208 |
+
see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
209 |
+
Note: this is a single-head self-attention operation
|
210 |
+
"""
|
211 |
+
#
|
212 |
+
def __init__(self, in_channels):
|
213 |
+
super().__init__()
|
214 |
+
self.in_channels = in_channels
|
215 |
|
216 |
+
self.norm = Normalize(in_channels)
|
217 |
+
self.q = torch.nn.Conv2d(in_channels,
|
218 |
+
in_channels,
|
219 |
+
kernel_size=1,
|
220 |
+
stride=1,
|
221 |
+
padding=0)
|
222 |
+
self.k = torch.nn.Conv2d(in_channels,
|
223 |
+
in_channels,
|
224 |
+
kernel_size=1,
|
225 |
+
stride=1,
|
226 |
+
padding=0)
|
227 |
+
self.v = torch.nn.Conv2d(in_channels,
|
228 |
+
in_channels,
|
229 |
+
kernel_size=1,
|
230 |
+
stride=1,
|
231 |
+
padding=0)
|
232 |
+
self.proj_out = torch.nn.Conv2d(in_channels,
|
233 |
+
in_channels,
|
234 |
+
kernel_size=1,
|
235 |
+
stride=1,
|
236 |
+
padding=0)
|
237 |
+
self.attention_op: Optional[Any] = None
|
238 |
+
|
239 |
+
def forward(self, x):
|
240 |
+
h_ = x
|
241 |
+
h_ = self.norm(h_)
|
242 |
+
q = self.q(h_)
|
243 |
+
k = self.k(h_)
|
244 |
+
v = self.v(h_)
|
245 |
+
|
246 |
+
# compute attention
|
247 |
+
B, C, H, W = q.shape
|
248 |
+
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
|
249 |
+
|
250 |
+
q, k, v = map(
|
251 |
+
lambda t: t.unsqueeze(3)
|
252 |
+
.reshape(B, t.shape[1], 1, C)
|
253 |
+
.permute(0, 2, 1, 3)
|
254 |
+
.reshape(B * 1, t.shape[1], C)
|
255 |
+
.contiguous(),
|
256 |
+
(q, k, v),
|
257 |
+
)
|
258 |
+
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
|
259 |
+
|
260 |
+
out = (
|
261 |
+
out.unsqueeze(0)
|
262 |
+
.reshape(B, 1, out.shape[1], C)
|
263 |
+
.permute(0, 2, 1, 3)
|
264 |
+
.reshape(B, out.shape[1], C)
|
265 |
+
)
|
266 |
+
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
|
267 |
+
out = self.proj_out(out)
|
268 |
+
return x+out
|
269 |
+
|
270 |
+
|
271 |
+
class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
|
272 |
+
def forward(self, x, context=None, mask=None):
|
273 |
+
b, c, h, w = x.shape
|
274 |
+
x = rearrange(x, 'b c h w -> b (h w) c')
|
275 |
+
out = super().forward(x, context=context, mask=mask)
|
276 |
+
out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c)
|
277 |
+
return x + out
|
278 |
+
|
279 |
+
|
280 |
+
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
281 |
+
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
|
282 |
+
if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
|
283 |
+
attn_type = "vanilla-xformers"
|
284 |
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
285 |
if attn_type == "vanilla":
|
286 |
+
assert attn_kwargs is None
|
287 |
return AttnBlock(in_channels)
|
288 |
+
elif attn_type == "vanilla-xformers":
|
289 |
+
print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
|
290 |
+
return MemoryEfficientAttnBlock(in_channels)
|
291 |
+
elif type == "memory-efficient-cross-attn":
|
292 |
+
attn_kwargs["query_dim"] = in_channels
|
293 |
+
return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
|
294 |
elif attn_type == "none":
|
295 |
return nn.Identity(in_channels)
|
296 |
else:
|
297 |
+
raise NotImplementedError()
|
298 |
|
299 |
|
300 |
class Model(nn.Module):
|
|
|
850 |
else:
|
851 |
x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
|
852 |
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ldm/modules/encoders/adapter.py
CHANGED
@@ -2,6 +2,7 @@ import torch
|
|
2 |
import torch.nn as nn
|
3 |
import torch.nn.functional as F
|
4 |
from ldm.modules.attention import SpatialTransformer, BasicTransformerBlock
|
|
|
5 |
|
6 |
def conv_nd(dims, *args, **kwargs):
|
7 |
"""
|
@@ -121,3 +122,130 @@ class Adapter(nn.Module):
|
|
121 |
features.append(x)
|
122 |
|
123 |
return features
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import torch.nn as nn
|
3 |
import torch.nn.functional as F
|
4 |
from ldm.modules.attention import SpatialTransformer, BasicTransformerBlock
|
5 |
+
from collections import OrderedDict
|
6 |
|
7 |
def conv_nd(dims, *args, **kwargs):
|
8 |
"""
|
|
|
122 |
features.append(x)
|
123 |
|
124 |
return features
|
125 |
+
|
126 |
+
|
127 |
+
class ResnetBlock_light(nn.Module):
|
128 |
+
def __init__(self, in_c):
|
129 |
+
super().__init__()
|
130 |
+
self.block1 = nn.Conv2d(in_c, in_c, 3, 1, 1)
|
131 |
+
self.act = nn.ReLU()
|
132 |
+
self.block2 = nn.Conv2d(in_c, in_c, 3, 1, 1)
|
133 |
+
|
134 |
+
def forward(self, x):
|
135 |
+
h = self.block1(x)
|
136 |
+
h = self.act(h)
|
137 |
+
h = self.block2(h)
|
138 |
+
|
139 |
+
return h + x
|
140 |
+
|
141 |
+
|
142 |
+
class extractor(nn.Module):
|
143 |
+
def __init__(self, in_c, inter_c, out_c, nums_rb, down=False):
|
144 |
+
super().__init__()
|
145 |
+
self.in_conv = nn.Conv2d(in_c, inter_c, 1, 1, 0)
|
146 |
+
self.body = []
|
147 |
+
for _ in range(nums_rb):
|
148 |
+
self.body.append(ResnetBlock_light(inter_c))
|
149 |
+
self.body = nn.Sequential(*self.body)
|
150 |
+
self.out_conv = nn.Conv2d(inter_c, out_c, 1, 1, 0)
|
151 |
+
self.down = down
|
152 |
+
if self.down == True:
|
153 |
+
self.down_opt = Downsample(in_c, use_conv=False)
|
154 |
+
|
155 |
+
def forward(self, x):
|
156 |
+
if self.down == True:
|
157 |
+
x = self.down_opt(x)
|
158 |
+
x = self.in_conv(x)
|
159 |
+
x = self.body(x)
|
160 |
+
x = self.out_conv(x)
|
161 |
+
|
162 |
+
return x
|
163 |
+
|
164 |
+
|
165 |
+
class Adapter_light(nn.Module):
|
166 |
+
def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64):
|
167 |
+
super(Adapter_light, self).__init__()
|
168 |
+
self.unshuffle = nn.PixelUnshuffle(8)
|
169 |
+
self.channels = channels
|
170 |
+
self.nums_rb = nums_rb
|
171 |
+
self.body = []
|
172 |
+
for i in range(len(channels)):
|
173 |
+
if i == 0:
|
174 |
+
self.body.append(extractor(in_c=cin, inter_c=channels[i]//4, out_c=channels[i], nums_rb=nums_rb, down=False))
|
175 |
+
else:
|
176 |
+
self.body.append(extractor(in_c=channels[i-1], inter_c=channels[i]//4, out_c=channels[i], nums_rb=nums_rb, down=True))
|
177 |
+
self.body = nn.ModuleList(self.body)
|
178 |
+
|
179 |
+
def forward(self, x):
|
180 |
+
# unshuffle
|
181 |
+
x = self.unshuffle(x)
|
182 |
+
# extract features
|
183 |
+
features = []
|
184 |
+
for i in range(len(self.channels)):
|
185 |
+
x = self.body[i](x)
|
186 |
+
features.append(x)
|
187 |
+
|
188 |
+
return features
|
189 |
+
|
190 |
+
class QuickGELU(nn.Module):
|
191 |
+
|
192 |
+
def forward(self, x: torch.Tensor):
|
193 |
+
return x * torch.sigmoid(1.702 * x)
|
194 |
+
|
195 |
+
class ResidualAttentionBlock(nn.Module):
|
196 |
+
|
197 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
198 |
+
super().__init__()
|
199 |
+
|
200 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
201 |
+
self.ln_1 = LayerNorm(d_model)
|
202 |
+
self.mlp = nn.Sequential(
|
203 |
+
OrderedDict([("c_fc", nn.Linear(d_model, d_model * 4)), ("gelu", QuickGELU()),
|
204 |
+
("c_proj", nn.Linear(d_model * 4, d_model))]))
|
205 |
+
self.ln_2 = LayerNorm(d_model)
|
206 |
+
self.attn_mask = attn_mask
|
207 |
+
|
208 |
+
def attention(self, x: torch.Tensor):
|
209 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
210 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
211 |
+
|
212 |
+
def forward(self, x: torch.Tensor):
|
213 |
+
x = x + self.attention(self.ln_1(x))
|
214 |
+
x = x + self.mlp(self.ln_2(x))
|
215 |
+
return x
|
216 |
+
|
217 |
+
class LayerNorm(nn.LayerNorm):
|
218 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
219 |
+
|
220 |
+
def forward(self, x: torch.Tensor):
|
221 |
+
orig_type = x.dtype
|
222 |
+
ret = super().forward(x.type(torch.float32))
|
223 |
+
return ret.type(orig_type)
|
224 |
+
|
225 |
+
class StyleAdapter(nn.Module):
|
226 |
+
|
227 |
+
def __init__(self, width=1024, context_dim=768, num_head=8, n_layes=3, num_token=4):
|
228 |
+
super().__init__()
|
229 |
+
|
230 |
+
scale = width ** -0.5
|
231 |
+
self.transformer_layes = nn.Sequential(*[ResidualAttentionBlock(width, num_head) for _ in range(n_layes)])
|
232 |
+
self.num_token = num_token
|
233 |
+
self.style_embedding = nn.Parameter(torch.randn(1, num_token, width) * scale)
|
234 |
+
self.ln_post = LayerNorm(width)
|
235 |
+
self.ln_pre = LayerNorm(width)
|
236 |
+
self.proj = nn.Parameter(scale * torch.randn(width, context_dim))
|
237 |
+
|
238 |
+
def forward(self, x):
|
239 |
+
# x shape [N, HW+1, C]
|
240 |
+
style_embedding = self.style_embedding + torch.zeros(
|
241 |
+
(x.shape[0], self.num_token, self.style_embedding.shape[-1]), device=x.device)
|
242 |
+
x = torch.cat([x, style_embedding], dim=1)
|
243 |
+
x = self.ln_pre(x)
|
244 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
245 |
+
x = self.transformer_layes(x)
|
246 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
247 |
+
|
248 |
+
x = self.ln_post(x[:, -self.num_token:, :])
|
249 |
+
x = x @ self.proj
|
250 |
+
|
251 |
+
return x
|
ldm/modules/structure_condition/openpose/__init__.py
ADDED
File without changes
|
ldm/modules/structure_condition/openpose/api.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
7 |
+
|
8 |
+
import cv2
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from . import util
|
12 |
+
from .body import Body
|
13 |
+
|
14 |
+
remote_model_path = "https://drive.google.com/file/d/1EULkcH_hhSU28qVc1jSJpCh2hGOrzpjK/view?usp=share_link"
|
15 |
+
|
16 |
+
|
17 |
+
class OpenposeInference(nn.Module):
|
18 |
+
|
19 |
+
def __init__(self):
|
20 |
+
super().__init__()
|
21 |
+
body_modelpath = os.path.join('models', "body_pose_model.pth")
|
22 |
+
|
23 |
+
if not os.path.exists(body_modelpath):
|
24 |
+
from basicsr.utils.download_util import load_file_from_url
|
25 |
+
load_file_from_url(remote_model_path, model_dir='models')
|
26 |
+
|
27 |
+
self.body_estimation = Body(body_modelpath)
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
x = x[:, :, ::-1].copy()
|
31 |
+
with torch.no_grad():
|
32 |
+
candidate, subset = self.body_estimation(x)
|
33 |
+
canvas = np.zeros_like(x)
|
34 |
+
canvas = util.draw_bodypose(canvas, candidate, subset)
|
35 |
+
canvas = cv2.cvtColor(canvas, cv2.COLOR_RGB2BGR)
|
36 |
+
return canvas
|
ldm/modules/structure_condition/openpose/body.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import time
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import matplotlib
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from scipy.ndimage.filters import gaussian_filter
|
10 |
+
from torchvision import transforms
|
11 |
+
|
12 |
+
from . import util
|
13 |
+
from .model import bodypose_model
|
14 |
+
|
15 |
+
|
16 |
+
class Body(object):
|
17 |
+
|
18 |
+
def __init__(self, model_path):
|
19 |
+
self.model = bodypose_model()
|
20 |
+
if torch.cuda.is_available():
|
21 |
+
self.model = self.model.cuda()
|
22 |
+
print('cuda')
|
23 |
+
model_dict = util.transfer(self.model, torch.load(model_path))
|
24 |
+
self.model.load_state_dict(model_dict)
|
25 |
+
self.model.eval()
|
26 |
+
|
27 |
+
def __call__(self, oriImg):
|
28 |
+
# scale_search = [0.5, 1.0, 1.5, 2.0]
|
29 |
+
scale_search = [0.5]
|
30 |
+
boxsize = 368
|
31 |
+
stride = 8
|
32 |
+
padValue = 128
|
33 |
+
thre1 = 0.1
|
34 |
+
thre2 = 0.05
|
35 |
+
multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
|
36 |
+
heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19))
|
37 |
+
paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))
|
38 |
+
|
39 |
+
for m in range(len(multiplier)):
|
40 |
+
scale = multiplier[m]
|
41 |
+
imageToTest = cv2.resize(oriImg, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
|
42 |
+
imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue)
|
43 |
+
im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5
|
44 |
+
im = np.ascontiguousarray(im)
|
45 |
+
|
46 |
+
data = torch.from_numpy(im).float()
|
47 |
+
if torch.cuda.is_available():
|
48 |
+
data = data.cuda()
|
49 |
+
# data = data.permute([2, 0, 1]).unsqueeze(0).float()
|
50 |
+
with torch.no_grad():
|
51 |
+
Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data)
|
52 |
+
Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy()
|
53 |
+
Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy()
|
54 |
+
|
55 |
+
# extract outputs, resize, and remove padding
|
56 |
+
# heatmap = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[1]].data), (1, 2, 0)) # output 1 is heatmaps
|
57 |
+
heatmap = np.transpose(np.squeeze(Mconv7_stage6_L2), (1, 2, 0)) # output 1 is heatmaps
|
58 |
+
heatmap = cv2.resize(heatmap, (0, 0), fx=stride, fy=stride, interpolation=cv2.INTER_CUBIC)
|
59 |
+
heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
|
60 |
+
heatmap = cv2.resize(heatmap, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC)
|
61 |
+
|
62 |
+
# paf = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[0]].data), (1, 2, 0)) # output 0 is PAFs
|
63 |
+
paf = np.transpose(np.squeeze(Mconv7_stage6_L1), (1, 2, 0)) # output 0 is PAFs
|
64 |
+
paf = cv2.resize(paf, (0, 0), fx=stride, fy=stride, interpolation=cv2.INTER_CUBIC)
|
65 |
+
paf = paf[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
|
66 |
+
paf = cv2.resize(paf, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC)
|
67 |
+
|
68 |
+
heatmap_avg += heatmap_avg + heatmap / len(multiplier)
|
69 |
+
paf_avg += +paf / len(multiplier)
|
70 |
+
|
71 |
+
all_peaks = []
|
72 |
+
peak_counter = 0
|
73 |
+
|
74 |
+
for part in range(18):
|
75 |
+
map_ori = heatmap_avg[:, :, part]
|
76 |
+
one_heatmap = gaussian_filter(map_ori, sigma=3)
|
77 |
+
|
78 |
+
map_left = np.zeros(one_heatmap.shape)
|
79 |
+
map_left[1:, :] = one_heatmap[:-1, :]
|
80 |
+
map_right = np.zeros(one_heatmap.shape)
|
81 |
+
map_right[:-1, :] = one_heatmap[1:, :]
|
82 |
+
map_up = np.zeros(one_heatmap.shape)
|
83 |
+
map_up[:, 1:] = one_heatmap[:, :-1]
|
84 |
+
map_down = np.zeros(one_heatmap.shape)
|
85 |
+
map_down[:, :-1] = one_heatmap[:, 1:]
|
86 |
+
|
87 |
+
peaks_binary = np.logical_and.reduce((one_heatmap >= map_left, one_heatmap >= map_right,
|
88 |
+
one_heatmap >= map_up, one_heatmap >= map_down, one_heatmap > thre1))
|
89 |
+
peaks = list(zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0])) # note reverse
|
90 |
+
peaks_with_score = [x + (map_ori[x[1], x[0]], ) for x in peaks]
|
91 |
+
peak_id = range(peak_counter, peak_counter + len(peaks))
|
92 |
+
peaks_with_score_and_id = [peaks_with_score[i] + (peak_id[i], ) for i in range(len(peak_id))]
|
93 |
+
|
94 |
+
all_peaks.append(peaks_with_score_and_id)
|
95 |
+
peak_counter += len(peaks)
|
96 |
+
|
97 |
+
# find connection in the specified sequence, center 29 is in the position 15
|
98 |
+
limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
|
99 |
+
[10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
|
100 |
+
[1, 16], [16, 18], [3, 17], [6, 18]]
|
101 |
+
# the middle joints heatmap correpondence
|
102 |
+
mapIdx = [[31, 32], [39, 40], [33, 34], [35, 36], [41, 42], [43, 44], [19, 20], [21, 22], \
|
103 |
+
[23, 24], [25, 26], [27, 28], [29, 30], [47, 48], [49, 50], [53, 54], [51, 52], \
|
104 |
+
[55, 56], [37, 38], [45, 46]]
|
105 |
+
|
106 |
+
connection_all = []
|
107 |
+
special_k = []
|
108 |
+
mid_num = 10
|
109 |
+
|
110 |
+
for k in range(len(mapIdx)):
|
111 |
+
score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]]
|
112 |
+
candA = all_peaks[limbSeq[k][0] - 1]
|
113 |
+
candB = all_peaks[limbSeq[k][1] - 1]
|
114 |
+
nA = len(candA)
|
115 |
+
nB = len(candB)
|
116 |
+
indexA, indexB = limbSeq[k]
|
117 |
+
if (nA != 0 and nB != 0):
|
118 |
+
connection_candidate = []
|
119 |
+
for i in range(nA):
|
120 |
+
for j in range(nB):
|
121 |
+
vec = np.subtract(candB[j][:2], candA[i][:2])
|
122 |
+
norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1])
|
123 |
+
norm = max(0.001, norm)
|
124 |
+
vec = np.divide(vec, norm)
|
125 |
+
|
126 |
+
startend = list(zip(np.linspace(candA[i][0], candB[j][0], num=mid_num), \
|
127 |
+
np.linspace(candA[i][1], candB[j][1], num=mid_num)))
|
128 |
+
|
129 |
+
vec_x = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 0] \
|
130 |
+
for I in range(len(startend))])
|
131 |
+
vec_y = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 1] \
|
132 |
+
for I in range(len(startend))])
|
133 |
+
|
134 |
+
score_midpts = np.multiply(vec_x, vec[0]) + np.multiply(vec_y, vec[1])
|
135 |
+
score_with_dist_prior = sum(score_midpts) / len(score_midpts) + min(
|
136 |
+
0.5 * oriImg.shape[0] / norm - 1, 0)
|
137 |
+
criterion1 = len(np.nonzero(score_midpts > thre2)[0]) > 0.8 * len(score_midpts)
|
138 |
+
criterion2 = score_with_dist_prior > 0
|
139 |
+
if criterion1 and criterion2:
|
140 |
+
connection_candidate.append(
|
141 |
+
[i, j, score_with_dist_prior, score_with_dist_prior + candA[i][2] + candB[j][2]])
|
142 |
+
|
143 |
+
connection_candidate = sorted(connection_candidate, key=lambda x: x[2], reverse=True)
|
144 |
+
connection = np.zeros((0, 5))
|
145 |
+
for c in range(len(connection_candidate)):
|
146 |
+
i, j, s = connection_candidate[c][0:3]
|
147 |
+
if (i not in connection[:, 3] and j not in connection[:, 4]):
|
148 |
+
connection = np.vstack([connection, [candA[i][3], candB[j][3], s, i, j]])
|
149 |
+
if (len(connection) >= min(nA, nB)):
|
150 |
+
break
|
151 |
+
|
152 |
+
connection_all.append(connection)
|
153 |
+
else:
|
154 |
+
special_k.append(k)
|
155 |
+
connection_all.append([])
|
156 |
+
|
157 |
+
# last number in each row is the total parts number of that person
|
158 |
+
# the second last number in each row is the score of the overall configuration
|
159 |
+
subset = -1 * np.ones((0, 20))
|
160 |
+
candidate = np.array([item for sublist in all_peaks for item in sublist])
|
161 |
+
|
162 |
+
for k in range(len(mapIdx)):
|
163 |
+
if k not in special_k:
|
164 |
+
partAs = connection_all[k][:, 0]
|
165 |
+
partBs = connection_all[k][:, 1]
|
166 |
+
indexA, indexB = np.array(limbSeq[k]) - 1
|
167 |
+
|
168 |
+
for i in range(len(connection_all[k])): # = 1:size(temp,1)
|
169 |
+
found = 0
|
170 |
+
subset_idx = [-1, -1]
|
171 |
+
for j in range(len(subset)): # 1:size(subset,1):
|
172 |
+
if subset[j][indexA] == partAs[i] or subset[j][indexB] == partBs[i]:
|
173 |
+
subset_idx[found] = j
|
174 |
+
found += 1
|
175 |
+
|
176 |
+
if found == 1:
|
177 |
+
j = subset_idx[0]
|
178 |
+
if subset[j][indexB] != partBs[i]:
|
179 |
+
subset[j][indexB] = partBs[i]
|
180 |
+
subset[j][-1] += 1
|
181 |
+
subset[j][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]
|
182 |
+
elif found == 2: # if found 2 and disjoint, merge them
|
183 |
+
j1, j2 = subset_idx
|
184 |
+
membership = ((subset[j1] >= 0).astype(int) + (subset[j2] >= 0).astype(int))[:-2]
|
185 |
+
if len(np.nonzero(membership == 2)[0]) == 0: # merge
|
186 |
+
subset[j1][:-2] += (subset[j2][:-2] + 1)
|
187 |
+
subset[j1][-2:] += subset[j2][-2:]
|
188 |
+
subset[j1][-2] += connection_all[k][i][2]
|
189 |
+
subset = np.delete(subset, j2, 0)
|
190 |
+
else: # as like found == 1
|
191 |
+
subset[j1][indexB] = partBs[i]
|
192 |
+
subset[j1][-1] += 1
|
193 |
+
subset[j1][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]
|
194 |
+
|
195 |
+
# if find no partA in the subset, create a new subset
|
196 |
+
elif not found and k < 17:
|
197 |
+
row = -1 * np.ones(20)
|
198 |
+
row[indexA] = partAs[i]
|
199 |
+
row[indexB] = partBs[i]
|
200 |
+
row[-1] = 2
|
201 |
+
row[-2] = sum(candidate[connection_all[k][i, :2].astype(int), 2]) + connection_all[k][i][2]
|
202 |
+
subset = np.vstack([subset, row])
|
203 |
+
# delete some rows of subset which has few parts occur
|
204 |
+
deleteIdx = []
|
205 |
+
for i in range(len(subset)):
|
206 |
+
if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4:
|
207 |
+
deleteIdx.append(i)
|
208 |
+
subset = np.delete(subset, deleteIdx, axis=0)
|
209 |
+
|
210 |
+
# subset: n*20 array, 0-17 is the index in candidate, 18 is the total score, 19 is the total parts
|
211 |
+
# candidate: x, y, score, id
|
212 |
+
return candidate, subset
|
213 |
+
|
214 |
+
|
215 |
+
if __name__ == "__main__":
|
216 |
+
body_estimation = Body('../model/body_pose_model.pth')
|
217 |
+
|
218 |
+
test_image = '/group/30042/liangbinxie/Projects/mmpose/test_data/twitter/1.png'
|
219 |
+
oriImg = cv2.imread(test_image) # B,G,R order
|
220 |
+
candidate, subset = body_estimation(oriImg)
|
221 |
+
print(candidate, subset)
|
222 |
+
canvas = util.draw_bodypose(oriImg, candidate, subset)
|
223 |
+
plt.imshow(canvas[:, :, [2, 1, 0]])
|
224 |
+
plt.show()
|
ldm/modules/structure_condition/openpose/hand.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import json
|
3 |
+
import numpy as np
|
4 |
+
import math
|
5 |
+
import time
|
6 |
+
from scipy.ndimage.filters import gaussian_filter
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import matplotlib
|
9 |
+
import torch
|
10 |
+
from skimage.measure import label
|
11 |
+
|
12 |
+
from .model import handpose_model
|
13 |
+
from . import util
|
14 |
+
|
15 |
+
class Hand(object):
|
16 |
+
def __init__(self, model_path):
|
17 |
+
self.model = handpose_model()
|
18 |
+
if torch.cuda.is_available():
|
19 |
+
self.model = self.model.cuda()
|
20 |
+
print('cuda')
|
21 |
+
model_dict = util.transfer(self.model, torch.load(model_path))
|
22 |
+
self.model.load_state_dict(model_dict)
|
23 |
+
self.model.eval()
|
24 |
+
|
25 |
+
def __call__(self, oriImg):
|
26 |
+
scale_search = [0.5, 1.0, 1.5, 2.0]
|
27 |
+
# scale_search = [0.5]
|
28 |
+
boxsize = 368
|
29 |
+
stride = 8
|
30 |
+
padValue = 128
|
31 |
+
thre = 0.05
|
32 |
+
multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
|
33 |
+
heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 22))
|
34 |
+
# paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))
|
35 |
+
|
36 |
+
for m in range(len(multiplier)):
|
37 |
+
scale = multiplier[m]
|
38 |
+
imageToTest = cv2.resize(oriImg, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
|
39 |
+
imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue)
|
40 |
+
im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5
|
41 |
+
im = np.ascontiguousarray(im)
|
42 |
+
|
43 |
+
data = torch.from_numpy(im).float()
|
44 |
+
if torch.cuda.is_available():
|
45 |
+
data = data.cuda()
|
46 |
+
# data = data.permute([2, 0, 1]).unsqueeze(0).float()
|
47 |
+
with torch.no_grad():
|
48 |
+
output = self.model(data).cpu().numpy()
|
49 |
+
# output = self.model(data).numpy()q
|
50 |
+
|
51 |
+
# extract outputs, resize, and remove padding
|
52 |
+
heatmap = np.transpose(np.squeeze(output), (1, 2, 0)) # output 1 is heatmaps
|
53 |
+
heatmap = cv2.resize(heatmap, (0, 0), fx=stride, fy=stride, interpolation=cv2.INTER_CUBIC)
|
54 |
+
heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
|
55 |
+
heatmap = cv2.resize(heatmap, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC)
|
56 |
+
|
57 |
+
heatmap_avg += heatmap / len(multiplier)
|
58 |
+
|
59 |
+
all_peaks = []
|
60 |
+
for part in range(21):
|
61 |
+
map_ori = heatmap_avg[:, :, part]
|
62 |
+
one_heatmap = gaussian_filter(map_ori, sigma=3)
|
63 |
+
binary = np.ascontiguousarray(one_heatmap > thre, dtype=np.uint8)
|
64 |
+
# 全部小于阈值
|
65 |
+
if np.sum(binary) == 0:
|
66 |
+
all_peaks.append([0, 0])
|
67 |
+
continue
|
68 |
+
label_img, label_numbers = label(binary, return_num=True, connectivity=binary.ndim)
|
69 |
+
max_index = np.argmax([np.sum(map_ori[label_img == i]) for i in range(1, label_numbers + 1)]) + 1
|
70 |
+
label_img[label_img != max_index] = 0
|
71 |
+
map_ori[label_img == 0] = 0
|
72 |
+
|
73 |
+
y, x = util.npmax(map_ori)
|
74 |
+
all_peaks.append([x, y])
|
75 |
+
return np.array(all_peaks)
|
76 |
+
|
77 |
+
if __name__ == "__main__":
|
78 |
+
hand_estimation = Hand('../model/hand_pose_model.pth')
|
79 |
+
|
80 |
+
# test_image = '../images/hand.jpg'
|
81 |
+
test_image = '../images/hand.jpg'
|
82 |
+
oriImg = cv2.imread(test_image) # B,G,R order
|
83 |
+
peaks = hand_estimation(oriImg)
|
84 |
+
canvas = util.draw_handpose(oriImg, peaks, True)
|
85 |
+
cv2.imshow('', canvas)
|
86 |
+
cv2.waitKey(0)
|
ldm/modules/structure_condition/openpose/model.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from collections import OrderedDict
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
def make_layers(block, no_relu_layers):
|
8 |
+
layers = []
|
9 |
+
for layer_name, v in block.items():
|
10 |
+
if 'pool' in layer_name:
|
11 |
+
layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1],
|
12 |
+
padding=v[2])
|
13 |
+
layers.append((layer_name, layer))
|
14 |
+
else:
|
15 |
+
conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1],
|
16 |
+
kernel_size=v[2], stride=v[3],
|
17 |
+
padding=v[4])
|
18 |
+
layers.append((layer_name, conv2d))
|
19 |
+
if layer_name not in no_relu_layers:
|
20 |
+
layers.append(('relu_'+layer_name, nn.ReLU(inplace=True)))
|
21 |
+
|
22 |
+
return nn.Sequential(OrderedDict(layers))
|
23 |
+
|
24 |
+
class bodypose_model(nn.Module):
|
25 |
+
def __init__(self):
|
26 |
+
super(bodypose_model, self).__init__()
|
27 |
+
|
28 |
+
# these layers have no relu layer
|
29 |
+
no_relu_layers = ['conv5_5_CPM_L1', 'conv5_5_CPM_L2', 'Mconv7_stage2_L1',\
|
30 |
+
'Mconv7_stage2_L2', 'Mconv7_stage3_L1', 'Mconv7_stage3_L2',\
|
31 |
+
'Mconv7_stage4_L1', 'Mconv7_stage4_L2', 'Mconv7_stage5_L1',\
|
32 |
+
'Mconv7_stage5_L2', 'Mconv7_stage6_L1', 'Mconv7_stage6_L1']
|
33 |
+
blocks = {}
|
34 |
+
block0 = OrderedDict([
|
35 |
+
('conv1_1', [3, 64, 3, 1, 1]),
|
36 |
+
('conv1_2', [64, 64, 3, 1, 1]),
|
37 |
+
('pool1_stage1', [2, 2, 0]),
|
38 |
+
('conv2_1', [64, 128, 3, 1, 1]),
|
39 |
+
('conv2_2', [128, 128, 3, 1, 1]),
|
40 |
+
('pool2_stage1', [2, 2, 0]),
|
41 |
+
('conv3_1', [128, 256, 3, 1, 1]),
|
42 |
+
('conv3_2', [256, 256, 3, 1, 1]),
|
43 |
+
('conv3_3', [256, 256, 3, 1, 1]),
|
44 |
+
('conv3_4', [256, 256, 3, 1, 1]),
|
45 |
+
('pool3_stage1', [2, 2, 0]),
|
46 |
+
('conv4_1', [256, 512, 3, 1, 1]),
|
47 |
+
('conv4_2', [512, 512, 3, 1, 1]),
|
48 |
+
('conv4_3_CPM', [512, 256, 3, 1, 1]),
|
49 |
+
('conv4_4_CPM', [256, 128, 3, 1, 1])
|
50 |
+
])
|
51 |
+
|
52 |
+
|
53 |
+
# Stage 1
|
54 |
+
block1_1 = OrderedDict([
|
55 |
+
('conv5_1_CPM_L1', [128, 128, 3, 1, 1]),
|
56 |
+
('conv5_2_CPM_L1', [128, 128, 3, 1, 1]),
|
57 |
+
('conv5_3_CPM_L1', [128, 128, 3, 1, 1]),
|
58 |
+
('conv5_4_CPM_L1', [128, 512, 1, 1, 0]),
|
59 |
+
('conv5_5_CPM_L1', [512, 38, 1, 1, 0])
|
60 |
+
])
|
61 |
+
|
62 |
+
block1_2 = OrderedDict([
|
63 |
+
('conv5_1_CPM_L2', [128, 128, 3, 1, 1]),
|
64 |
+
('conv5_2_CPM_L2', [128, 128, 3, 1, 1]),
|
65 |
+
('conv5_3_CPM_L2', [128, 128, 3, 1, 1]),
|
66 |
+
('conv5_4_CPM_L2', [128, 512, 1, 1, 0]),
|
67 |
+
('conv5_5_CPM_L2', [512, 19, 1, 1, 0])
|
68 |
+
])
|
69 |
+
blocks['block1_1'] = block1_1
|
70 |
+
blocks['block1_2'] = block1_2
|
71 |
+
|
72 |
+
self.model0 = make_layers(block0, no_relu_layers)
|
73 |
+
|
74 |
+
# Stages 2 - 6
|
75 |
+
for i in range(2, 7):
|
76 |
+
blocks['block%d_1' % i] = OrderedDict([
|
77 |
+
('Mconv1_stage%d_L1' % i, [185, 128, 7, 1, 3]),
|
78 |
+
('Mconv2_stage%d_L1' % i, [128, 128, 7, 1, 3]),
|
79 |
+
('Mconv3_stage%d_L1' % i, [128, 128, 7, 1, 3]),
|
80 |
+
('Mconv4_stage%d_L1' % i, [128, 128, 7, 1, 3]),
|
81 |
+
('Mconv5_stage%d_L1' % i, [128, 128, 7, 1, 3]),
|
82 |
+
('Mconv6_stage%d_L1' % i, [128, 128, 1, 1, 0]),
|
83 |
+
('Mconv7_stage%d_L1' % i, [128, 38, 1, 1, 0])
|
84 |
+
])
|
85 |
+
|
86 |
+
blocks['block%d_2' % i] = OrderedDict([
|
87 |
+
('Mconv1_stage%d_L2' % i, [185, 128, 7, 1, 3]),
|
88 |
+
('Mconv2_stage%d_L2' % i, [128, 128, 7, 1, 3]),
|
89 |
+
('Mconv3_stage%d_L2' % i, [128, 128, 7, 1, 3]),
|
90 |
+
('Mconv4_stage%d_L2' % i, [128, 128, 7, 1, 3]),
|
91 |
+
('Mconv5_stage%d_L2' % i, [128, 128, 7, 1, 3]),
|
92 |
+
('Mconv6_stage%d_L2' % i, [128, 128, 1, 1, 0]),
|
93 |
+
('Mconv7_stage%d_L2' % i, [128, 19, 1, 1, 0])
|
94 |
+
])
|
95 |
+
|
96 |
+
for k in blocks.keys():
|
97 |
+
blocks[k] = make_layers(blocks[k], no_relu_layers)
|
98 |
+
|
99 |
+
self.model1_1 = blocks['block1_1']
|
100 |
+
self.model2_1 = blocks['block2_1']
|
101 |
+
self.model3_1 = blocks['block3_1']
|
102 |
+
self.model4_1 = blocks['block4_1']
|
103 |
+
self.model5_1 = blocks['block5_1']
|
104 |
+
self.model6_1 = blocks['block6_1']
|
105 |
+
|
106 |
+
self.model1_2 = blocks['block1_2']
|
107 |
+
self.model2_2 = blocks['block2_2']
|
108 |
+
self.model3_2 = blocks['block3_2']
|
109 |
+
self.model4_2 = blocks['block4_2']
|
110 |
+
self.model5_2 = blocks['block5_2']
|
111 |
+
self.model6_2 = blocks['block6_2']
|
112 |
+
|
113 |
+
|
114 |
+
def forward(self, x):
|
115 |
+
|
116 |
+
out1 = self.model0(x)
|
117 |
+
|
118 |
+
out1_1 = self.model1_1(out1)
|
119 |
+
out1_2 = self.model1_2(out1)
|
120 |
+
out2 = torch.cat([out1_1, out1_2, out1], 1)
|
121 |
+
|
122 |
+
out2_1 = self.model2_1(out2)
|
123 |
+
out2_2 = self.model2_2(out2)
|
124 |
+
out3 = torch.cat([out2_1, out2_2, out1], 1)
|
125 |
+
|
126 |
+
out3_1 = self.model3_1(out3)
|
127 |
+
out3_2 = self.model3_2(out3)
|
128 |
+
out4 = torch.cat([out3_1, out3_2, out1], 1)
|
129 |
+
|
130 |
+
out4_1 = self.model4_1(out4)
|
131 |
+
out4_2 = self.model4_2(out4)
|
132 |
+
out5 = torch.cat([out4_1, out4_2, out1], 1)
|
133 |
+
|
134 |
+
out5_1 = self.model5_1(out5)
|
135 |
+
out5_2 = self.model5_2(out5)
|
136 |
+
out6 = torch.cat([out5_1, out5_2, out1], 1)
|
137 |
+
|
138 |
+
out6_1 = self.model6_1(out6)
|
139 |
+
out6_2 = self.model6_2(out6)
|
140 |
+
|
141 |
+
return out6_1, out6_2
|
142 |
+
|
143 |
+
class handpose_model(nn.Module):
|
144 |
+
def __init__(self):
|
145 |
+
super(handpose_model, self).__init__()
|
146 |
+
|
147 |
+
# these layers have no relu layer
|
148 |
+
no_relu_layers = ['conv6_2_CPM', 'Mconv7_stage2', 'Mconv7_stage3',\
|
149 |
+
'Mconv7_stage4', 'Mconv7_stage5', 'Mconv7_stage6']
|
150 |
+
# stage 1
|
151 |
+
block1_0 = OrderedDict([
|
152 |
+
('conv1_1', [3, 64, 3, 1, 1]),
|
153 |
+
('conv1_2', [64, 64, 3, 1, 1]),
|
154 |
+
('pool1_stage1', [2, 2, 0]),
|
155 |
+
('conv2_1', [64, 128, 3, 1, 1]),
|
156 |
+
('conv2_2', [128, 128, 3, 1, 1]),
|
157 |
+
('pool2_stage1', [2, 2, 0]),
|
158 |
+
('conv3_1', [128, 256, 3, 1, 1]),
|
159 |
+
('conv3_2', [256, 256, 3, 1, 1]),
|
160 |
+
('conv3_3', [256, 256, 3, 1, 1]),
|
161 |
+
('conv3_4', [256, 256, 3, 1, 1]),
|
162 |
+
('pool3_stage1', [2, 2, 0]),
|
163 |
+
('conv4_1', [256, 512, 3, 1, 1]),
|
164 |
+
('conv4_2', [512, 512, 3, 1, 1]),
|
165 |
+
('conv4_3', [512, 512, 3, 1, 1]),
|
166 |
+
('conv4_4', [512, 512, 3, 1, 1]),
|
167 |
+
('conv5_1', [512, 512, 3, 1, 1]),
|
168 |
+
('conv5_2', [512, 512, 3, 1, 1]),
|
169 |
+
('conv5_3_CPM', [512, 128, 3, 1, 1])
|
170 |
+
])
|
171 |
+
|
172 |
+
block1_1 = OrderedDict([
|
173 |
+
('conv6_1_CPM', [128, 512, 1, 1, 0]),
|
174 |
+
('conv6_2_CPM', [512, 22, 1, 1, 0])
|
175 |
+
])
|
176 |
+
|
177 |
+
blocks = {}
|
178 |
+
blocks['block1_0'] = block1_0
|
179 |
+
blocks['block1_1'] = block1_1
|
180 |
+
|
181 |
+
# stage 2-6
|
182 |
+
for i in range(2, 7):
|
183 |
+
blocks['block%d' % i] = OrderedDict([
|
184 |
+
('Mconv1_stage%d' % i, [150, 128, 7, 1, 3]),
|
185 |
+
('Mconv2_stage%d' % i, [128, 128, 7, 1, 3]),
|
186 |
+
('Mconv3_stage%d' % i, [128, 128, 7, 1, 3]),
|
187 |
+
('Mconv4_stage%d' % i, [128, 128, 7, 1, 3]),
|
188 |
+
('Mconv5_stage%d' % i, [128, 128, 7, 1, 3]),
|
189 |
+
('Mconv6_stage%d' % i, [128, 128, 1, 1, 0]),
|
190 |
+
('Mconv7_stage%d' % i, [128, 22, 1, 1, 0])
|
191 |
+
])
|
192 |
+
|
193 |
+
for k in blocks.keys():
|
194 |
+
blocks[k] = make_layers(blocks[k], no_relu_layers)
|
195 |
+
|
196 |
+
self.model1_0 = blocks['block1_0']
|
197 |
+
self.model1_1 = blocks['block1_1']
|
198 |
+
self.model2 = blocks['block2']
|
199 |
+
self.model3 = blocks['block3']
|
200 |
+
self.model4 = blocks['block4']
|
201 |
+
self.model5 = blocks['block5']
|
202 |
+
self.model6 = blocks['block6']
|
203 |
+
|
204 |
+
def forward(self, x):
|
205 |
+
out1_0 = self.model1_0(x)
|
206 |
+
out1_1 = self.model1_1(out1_0)
|
207 |
+
concat_stage2 = torch.cat([out1_1, out1_0], 1)
|
208 |
+
out_stage2 = self.model2(concat_stage2)
|
209 |
+
concat_stage3 = torch.cat([out_stage2, out1_0], 1)
|
210 |
+
out_stage3 = self.model3(concat_stage3)
|
211 |
+
concat_stage4 = torch.cat([out_stage3, out1_0], 1)
|
212 |
+
out_stage4 = self.model4(concat_stage4)
|
213 |
+
concat_stage5 = torch.cat([out_stage4, out1_0], 1)
|
214 |
+
out_stage5 = self.model5(concat_stage5)
|
215 |
+
concat_stage6 = torch.cat([out_stage5, out1_0], 1)
|
216 |
+
out_stage6 = self.model6(concat_stage6)
|
217 |
+
return out_stage6
|
218 |
+
|
219 |
+
|
ldm/modules/structure_condition/openpose/util.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import matplotlib
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
def padRightDownCorner(img, stride, padValue):
|
9 |
+
h = img.shape[0]
|
10 |
+
w = img.shape[1]
|
11 |
+
|
12 |
+
pad = 4 * [None]
|
13 |
+
pad[0] = 0 # up
|
14 |
+
pad[1] = 0 # left
|
15 |
+
pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
|
16 |
+
pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
|
17 |
+
|
18 |
+
img_padded = img
|
19 |
+
pad_up = np.tile(img_padded[0:1, :, :] * 0 + padValue, (pad[0], 1, 1))
|
20 |
+
img_padded = np.concatenate((pad_up, img_padded), axis=0)
|
21 |
+
pad_left = np.tile(img_padded[:, 0:1, :] * 0 + padValue, (1, pad[1], 1))
|
22 |
+
img_padded = np.concatenate((pad_left, img_padded), axis=1)
|
23 |
+
pad_down = np.tile(img_padded[-2:-1, :, :] * 0 + padValue, (pad[2], 1, 1))
|
24 |
+
img_padded = np.concatenate((img_padded, pad_down), axis=0)
|
25 |
+
pad_right = np.tile(img_padded[:, -2:-1, :] * 0 + padValue, (1, pad[3], 1))
|
26 |
+
img_padded = np.concatenate((img_padded, pad_right), axis=1)
|
27 |
+
|
28 |
+
return img_padded, pad
|
29 |
+
|
30 |
+
|
31 |
+
# transfer caffe model to pytorch which will match the layer name
|
32 |
+
def transfer(model, model_weights):
|
33 |
+
transfered_model_weights = {}
|
34 |
+
for weights_name in model.state_dict().keys():
|
35 |
+
transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])]
|
36 |
+
return transfered_model_weights
|
37 |
+
|
38 |
+
|
39 |
+
# draw the body keypoint and lims
|
40 |
+
def draw_bodypose(canvas, candidate, subset):
|
41 |
+
stickwidth = 4
|
42 |
+
limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
|
43 |
+
[10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
|
44 |
+
[1, 16], [16, 18], [3, 17], [6, 18]]
|
45 |
+
|
46 |
+
colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
|
47 |
+
[0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
|
48 |
+
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
|
49 |
+
for i in range(18):
|
50 |
+
for n in range(len(subset)):
|
51 |
+
index = int(subset[n][i])
|
52 |
+
if index == -1:
|
53 |
+
continue
|
54 |
+
x, y = candidate[index][0:2]
|
55 |
+
cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
|
56 |
+
for i in range(17):
|
57 |
+
for n in range(len(subset)):
|
58 |
+
index = subset[n][np.array(limbSeq[i]) - 1]
|
59 |
+
if -1 in index:
|
60 |
+
continue
|
61 |
+
cur_canvas = canvas.copy()
|
62 |
+
Y = candidate[index.astype(int), 0]
|
63 |
+
X = candidate[index.astype(int), 1]
|
64 |
+
mX = np.mean(X)
|
65 |
+
mY = np.mean(Y)
|
66 |
+
length = ((X[0] - X[1])**2 + (Y[0] - Y[1])**2)**0.5
|
67 |
+
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
|
68 |
+
polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
|
69 |
+
cv2.fillConvexPoly(cur_canvas, polygon, colors[i])
|
70 |
+
canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
|
71 |
+
# plt.imsave("preview.jpg", canvas[:, :, [2, 1, 0]])
|
72 |
+
# plt.imshow(canvas[:, :, [2, 1, 0]])
|
73 |
+
return canvas
|
74 |
+
|
75 |
+
|
76 |
+
# image drawed by opencv is not good.
|
77 |
+
def draw_handpose(canvas, all_hand_peaks, show_number=False):
|
78 |
+
edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \
|
79 |
+
[10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]]
|
80 |
+
|
81 |
+
for peaks in all_hand_peaks:
|
82 |
+
for ie, e in enumerate(edges):
|
83 |
+
if np.sum(np.all(peaks[e], axis=1) == 0) == 0:
|
84 |
+
x1, y1 = peaks[e[0]]
|
85 |
+
x2, y2 = peaks[e[1]]
|
86 |
+
cv2.line(
|
87 |
+
canvas, (x1, y1), (x2, y2),
|
88 |
+
matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255,
|
89 |
+
thickness=2)
|
90 |
+
|
91 |
+
for i, keyponit in enumerate(peaks):
|
92 |
+
x, y = keyponit
|
93 |
+
cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
|
94 |
+
if show_number:
|
95 |
+
cv2.putText(canvas, str(i), (x, y), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (0, 0, 0), lineType=cv2.LINE_AA)
|
96 |
+
return canvas
|
97 |
+
|
98 |
+
|
99 |
+
# detect hand according to body pose keypoints
|
100 |
+
# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
|
101 |
+
def handDetect(candidate, subset, oriImg):
|
102 |
+
# right hand: wrist 4, elbow 3, shoulder 2
|
103 |
+
# left hand: wrist 7, elbow 6, shoulder 5
|
104 |
+
ratioWristElbow = 0.33
|
105 |
+
detect_result = []
|
106 |
+
image_height, image_width = oriImg.shape[0:2]
|
107 |
+
for person in subset.astype(int):
|
108 |
+
# if any of three not detected
|
109 |
+
has_left = np.sum(person[[5, 6, 7]] == -1) == 0
|
110 |
+
has_right = np.sum(person[[2, 3, 4]] == -1) == 0
|
111 |
+
if not (has_left or has_right):
|
112 |
+
continue
|
113 |
+
hands = []
|
114 |
+
#left hand
|
115 |
+
if has_left:
|
116 |
+
left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]]
|
117 |
+
x1, y1 = candidate[left_shoulder_index][:2]
|
118 |
+
x2, y2 = candidate[left_elbow_index][:2]
|
119 |
+
x3, y3 = candidate[left_wrist_index][:2]
|
120 |
+
hands.append([x1, y1, x2, y2, x3, y3, True])
|
121 |
+
# right hand
|
122 |
+
if has_right:
|
123 |
+
right_shoulder_index, right_elbow_index, right_wrist_index = person[[2, 3, 4]]
|
124 |
+
x1, y1 = candidate[right_shoulder_index][:2]
|
125 |
+
x2, y2 = candidate[right_elbow_index][:2]
|
126 |
+
x3, y3 = candidate[right_wrist_index][:2]
|
127 |
+
hands.append([x1, y1, x2, y2, x3, y3, False])
|
128 |
+
|
129 |
+
for x1, y1, x2, y2, x3, y3, is_left in hands:
|
130 |
+
# pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox
|
131 |
+
# handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]);
|
132 |
+
# handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]);
|
133 |
+
# const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow);
|
134 |
+
# const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder);
|
135 |
+
# handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder);
|
136 |
+
x = x3 + ratioWristElbow * (x3 - x2)
|
137 |
+
y = y3 + ratioWristElbow * (y3 - y2)
|
138 |
+
distanceWristElbow = math.sqrt((x3 - x2)**2 + (y3 - y2)**2)
|
139 |
+
distanceElbowShoulder = math.sqrt((x2 - x1)**2 + (y2 - y1)**2)
|
140 |
+
width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
|
141 |
+
# x-y refers to the center --> offset to topLeft point
|
142 |
+
# handRectangle.x -= handRectangle.width / 2.f;
|
143 |
+
# handRectangle.y -= handRectangle.height / 2.f;
|
144 |
+
x -= width / 2
|
145 |
+
y -= width / 2 # width = height
|
146 |
+
# overflow the image
|
147 |
+
if x < 0: x = 0
|
148 |
+
if y < 0: y = 0
|
149 |
+
width1 = width
|
150 |
+
width2 = width
|
151 |
+
if x + width > image_width: width1 = image_width - x
|
152 |
+
if y + width > image_height: width2 = image_height - y
|
153 |
+
width = min(width1, width2)
|
154 |
+
# the max hand box value is 20 pixels
|
155 |
+
if width >= 20:
|
156 |
+
detect_result.append([int(x), int(y), int(width), is_left])
|
157 |
+
'''
|
158 |
+
return value: [[x, y, w, True if left hand else False]].
|
159 |
+
width=height since the network require squared input.
|
160 |
+
x, y is the coordinate of top left
|
161 |
+
'''
|
162 |
+
return detect_result
|
163 |
+
|
164 |
+
|
165 |
+
# get max index of 2d array
|
166 |
+
def npmax(array):
|
167 |
+
arrayindex = array.argmax(1)
|
168 |
+
arrayvalue = array.max(1)
|
169 |
+
i = arrayvalue.argmax()
|
170 |
+
j = arrayindex[i]
|
171 |
+
return i, j
|
172 |
+
|
173 |
+
|
174 |
+
def HWC3(x):
|
175 |
+
assert x.dtype == np.uint8
|
176 |
+
if x.ndim == 2:
|
177 |
+
x = x[:, :, None]
|
178 |
+
assert x.ndim == 3
|
179 |
+
H, W, C = x.shape
|
180 |
+
assert C == 1 or C == 3 or C == 4
|
181 |
+
if C == 3:
|
182 |
+
return x
|
183 |
+
if C == 1:
|
184 |
+
return np.concatenate([x, x, x], axis=2)
|
185 |
+
if C == 4:
|
186 |
+
color = x[:, :, 0:3].astype(np.float32)
|
187 |
+
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
|
188 |
+
y = color * alpha + 255.0 * (1.0 - alpha)
|
189 |
+
y = y.clip(0, 255).astype(np.uint8)
|
190 |
+
return y
|
191 |
+
|
192 |
+
|
193 |
+
def resize_image(input_image, resolution):
|
194 |
+
H, W, C = input_image.shape
|
195 |
+
H = float(H)
|
196 |
+
W = float(W)
|
197 |
+
k = float(resolution) / min(H, W)
|
198 |
+
H *= k
|
199 |
+
W *= k
|
200 |
+
H = int(np.round(H / 64.0)) * 64
|
201 |
+
W = int(np.round(W / 64.0)) * 64
|
202 |
+
img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
|
203 |
+
return img
|