owiedotch commited on
Commit
10d6431
1 Parent(s): fa33488

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -13
app.py CHANGED
@@ -8,7 +8,6 @@ from omegaconf import OmegaConf
8
  import subprocess
9
  from tqdm import tqdm
10
  import requests
11
- import spaces
12
  import einops
13
  import math
14
  import random
@@ -62,9 +61,6 @@ load_state_dict(model, ckpt, strict=True)
62
  model.freeze()
63
  model.to("cuda")
64
 
65
- sampler = SpacedSampler(model, var_type="fixed_small")
66
-
67
- @spaces.GPU
68
  @torch.no_grad()
69
  def process(
70
  control_img: Image.Image,
@@ -79,7 +75,10 @@ def process(
79
  seed: int,
80
  tile_diffusion: bool,
81
  tile_diffusion_size: int,
82
- tile_diffusion_stride: int
 
 
 
83
  ):
84
  print(
85
  f"control image shape={control_img.size}\n"
@@ -88,6 +87,7 @@ def process(
88
  f"cdf scale={cfg_scale}, steps={steps}, use_color_fix={use_color_fix}\n"
89
  f"seed={seed}\n"
90
  f"tile_diffusion={tile_diffusion}, tile_diffusion_size={tile_diffusion_size}, tile_diffusion_stride={tile_diffusion_stride}"
 
91
  )
92
  pl.seed_everything(seed)
93
 
@@ -118,12 +118,13 @@ def process(
118
  height, width = control.size(-2), control.size(-1)
119
  model.control_scales = [strength] * 13
120
 
 
121
  preds = []
122
  for _ in tqdm(range(num_samples)):
123
  shape = (1, 4, height // 8, width // 8)
124
  x_T = torch.randn(shape, device=model.device, dtype=torch.float32)
125
 
126
- if not tile_diffusion:
127
  samples = sampler.sample_ccsr(
128
  steps=steps, t_max=0.6667, t_min=0.3333, shape=shape, cond_img=control,
129
  positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T,
@@ -131,13 +132,23 @@ def process(
131
  color_fix_type="adain" if use_color_fix else "none"
132
  )
133
  else:
134
- samples = sampler.sample_with_tile_ccsr(
135
- tile_size=tile_diffusion_size, tile_stride=tile_diffusion_stride,
136
- steps=steps, t_max=0.6667, t_min=0.3333, shape=shape, cond_img=control,
137
- positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T,
138
- cfg_scale=cfg_scale,
139
- color_fix_type="adain" if use_color_fix else "none"
140
- )
 
 
 
 
 
 
 
 
 
 
141
 
142
  x_samples = samples.clamp(0, 1)
143
  x_samples = (einops.rearrange(x_samples, "b c h w -> b h w c") * 255).cpu().numpy().clip(0, 255).astype(np.uint8)
@@ -182,6 +193,9 @@ with block:
182
  tile_diffusion = gr.Checkbox(label="Tile diffusion", value=False)
183
  tile_diffusion_size = gr.Slider(label="Tile diffusion size", minimum=512, maximum=1024, value=512, step=256)
184
  tile_diffusion_stride = gr.Slider(label="Tile diffusion stride", minimum=256, maximum=512, value=256, step=128)
 
 
 
185
 
186
  with gr.Column():
187
  result_gallery = gr.Gallery(label="Output", show_label=False, elem_id="gallery")
@@ -200,6 +214,9 @@ with block:
200
  tile_diffusion,
201
  tile_diffusion_size,
202
  tile_diffusion_stride,
 
 
 
203
  ]
204
  run_button.click(fn=process, inputs=inputs, outputs=[result_gallery])
205
 
 
8
  import subprocess
9
  from tqdm import tqdm
10
  import requests
 
11
  import einops
12
  import math
13
  import random
 
61
  model.freeze()
62
  model.to("cuda")
63
 
 
 
 
64
  @torch.no_grad()
65
  def process(
66
  control_img: Image.Image,
 
75
  seed: int,
76
  tile_diffusion: bool,
77
  tile_diffusion_size: int,
78
+ tile_diffusion_stride: int,
79
+ tile_vae: bool,
80
+ vae_encoder_tile_size: int,
81
+ vae_decoder_tile_size: int
82
  ):
83
  print(
84
  f"control image shape={control_img.size}\n"
 
87
  f"cdf scale={cfg_scale}, steps={steps}, use_color_fix={use_color_fix}\n"
88
  f"seed={seed}\n"
89
  f"tile_diffusion={tile_diffusion}, tile_diffusion_size={tile_diffusion_size}, tile_diffusion_stride={tile_diffusion_stride}"
90
+ f"tile_vae={tile_vae}, vae_encoder_tile_size={vae_encoder_tile_size}, vae_decoder_tile_size={vae_decoder_tile_size}"
91
  )
92
  pl.seed_everything(seed)
93
 
 
118
  height, width = control.size(-2), control.size(-1)
119
  model.control_scales = [strength] * 13
120
 
121
+ sampler = SpacedSampler(model, var_type="fixed_small")
122
  preds = []
123
  for _ in tqdm(range(num_samples)):
124
  shape = (1, 4, height // 8, width // 8)
125
  x_T = torch.randn(shape, device=model.device, dtype=torch.float32)
126
 
127
+ if not tile_diffusion and not tile_vae:
128
  samples = sampler.sample_ccsr(
129
  steps=steps, t_max=0.6667, t_min=0.3333, shape=shape, cond_img=control,
130
  positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T,
 
132
  color_fix_type="adain" if use_color_fix else "none"
133
  )
134
  else:
135
+ if tile_vae:
136
+ model._init_tiled_vae(encoder_tile_size=vae_encoder_tile_size, decoder_tile_size=vae_decoder_tile_size)
137
+ if tile_diffusion:
138
+ samples = sampler.sample_with_tile_ccsr(
139
+ tile_size=tile_diffusion_size, tile_stride=tile_diffusion_stride,
140
+ steps=steps, t_max=0.6667, t_min=0.3333, shape=shape, cond_img=control,
141
+ positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T,
142
+ cfg_scale=cfg_scale,
143
+ color_fix_type="adain" if use_color_fix else "none"
144
+ )
145
+ else:
146
+ samples = sampler.sample_ccsr(
147
+ steps=steps, t_max=0.6667, t_min=0.3333, shape=shape, cond_img=control,
148
+ positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T,
149
+ cfg_scale=cfg_scale,
150
+ color_fix_type="adain" if use_color_fix else "none"
151
+ )
152
 
153
  x_samples = samples.clamp(0, 1)
154
  x_samples = (einops.rearrange(x_samples, "b c h w -> b h w c") * 255).cpu().numpy().clip(0, 255).astype(np.uint8)
 
193
  tile_diffusion = gr.Checkbox(label="Tile diffusion", value=False)
194
  tile_diffusion_size = gr.Slider(label="Tile diffusion size", minimum=512, maximum=1024, value=512, step=256)
195
  tile_diffusion_stride = gr.Slider(label="Tile diffusion stride", minimum=256, maximum=512, value=256, step=128)
196
+ tile_vae = gr.Checkbox(label="Tile VAE", value=True)
197
+ vae_encoder_tile_size = gr.Slider(label="Encoder tile size", minimum=512, maximum=5000, value=1024, step=256)
198
+ vae_decoder_tile_size = gr.Slider(label="Decoder tile size", minimum=64, maximum=512, value=224, step=128)
199
 
200
  with gr.Column():
201
  result_gallery = gr.Gallery(label="Output", show_label=False, elem_id="gallery")
 
214
  tile_diffusion,
215
  tile_diffusion_size,
216
  tile_diffusion_stride,
217
+ tile_vae,
218
+ vae_encoder_tile_size,
219
+ vae_decoder_tile_size,
220
  ]
221
  run_button.click(fn=process, inputs=inputs, outputs=[result_gallery])
222