ChenyangSi commited on
Commit
713ec7d
·
1 Parent(s): 453154b

Upload 3 files

Browse files
Files changed (3) hide show
  1. __init__.py +1 -0
  2. app.py +162 -0
  3. free_lunch_utils.py +304 -0
__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from free_lunch_utils import register_upblock2d, register_free_upblock2d, register_crossattn_upblock2d, register_free_crossattn_upblock2d
app.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import argparse, os, sys, glob
2
+ # sys.path.append(os.path.split(sys.path[0])[0])
3
+
4
+ from diffusers import StableDiffusionPipeline
5
+ import torch
6
+ from free_lunch_utils import register_free_upblock2d, register_free_crossattn_upblock2d
7
+
8
+ import gradio as gr
9
+ from PIL import Image
10
+ import torch
11
+ from muse import PipelineMuse
12
+ from diffusers import AutoPipelineForText2Image, UniPCMultistepScheduler
13
+
14
+
15
+ if sd_options == 'SD1.4':
16
+ model_id = "CompVis/stable-diffusion-v1-4"
17
+ elif sd_options == 'SD1.5':
18
+ model_id = "runwayml/stable-diffusion-v1-5"
19
+ elif sd_options == 'SD2.1':
20
+ model_id = "stabilityai/stable-diffusion-2-1"
21
+
22
+ pip_sd = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
23
+ pip_sd = pip_sd.to("cuda")
24
+
25
+
26
+ pip_freeu = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
27
+ pip_freeu = pip_freeu.to("cuda")
28
+ # -------- freeu block registration
29
+ register_free_upblock2d(pipe, b1=1.2, b2=1.4, s1=0.9, s2=0.2)
30
+ register_free_crossattn_upblock2d(pipe, b1=1.2, b2=1.4, s1=0.9, s2=0.2)
31
+ # -------- freeu block registration
32
+
33
+
34
+ def infer(prompt):
35
+
36
+ print("Generating SD:")
37
+ sd_image = pip_sd(prompt).images[0]
38
+
39
+ print("Generating FreeU:")
40
+ freeu_image = pip_freeu(prompt).images[0]
41
+
42
+ # First SD, then freeu
43
+ images = [sd_image, freeu_image]
44
+
45
+ return images
46
+
47
+
48
+ examples = [
49
+ [
50
+ "A small cabin on top of a snowy mountain in the style of Disney, artstation",
51
+ ],
52
+ [
53
+ "a monkey doing yoga on the beach",
54
+ ],
55
+ [
56
+ "half human half cat, a human cat hybrid",
57
+ ],
58
+ [
59
+ "a hedgehog using a calculator",
60
+ ],
61
+ [
62
+ "kanye west | diffuse lighting | fantasy | intricate elegant highly detailed lifelike photorealistic digital painting | artstation",
63
+ ],
64
+ [
65
+ "astronaut pig",
66
+ ],
67
+ [
68
+ "two people shouting at each other",
69
+ ],
70
+ [
71
+ "A linked in profile picture of Elon Musk",
72
+ ],
73
+ [
74
+ "A man looking out of a rainy window",
75
+ ],
76
+ [
77
+ "close up, iron man, eating breakfast in a cabin, symmetrical balance, hyper-realistic --ar 16:9 --style raw"
78
+ ],
79
+ [
80
+ 'A high tech solarpunk utopia in the Amazon rainforest',
81
+ ],
82
+ [
83
+ 'A pikachu fine dining with a view to the Eiffel Tower',
84
+ ],
85
+ [
86
+ 'A mecha robot in a favela in expressionist style',
87
+ ],
88
+ [
89
+ 'an insect robot preparing a delicious meal',
90
+ ],
91
+ ]
92
+
93
+
94
+ css = """
95
+ h1 {
96
+ text-align: center;
97
+ }
98
+
99
+ #component-0 {
100
+ max-width: 730px;
101
+ margin: auto;
102
+ }
103
+ """
104
+
105
+ block = gr.Blocks(css=css)
106
+
107
+ options = ['SD1.4', 'SD1.5', 'SD2.1']
108
+
109
+ with block:
110
+ gr.Markdown("SD vs. FreeU.")
111
+ with gr.Group():
112
+ with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
113
+ with gr.Column():
114
+ text = gr.Textbox(
115
+ label="Enter your prompt",
116
+ show_label=False,
117
+ max_lines=1,
118
+ placeholder="Enter your prompt",
119
+ container=False,
120
+ )
121
+ btn = gr.Button("Generate image", scale=0)
122
+
123
+ with gr.Accordion('FreeU Parameters', open=False):
124
+ sd_options = gr.Dropdown(options, label="SD options")
125
+
126
+ b1 = gr.Slider(label='b1: backbone factor of the first stage block of decoder',
127
+ minimum=1,
128
+ maximum=1.6,
129
+ step=0.01,
130
+ value=1)
131
+ b2 = gr.Slider(label='b2: backbone factor of the second stage block of decoder',
132
+ minimum=1,
133
+ maximum=1.6,
134
+ step=0.01,
135
+ value=1)
136
+ s1 = gr.Slider(label='s1: skip factor of the first stage block of decoder',
137
+ minimum=0,
138
+ maximum=1,
139
+ step=0.1,
140
+ value=1)
141
+ s2 = gr.Slider(label='s2: skip factor of the second stage block of decoder',
142
+ minimum=0,
143
+ maximum=1,
144
+ step=0.1,
145
+ value=1)
146
+
147
+ with gr.Row():
148
+ with gr.Column(min_width=256) as c1:
149
+ image_1 = gr.Image(interactive=False)
150
+ image_1_label = gr.Markdown("SD")
151
+ with gr.Column(min_width=256) as c2:
152
+ image_2 = gr.Image(interactive=False)
153
+ image_2_label = gr.Markdown("FreeU")
154
+
155
+
156
+ ex = gr.Examples(examples=examples, fn=infer, inputs=[text], outputs=[image_1, image_2], cache_examples=False)
157
+ ex.dataset.headers = [""]
158
+
159
+ text.submit(infer, inputs=[text], outputs=[image_1, image_2])
160
+ btn.click(infer, inputs=[text], outputs=[image_1, image_2])
161
+
162
+ block.launch()
free_lunch_utils.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.fft as fft
3
+ from diffusers.models.unet_2d_condition import logger
4
+ from diffusers.utils import is_torch_version
5
+ from typing import Any, Dict, List, Optional, Tuple, Union
6
+
7
+
8
+ def isinstance_str(x: object, cls_name: str):
9
+ """
10
+ Checks whether x has any class *named* cls_name in its ancestry.
11
+ Doesn't require access to the class's implementation.
12
+
13
+ Useful for patching!
14
+ """
15
+
16
+ for _cls in x.__class__.__mro__:
17
+ if _cls.__name__ == cls_name:
18
+ return True
19
+
20
+ return False
21
+
22
+
23
+ def Fourier_filter(x, threshold, scale):
24
+ dtype = x.dtype
25
+ x = x.type(torch.float32)
26
+ # FFT
27
+ x_freq = fft.fftn(x, dim=(-2, -1))
28
+ x_freq = fft.fftshift(x_freq, dim=(-2, -1))
29
+
30
+ B, C, H, W = x_freq.shape
31
+ mask = torch.ones((B, C, H, W)).cuda()
32
+
33
+ crow, ccol = H // 2, W //2
34
+ mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale
35
+ x_freq = x_freq * mask
36
+
37
+ # IFFT
38
+ x_freq = fft.ifftshift(x_freq, dim=(-2, -1))
39
+ x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real
40
+
41
+ x_filtered = x_filtered.type(dtype)
42
+ return x_filtered
43
+
44
+
45
+ def register_upblock2d(model):
46
+ def up_forward(self):
47
+ def forward(hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
48
+ for resnet in self.resnets:
49
+ # pop res hidden states
50
+ res_hidden_states = res_hidden_states_tuple[-1]
51
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
52
+ #print(f"in upblock2d, hidden states shape: {hidden_states.shape}")
53
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
54
+
55
+ if self.training and self.gradient_checkpointing:
56
+
57
+ def create_custom_forward(module):
58
+ def custom_forward(*inputs):
59
+ return module(*inputs)
60
+
61
+ return custom_forward
62
+
63
+ if is_torch_version(">=", "1.11.0"):
64
+ hidden_states = torch.utils.checkpoint.checkpoint(
65
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
66
+ )
67
+ else:
68
+ hidden_states = torch.utils.checkpoint.checkpoint(
69
+ create_custom_forward(resnet), hidden_states, temb
70
+ )
71
+ else:
72
+ hidden_states = resnet(hidden_states, temb)
73
+
74
+ if self.upsamplers is not None:
75
+ for upsampler in self.upsamplers:
76
+ hidden_states = upsampler(hidden_states, upsample_size)
77
+
78
+ return hidden_states
79
+
80
+ return forward
81
+
82
+ for i, upsample_block in enumerate(model.unet.up_blocks):
83
+ if isinstance_str(upsample_block, "UpBlock2D"):
84
+ upsample_block.forward = up_forward(upsample_block)
85
+
86
+
87
+ def register_free_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2):
88
+ def up_forward(self):
89
+ def forward(hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
90
+ for resnet in self.resnets:
91
+ # pop res hidden states
92
+ res_hidden_states = res_hidden_states_tuple[-1]
93
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
94
+ #print(f"in free upblock2d, hidden states shape: {hidden_states.shape}")
95
+
96
+ # --------------- FreeU code -----------------------
97
+ # Only operate on the first two stages
98
+ if hidden_states.shape[1] == 1280:
99
+ hidden_states[:,:640] = hidden_states[:,:640] * self.b1
100
+ res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
101
+ if hidden_states.shape[1] == 640:
102
+ hidden_states[:,:320] = hidden_states[:,:320] * self.b2
103
+ res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
104
+ # ---------------------------------------------------------
105
+
106
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
107
+
108
+ if self.training and self.gradient_checkpointing:
109
+
110
+ def create_custom_forward(module):
111
+ def custom_forward(*inputs):
112
+ return module(*inputs)
113
+
114
+ return custom_forward
115
+
116
+ if is_torch_version(">=", "1.11.0"):
117
+ hidden_states = torch.utils.checkpoint.checkpoint(
118
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
119
+ )
120
+ else:
121
+ hidden_states = torch.utils.checkpoint.checkpoint(
122
+ create_custom_forward(resnet), hidden_states, temb
123
+ )
124
+ else:
125
+ hidden_states = resnet(hidden_states, temb)
126
+
127
+ if self.upsamplers is not None:
128
+ for upsampler in self.upsamplers:
129
+ hidden_states = upsampler(hidden_states, upsample_size)
130
+
131
+ return hidden_states
132
+
133
+ return forward
134
+
135
+ for i, upsample_block in enumerate(model.unet.up_blocks):
136
+ if isinstance_str(upsample_block, "UpBlock2D"):
137
+ upsample_block.forward = up_forward(upsample_block)
138
+ setattr(upsample_block, 'b1', b1)
139
+ setattr(upsample_block, 'b2', b2)
140
+ setattr(upsample_block, 's1', s1)
141
+ setattr(upsample_block, 's2', s2)
142
+
143
+
144
+ def register_crossattn_upblock2d(model):
145
+ def up_forward(self):
146
+ def forward(
147
+ hidden_states: torch.FloatTensor,
148
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
149
+ temb: Optional[torch.FloatTensor] = None,
150
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
151
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
152
+ upsample_size: Optional[int] = None,
153
+ attention_mask: Optional[torch.FloatTensor] = None,
154
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
155
+ ):
156
+ for resnet, attn in zip(self.resnets, self.attentions):
157
+ # pop res hidden states
158
+ #print(f"in crossatten upblock2d, hidden states shape: {hidden_states.shape}")
159
+ res_hidden_states = res_hidden_states_tuple[-1]
160
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
161
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
162
+
163
+ if self.training and self.gradient_checkpointing:
164
+
165
+ def create_custom_forward(module, return_dict=None):
166
+ def custom_forward(*inputs):
167
+ if return_dict is not None:
168
+ return module(*inputs, return_dict=return_dict)
169
+ else:
170
+ return module(*inputs)
171
+
172
+ return custom_forward
173
+
174
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
175
+ hidden_states = torch.utils.checkpoint.checkpoint(
176
+ create_custom_forward(resnet),
177
+ hidden_states,
178
+ temb,
179
+ **ckpt_kwargs,
180
+ )
181
+ hidden_states = torch.utils.checkpoint.checkpoint(
182
+ create_custom_forward(attn, return_dict=False),
183
+ hidden_states,
184
+ encoder_hidden_states,
185
+ None, # timestep
186
+ None, # class_labels
187
+ cross_attention_kwargs,
188
+ attention_mask,
189
+ encoder_attention_mask,
190
+ **ckpt_kwargs,
191
+ )[0]
192
+ else:
193
+ hidden_states = resnet(hidden_states, temb)
194
+ hidden_states = attn(
195
+ hidden_states,
196
+ encoder_hidden_states=encoder_hidden_states,
197
+ cross_attention_kwargs=cross_attention_kwargs,
198
+ attention_mask=attention_mask,
199
+ encoder_attention_mask=encoder_attention_mask,
200
+ return_dict=False,
201
+ )[0]
202
+
203
+ if self.upsamplers is not None:
204
+ for upsampler in self.upsamplers:
205
+ hidden_states = upsampler(hidden_states, upsample_size)
206
+
207
+ return hidden_states
208
+
209
+ return forward
210
+
211
+ for i, upsample_block in enumerate(model.unet.up_blocks):
212
+ if isinstance_str(upsample_block, "CrossAttnUpBlock2D"):
213
+ upsample_block.forward = up_forward(upsample_block)
214
+
215
+
216
+ def register_free_crossattn_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2):
217
+ def up_forward(self):
218
+ def forward(
219
+ hidden_states: torch.FloatTensor,
220
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
221
+ temb: Optional[torch.FloatTensor] = None,
222
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
223
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
224
+ upsample_size: Optional[int] = None,
225
+ attention_mask: Optional[torch.FloatTensor] = None,
226
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
227
+ ):
228
+ for resnet, attn in zip(self.resnets, self.attentions):
229
+ # pop res hidden states
230
+ #print(f"in free crossatten upblock2d, hidden states shape: {hidden_states.shape}")
231
+ res_hidden_states = res_hidden_states_tuple[-1]
232
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
233
+
234
+ # --------------- FreeU code -----------------------
235
+ # Only operate on the first two stages
236
+ if hidden_states.shape[1] == 1280:
237
+ hidden_states[:,:640] = hidden_states[:,:640] * self.b1
238
+ res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
239
+ if hidden_states.shape[1] == 640:
240
+ hidden_states[:,:320] = hidden_states[:,:320] * self.b2
241
+ res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
242
+ # ---------------------------------------------------------
243
+
244
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
245
+
246
+ if self.training and self.gradient_checkpointing:
247
+
248
+ def create_custom_forward(module, return_dict=None):
249
+ def custom_forward(*inputs):
250
+ if return_dict is not None:
251
+ return module(*inputs, return_dict=return_dict)
252
+ else:
253
+ return module(*inputs)
254
+
255
+ return custom_forward
256
+
257
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
258
+ hidden_states = torch.utils.checkpoint.checkpoint(
259
+ create_custom_forward(resnet),
260
+ hidden_states,
261
+ temb,
262
+ **ckpt_kwargs,
263
+ )
264
+ hidden_states = torch.utils.checkpoint.checkpoint(
265
+ create_custom_forward(attn, return_dict=False),
266
+ hidden_states,
267
+ encoder_hidden_states,
268
+ None, # timestep
269
+ None, # class_labels
270
+ cross_attention_kwargs,
271
+ attention_mask,
272
+ encoder_attention_mask,
273
+ **ckpt_kwargs,
274
+ )[0]
275
+ else:
276
+ hidden_states = resnet(hidden_states, temb)
277
+ # hidden_states = attn(
278
+ # hidden_states,
279
+ # encoder_hidden_states=encoder_hidden_states,
280
+ # cross_attention_kwargs=cross_attention_kwargs,
281
+ # encoder_attention_mask=encoder_attention_mask,
282
+ # return_dict=False,
283
+ # )[0]
284
+ hidden_states = attn(
285
+ hidden_states,
286
+ encoder_hidden_states=encoder_hidden_states,
287
+ cross_attention_kwargs=cross_attention_kwargs,
288
+ )[0]
289
+
290
+ if self.upsamplers is not None:
291
+ for upsampler in self.upsamplers:
292
+ hidden_states = upsampler(hidden_states, upsample_size)
293
+
294
+ return hidden_states
295
+
296
+ return forward
297
+
298
+ for i, upsample_block in enumerate(model.unet.up_blocks):
299
+ if isinstance_str(upsample_block, "CrossAttnUpBlock2D"):
300
+ upsample_block.forward = up_forward(upsample_block)
301
+ setattr(upsample_block, 'b1', b1)
302
+ setattr(upsample_block, 'b2', b2)
303
+ setattr(upsample_block, 's1', s1)
304
+ setattr(upsample_block, 's2', s2)