owiedotch commited on
Commit
f1ca883
1 Parent(s): 49b322e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -32
app.py CHANGED
@@ -79,18 +79,14 @@ def process(
79
  seed: int,
80
  tile_diffusion: bool,
81
  tile_diffusion_size: int,
82
- tile_diffusion_stride: int,
83
- tile_vae: bool,
84
- vae_encoder_tile_size: int,
85
- vae_decoder_tile_size: int
86
  ):
87
  print(f"control image shape={control_img.size}\n"
88
  f"num_samples={num_samples}, sr_scale={sr_scale}, strength={strength}\n"
89
  f"positive_prompt='{positive_prompt}', negative_prompt='{negative_prompt}'\n"
90
  f"cfg scale={cfg_scale}, steps={steps}, use_color_fix={use_color_fix}\n"
91
  f"seed={seed}\n"
92
- f"tile_diffusion={tile_diffusion}, tile_diffusion_size={tile_diffusion_size}, tile_diffusion_stride={tile_diffusion_stride}"
93
- f"tile_vae={tile_vae}, vae_encoder_tile_size={vae_encoder_tile_size}, vae_decoder_tile_size={vae_decoder_tile_size}")
94
 
95
  pl.seed_everything(seed)
96
 
@@ -127,7 +123,7 @@ def process(
127
  shape = (1, 4, height // 8, width // 8)
128
  x_T = torch.randn(shape, device=device, dtype=torch.float32)
129
 
130
- if not tile_diffusion and not tile_vae:
131
  samples = sampler.sample_ccsr(
132
  steps=steps, t_max=0.6667, t_min=0.3333, shape=shape, cond_img=control,
133
  positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T,
@@ -135,23 +131,13 @@ def process(
135
  color_fix_type="adain" if use_color_fix else "none"
136
  )
137
  else:
138
- if tile_vae:
139
- model._init_tiled_vae(encoder_tile_size=vae_encoder_tile_size, decoder_tile_size=vae_decoder_tile_size)
140
- if tile_diffusion:
141
- samples = sampler.sample_with_tile_ccsr(
142
- tile_size=tile_diffusion_size, tile_stride=tile_diffusion_stride,
143
- steps=steps, t_max=0.6667, t_min=0.3333, shape=shape, cond_img=control,
144
- positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T,
145
- cfg_scale=cfg_scale,
146
- color_fix_type="adain" if use_color_fix else "none"
147
- )
148
- else:
149
- samples = sampler.sample_ccsr(
150
- steps=steps, t_max=0.6667, t_min=0.3333, shape=shape, cond_img=control,
151
- positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T,
152
- cfg_scale=cfg_scale,
153
- color_fix_type="adain" if use_color_fix else "none"
154
- )
155
 
156
  x_samples = samples.clamp(0, 1)
157
  x_samples = (einops.rearrange(x_samples, "b c h w -> b h w c") * 255).cpu().numpy().clip(0, 255).astype(np.uint8)
@@ -161,9 +147,13 @@ def process(
161
 
162
  return preds
163
 
164
- def update_output_resolution(image, scale):
165
  if image is not None:
166
  width, height = image.size
 
 
 
 
167
  return f"Current resolution: {width}x{height}. Output resolution: {int(width*scale)}x{int(height*scale)}"
168
  return "Upload an image to see the output resolution"
169
 
@@ -233,9 +223,6 @@ with gr.Blocks(css=css) as block:
233
  tile_diffusion = gr.Checkbox(label="Tile diffusion", value=False)
234
  tile_diffusion_size = gr.Slider(label="Tile diffusion size", minimum=512, maximum=1024, value=512, step=256)
235
  tile_diffusion_stride = gr.Slider(label="Tile diffusion stride", minimum=256, maximum=512, value=256, step=128)
236
- tile_vae = gr.Checkbox(label="Tile VAE", value=True)
237
- vae_encoder_tile_size = gr.Slider(label="Encoder tile size", minimum=512, maximum=5000, value=1024, step=256)
238
- vae_decoder_tile_size = gr.Slider(label="Decoder tile size", minimum=64, maximum=512, value=224, step=128)
239
 
240
  with gr.Row():
241
  result_gallery = gr.Gallery(label="Output", show_label=False, elem_id="gallery", elem_classes="output-gallery")
@@ -253,10 +240,10 @@ with gr.Blocks(css=css) as block:
253
  inputs = [
254
  input_image, num_samples, sr_scale, strength, positive_prompt, negative_prompt,
255
  cfg_scale, steps, use_color_fix, seed, tile_diffusion, tile_diffusion_size,
256
- tile_diffusion_stride, tile_vae, vae_encoder_tile_size, vae_decoder_tile_size,
257
  ]
258
  run_button.click(
259
- fn=lambda *args: process(*args[:1], args[1], get_scale_value(args[2], args[-1]), *args[3:-1]),
260
  inputs=inputs + [custom_scale],
261
  outputs=[result_gallery]
262
  )
@@ -269,12 +256,17 @@ with gr.Blocks(css=css) as block:
269
 
270
  input_image.change(
271
  update_output_resolution,
272
- inputs=[input_image, sr_scale],
273
  outputs=[output_resolution]
274
  )
275
  sr_scale.change(
276
  update_output_resolution,
277
- inputs=[input_image, sr_scale],
 
 
 
 
 
278
  outputs=[output_resolution]
279
  )
280
 
 
79
  seed: int,
80
  tile_diffusion: bool,
81
  tile_diffusion_size: int,
82
+ tile_diffusion_stride: int
 
 
 
83
  ):
84
  print(f"control image shape={control_img.size}\n"
85
  f"num_samples={num_samples}, sr_scale={sr_scale}, strength={strength}\n"
86
  f"positive_prompt='{positive_prompt}', negative_prompt='{negative_prompt}'\n"
87
  f"cfg scale={cfg_scale}, steps={steps}, use_color_fix={use_color_fix}\n"
88
  f"seed={seed}\n"
89
+ f"tile_diffusion={tile_diffusion}, tile_diffusion_size={tile_diffusion_size}, tile_diffusion_stride={tile_diffusion_stride}")
 
90
 
91
  pl.seed_everything(seed)
92
 
 
123
  shape = (1, 4, height // 8, width // 8)
124
  x_T = torch.randn(shape, device=device, dtype=torch.float32)
125
 
126
+ if not tile_diffusion:
127
  samples = sampler.sample_ccsr(
128
  steps=steps, t_max=0.6667, t_min=0.3333, shape=shape, cond_img=control,
129
  positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T,
 
131
  color_fix_type="adain" if use_color_fix else "none"
132
  )
133
  else:
134
+ samples = sampler.sample_with_tile_ccsr(
135
+ tile_size=tile_diffusion_size, tile_stride=tile_diffusion_stride,
136
+ steps=steps, t_max=0.6667, t_min=0.3333, shape=shape, cond_img=control,
137
+ positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T,
138
+ cfg_scale=cfg_scale,
139
+ color_fix_type="adain" if use_color_fix else "none"
140
+ )
 
 
 
 
 
 
 
 
 
 
141
 
142
  x_samples = samples.clamp(0, 1)
143
  x_samples = (einops.rearrange(x_samples, "b c h w -> b h w c") * 255).cpu().numpy().clip(0, 255).astype(np.uint8)
 
147
 
148
  return preds
149
 
150
+ def update_output_resolution(image, scale_choice, custom_scale):
151
  if image is not None:
152
  width, height = image.size
153
+ if scale_choice == "Custom":
154
+ scale = custom_scale
155
+ else:
156
+ scale = float(scale_choice.split()[-1].strip("()x"))
157
  return f"Current resolution: {width}x{height}. Output resolution: {int(width*scale)}x{int(height*scale)}"
158
  return "Upload an image to see the output resolution"
159
 
 
223
  tile_diffusion = gr.Checkbox(label="Tile diffusion", value=False)
224
  tile_diffusion_size = gr.Slider(label="Tile diffusion size", minimum=512, maximum=1024, value=512, step=256)
225
  tile_diffusion_stride = gr.Slider(label="Tile diffusion stride", minimum=256, maximum=512, value=256, step=128)
 
 
 
226
 
227
  with gr.Row():
228
  result_gallery = gr.Gallery(label="Output", show_label=False, elem_id="gallery", elem_classes="output-gallery")
 
240
  inputs = [
241
  input_image, num_samples, sr_scale, strength, positive_prompt, negative_prompt,
242
  cfg_scale, steps, use_color_fix, seed, tile_diffusion, tile_diffusion_size,
243
+ tile_diffusion_stride
244
  ]
245
  run_button.click(
246
+ fn=lambda *args: process(*args[:1], args[1], get_scale_value(args[2], args[-1]), *args[3:]),
247
  inputs=inputs + [custom_scale],
248
  outputs=[result_gallery]
249
  )
 
256
 
257
  input_image.change(
258
  update_output_resolution,
259
+ inputs=[input_image, sr_scale, custom_scale],
260
  outputs=[output_resolution]
261
  )
262
  sr_scale.change(
263
  update_output_resolution,
264
+ inputs=[input_image, sr_scale, custom_scale],
265
+ outputs=[output_resolution]
266
+ )
267
+ custom_scale.change(
268
+ update_output_resolution,
269
+ inputs=[input_image, sr_scale, custom_scale],
270
  outputs=[output_resolution]
271
  )
272