owiedotch commited on
Commit
49b322e
1 Parent(s): 0582ce0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -32
app.py CHANGED
@@ -61,7 +61,6 @@ ckpt = torch.load("weights/real-world_ccsr.ckpt", map_location="cpu")
61
  load_state_dict(model, ckpt, strict=True)
62
  model.freeze()
63
 
64
- # Check if CUDA is available, otherwise use CPU
65
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
  model.to(device)
67
 
@@ -85,27 +84,26 @@ def process(
85
  vae_encoder_tile_size: int,
86
  vae_decoder_tile_size: int
87
  ):
88
- print(
89
- f"control image shape={control_img.size}\n"
90
- f"num_samples={num_samples}, sr_scale={sr_scale}, strength={strength}\n"
91
- f"positive_prompt='{positive_prompt}', negative_prompt='{negative_prompt}'\n"
92
- f"cdf scale={cfg_scale}, steps={steps}, use_color_fix={use_color_fix}\n"
93
- f"seed={seed}\n"
94
- f"tile_diffusion={tile_diffusion}, tile_diffusion_size={tile_diffusion_size}, tile_diffusion_stride={tile_diffusion_stride}"
95
- f"tile_vae={tile_vae}, vae_encoder_tile_size={vae_encoder_tile_size}, vae_decoder_tile_size={vae_decoder_tile_size}"
96
- )
97
  pl.seed_everything(seed)
98
 
99
- # Resize lr
100
  if sr_scale != 1:
101
  control_img = control_img.resize(
102
  tuple(math.ceil(x * sr_scale) for x in control_img.size),
103
  Image.BICUBIC
104
  )
105
-
106
  input_size = control_img.size
107
 
108
- # Resize the lr image
109
  if not tile_diffusion:
110
  control_img = auto_resize(control_img, 512)
111
  else:
@@ -129,39 +127,28 @@ def process(
129
  shape = (1, 4, height // 8, width // 8)
130
  x_T = torch.randn(shape, device=device, dtype=torch.float32)
131
 
132
- # Modify the get_learned_conditioning method to handle the attention mask issue
133
- def modified_get_learned_conditioning(model, prompt):
134
- tokens = model.cond_stage_model.tokenizer.encode(prompt)
135
- tokens = torch.LongTensor(tokens).to(model.device).unsqueeze(0)
136
- encoder_hidden_states = model.cond_stage_model.transformer(input_ids=tokens).last_hidden_state
137
- return encoder_hidden_states
138
-
139
- cond = modified_get_learned_conditioning(model, positive_prompt)
140
- uncond = modified_get_learned_conditioning(model, negative_prompt)
141
-
142
  if not tile_diffusion and not tile_vae:
143
  samples = sampler.sample_ccsr(
144
  steps=steps, t_max=0.6667, t_min=0.3333, shape=shape, cond_img=control,
145
- positive_prompt=cond, negative_prompt=uncond, x_T=x_T,
146
  cfg_scale=cfg_scale,
147
  color_fix_type="adain" if use_color_fix else "none"
148
  )
149
  else:
150
  if tile_vae:
151
- # Note: Tiled VAE is not implemented in this version
152
- pass
153
  if tile_diffusion:
154
  samples = sampler.sample_with_tile_ccsr(
155
  tile_size=tile_diffusion_size, tile_stride=tile_diffusion_stride,
156
  steps=steps, t_max=0.6667, t_min=0.3333, shape=shape, cond_img=control,
157
- positive_prompt=cond, negative_prompt=uncond, x_T=x_T,
158
  cfg_scale=cfg_scale,
159
  color_fix_type="adain" if use_color_fix else "none"
160
  )
161
  else:
162
  samples = sampler.sample_ccsr(
163
  steps=steps, t_max=0.6667, t_min=0.3333, shape=shape, cond_img=control,
164
- positive_prompt=cond, negative_prompt=uncond, x_T=x_T,
165
  cfg_scale=cfg_scale,
166
  color_fix_type="adain" if use_color_fix else "none"
167
  )
@@ -180,12 +167,31 @@ def update_output_resolution(image, scale):
180
  return f"Current resolution: {width}x{height}. Output resolution: {int(width*scale)}x{int(height*scale)}"
181
  return "Upload an image to see the output resolution"
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  # Improved UI design
184
  css = """
185
  .container {max-width: 1200px; margin: auto; padding: 20px;}
186
  .input-image {width: 100%; max-height: 500px; object-fit: contain;}
187
  .output-gallery {display: flex; flex-wrap: wrap; justify-content: center;}
188
  .output-image {margin: 10px; max-width: 45%; height: auto;}
 
189
  """
190
 
191
  with gr.Blocks(css=css) as block:
@@ -194,7 +200,20 @@ with gr.Blocks(css=css) as block:
194
  with gr.Row():
195
  with gr.Column(scale=1):
196
  input_image = gr.Image(type="pil", label="Input Image", elem_classes="input-image")
197
- sr_scale = gr.Slider(label="SR Scale", minimum=1, maximum=8, value=4, step=0.1, info="Super-resolution scale factor.")
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  output_resolution = gr.Markdown("Upload an image to see the output resolution")
199
  run_button = gr.Button(value="Run", variant="primary")
200
 
@@ -221,15 +240,43 @@ with gr.Blocks(css=css) as block:
221
  with gr.Row():
222
  result_gallery = gr.Gallery(label="Output", show_label=False, elem_id="gallery", elem_classes="output-gallery")
223
 
 
 
 
 
 
 
 
 
 
 
224
  inputs = [
225
  input_image, num_samples, sr_scale, strength, positive_prompt, negative_prompt,
226
  cfg_scale, steps, use_color_fix, seed, tile_diffusion, tile_diffusion_size,
227
  tile_diffusion_stride, tile_vae, vae_encoder_tile_size, vae_decoder_tile_size,
228
  ]
229
- run_button.click(fn=process, inputs=inputs, outputs=[result_gallery])
 
 
 
 
230
 
231
- input_image.change(update_output_resolution, inputs=[input_image, sr_scale], outputs=[output_resolution])
232
- sr_scale.change(update_output_resolution, inputs=[input_image, sr_scale], outputs=[output_resolution])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
  input_image.change(
235
  lambda x: gr.update(interactive=x is not None),
 
61
  load_state_dict(model, ckpt, strict=True)
62
  model.freeze()
63
 
 
64
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
65
  model.to(device)
66
 
 
84
  vae_encoder_tile_size: int,
85
  vae_decoder_tile_size: int
86
  ):
87
+ print(f"control image shape={control_img.size}\n"
88
+ f"num_samples={num_samples}, sr_scale={sr_scale}, strength={strength}\n"
89
+ f"positive_prompt='{positive_prompt}', negative_prompt='{negative_prompt}'\n"
90
+ f"cfg scale={cfg_scale}, steps={steps}, use_color_fix={use_color_fix}\n"
91
+ f"seed={seed}\n"
92
+ f"tile_diffusion={tile_diffusion}, tile_diffusion_size={tile_diffusion_size}, tile_diffusion_stride={tile_diffusion_stride}"
93
+ f"tile_vae={tile_vae}, vae_encoder_tile_size={vae_encoder_tile_size}, vae_decoder_tile_size={vae_decoder_tile_size}")
94
+
 
95
  pl.seed_everything(seed)
96
 
97
+ # Resize input image
98
  if sr_scale != 1:
99
  control_img = control_img.resize(
100
  tuple(math.ceil(x * sr_scale) for x in control_img.size),
101
  Image.BICUBIC
102
  )
103
+
104
  input_size = control_img.size
105
 
106
+ # Resize the image
107
  if not tile_diffusion:
108
  control_img = auto_resize(control_img, 512)
109
  else:
 
127
  shape = (1, 4, height // 8, width // 8)
128
  x_T = torch.randn(shape, device=device, dtype=torch.float32)
129
 
 
 
 
 
 
 
 
 
 
 
130
  if not tile_diffusion and not tile_vae:
131
  samples = sampler.sample_ccsr(
132
  steps=steps, t_max=0.6667, t_min=0.3333, shape=shape, cond_img=control,
133
+ positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T,
134
  cfg_scale=cfg_scale,
135
  color_fix_type="adain" if use_color_fix else "none"
136
  )
137
  else:
138
  if tile_vae:
139
+ model._init_tiled_vae(encoder_tile_size=vae_encoder_tile_size, decoder_tile_size=vae_decoder_tile_size)
 
140
  if tile_diffusion:
141
  samples = sampler.sample_with_tile_ccsr(
142
  tile_size=tile_diffusion_size, tile_stride=tile_diffusion_stride,
143
  steps=steps, t_max=0.6667, t_min=0.3333, shape=shape, cond_img=control,
144
+ positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T,
145
  cfg_scale=cfg_scale,
146
  color_fix_type="adain" if use_color_fix else "none"
147
  )
148
  else:
149
  samples = sampler.sample_ccsr(
150
  steps=steps, t_max=0.6667, t_min=0.3333, shape=shape, cond_img=control,
151
+ positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T,
152
  cfg_scale=cfg_scale,
153
  color_fix_type="adain" if use_color_fix else "none"
154
  )
 
167
  return f"Current resolution: {width}x{height}. Output resolution: {int(width*scale)}x{int(height*scale)}"
168
  return "Upload an image to see the output resolution"
169
 
170
+ def update_scale_choices(image):
171
+ if image is not None:
172
+ width, height = image.size
173
+ aspect_ratio = width / height
174
+ common_resolutions = [
175
+ (1280, 720), (1920, 1080), (2560, 1440), (3840, 2160), # 16:9
176
+ (1440, 1440), (2048, 2048), (2560, 2560), (3840, 3840) # 1:1
177
+ ]
178
+ choices = []
179
+ for w, h in common_resolutions:
180
+ if abs(w/h - aspect_ratio) < 0.1: # Allow some tolerance for aspect ratio
181
+ scale = max(w/width, h/height)
182
+ if scale > 1:
183
+ choices.append(f"{w}x{h} ({scale:.2f}x)")
184
+ choices.append("Custom")
185
+ return gr.update(choices=choices, value=choices[1] if len(choices) > 1 else "Custom")
186
+ return gr.update(choices=["Custom"], value="Custom")
187
+
188
  # Improved UI design
189
  css = """
190
  .container {max-width: 1200px; margin: auto; padding: 20px;}
191
  .input-image {width: 100%; max-height: 500px; object-fit: contain;}
192
  .output-gallery {display: flex; flex-wrap: wrap; justify-content: center;}
193
  .output-image {margin: 10px; max-width: 45%; height: auto;}
194
+ .gr-form {border: 1px solid #e0e0e0; border-radius: 8px; padding: 16px; margin-bottom: 16px;}
195
  """
196
 
197
  with gr.Blocks(css=css) as block:
 
200
  with gr.Row():
201
  with gr.Column(scale=1):
202
  input_image = gr.Image(type="pil", label="Input Image", elem_classes="input-image")
203
+ sr_scale = gr.Dropdown(
204
+ label="Output Resolution",
205
+ choices=["Custom"],
206
+ value="Custom",
207
+ interactive=True
208
+ )
209
+ custom_scale = gr.Slider(
210
+ label="Custom Scale",
211
+ minimum=1,
212
+ maximum=8,
213
+ value=4,
214
+ step=0.1,
215
+ visible=True
216
+ )
217
  output_resolution = gr.Markdown("Upload an image to see the output resolution")
218
  run_button = gr.Button(value="Run", variant="primary")
219
 
 
240
  with gr.Row():
241
  result_gallery = gr.Gallery(label="Output", show_label=False, elem_id="gallery", elem_classes="output-gallery")
242
 
243
+ def update_custom_scale(choice):
244
+ return gr.update(visible=choice == "Custom")
245
+
246
+ sr_scale.change(update_custom_scale, inputs=[sr_scale], outputs=[custom_scale])
247
+
248
+ def get_scale_value(choice, custom):
249
+ if choice == "Custom":
250
+ return custom
251
+ return float(choice.split()[-1].strip("()x"))
252
+
253
  inputs = [
254
  input_image, num_samples, sr_scale, strength, positive_prompt, negative_prompt,
255
  cfg_scale, steps, use_color_fix, seed, tile_diffusion, tile_diffusion_size,
256
  tile_diffusion_stride, tile_vae, vae_encoder_tile_size, vae_decoder_tile_size,
257
  ]
258
+ run_button.click(
259
+ fn=lambda *args: process(*args[:1], args[1], get_scale_value(args[2], args[-1]), *args[3:-1]),
260
+ inputs=inputs + [custom_scale],
261
+ outputs=[result_gallery]
262
+ )
263
 
264
+ input_image.change(
265
+ update_scale_choices,
266
+ inputs=[input_image],
267
+ outputs=[sr_scale]
268
+ )
269
+
270
+ input_image.change(
271
+ update_output_resolution,
272
+ inputs=[input_image, sr_scale],
273
+ outputs=[output_resolution]
274
+ )
275
+ sr_scale.change(
276
+ update_output_resolution,
277
+ inputs=[input_image, sr_scale],
278
+ outputs=[output_resolution]
279
+ )
280
 
281
  input_image.change(
282
  lambda x: gr.update(interactive=x is not None),