owiedotch commited on
Commit
c686a8d
1 Parent(s): f7d9674

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -60
app.py CHANGED
@@ -57,17 +57,8 @@ from utils.image import auto_resize
57
 
58
  config = OmegaConf.load("configs/model/ccsr_stage2.yaml")
59
  model = instantiate_from_config(config)
60
-
61
- # Load the checkpoint without weights_only=True
62
  ckpt = torch.load("weights/real-world_ccsr.ckpt", map_location="cpu")
63
-
64
- # Extract only the model state dict
65
- if "state_dict" in ckpt:
66
- state_dict = ckpt["state_dict"]
67
- else:
68
- state_dict = ckpt
69
-
70
- load_state_dict(model, state_dict, strict=True)
71
  model.freeze()
72
 
73
  # Check if CUDA is available, otherwise use CPU
@@ -147,7 +138,9 @@ def process(
147
  )
148
  else:
149
  if tile_vae:
150
- model._init_tiled_vae(encoder_tile_size=vae_encoder_tile_size, decoder_tile_size=vae_decoder_tile_size)
 
 
151
  if tile_diffusion:
152
  samples = sampler.sample_with_tile_ccsr(
153
  tile_size=tile_diffusion_size, tile_stride=tile_diffusion_stride,
@@ -178,67 +171,57 @@ def update_output_resolution(image, scale):
178
  return f"Current resolution: {width}x{height}. Output resolution: {int(width*scale)}x{int(height*scale)}"
179
  return "Upload an image to see the output resolution"
180
 
181
- block = gr.Blocks().queue()
182
- with block:
183
- with gr.Row():
184
- input_image = gr.Image(type="pil", label="Input Image")
 
 
 
 
 
 
185
 
186
  with gr.Row():
187
- sr_scale = gr.Slider(label="SR Scale", minimum=1, maximum=8, value=4, step=0.1, info="Super-resolution scale factor.")
188
-
189
- output_resolution = gr.Markdown("Upload an image to see the output resolution")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
  with gr.Row():
192
- run_button = gr.Button(value="Run")
193
-
194
- with gr.Accordion("Options", open=False):
195
- with gr.Column():
196
- num_samples = gr.Slider(label="Number Of Samples", minimum=1, maximum=12, value=1, step=1)
197
- strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
198
- positive_prompt = gr.Textbox(label="Positive Prompt", value="")
199
- negative_prompt = gr.Textbox(
200
- label="Negative Prompt",
201
- value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality"
202
- )
203
- cfg_scale = gr.Slider(label="Classifier Free Guidance Scale", minimum=0.1, maximum=30.0, value=1.0, step=0.1)
204
- steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=45, step=1)
205
- use_color_fix = gr.Checkbox(label="Use Color Correction", value=True)
206
- seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=231)
207
- tile_diffusion = gr.Checkbox(label="Tile diffusion", value=False)
208
- tile_diffusion_size = gr.Slider(label="Tile diffusion size", minimum=512, maximum=1024, value=512, step=256)
209
- tile_diffusion_stride = gr.Slider(label="Tile diffusion stride", minimum=256, maximum=512, value=256, step=128)
210
- tile_vae = gr.Checkbox(label="Tile VAE", value=True)
211
- vae_encoder_tile_size = gr.Slider(label="Encoder tile size", minimum=512, maximum=5000, value=1024, step=256)
212
- vae_decoder_tile_size = gr.Slider(label="Decoder tile size", minimum=64, maximum=512, value=224, step=128)
213
-
214
- with gr.Column():
215
- result_gallery = gr.Gallery(label="Output", show_label=False, elem_id="gallery")
216
 
217
  inputs = [
218
- input_image,
219
- num_samples,
220
- sr_scale,
221
- strength,
222
- positive_prompt,
223
- negative_prompt,
224
- cfg_scale,
225
- steps,
226
- use_color_fix,
227
- seed,
228
- tile_diffusion,
229
- tile_diffusion_size,
230
- tile_diffusion_stride,
231
- tile_vae,
232
- vae_encoder_tile_size,
233
- vae_decoder_tile_size,
234
  ]
235
  run_button.click(fn=process, inputs=inputs, outputs=[result_gallery])
236
 
237
- # Update output resolution when image is uploaded or SR scale is changed
238
  input_image.change(update_output_resolution, inputs=[input_image, sr_scale], outputs=[output_resolution])
239
  sr_scale.change(update_output_resolution, inputs=[input_image, sr_scale], outputs=[output_resolution])
240
 
241
- # Disable SR scale slider when no image is uploaded
242
  input_image.change(
243
  lambda x: gr.update(interactive=x is not None),
244
  inputs=[input_image],
 
57
 
58
  config = OmegaConf.load("configs/model/ccsr_stage2.yaml")
59
  model = instantiate_from_config(config)
 
 
60
  ckpt = torch.load("weights/real-world_ccsr.ckpt", map_location="cpu")
61
+ load_state_dict(model, ckpt, strict=True)
 
 
 
 
 
 
 
62
  model.freeze()
63
 
64
  # Check if CUDA is available, otherwise use CPU
 
138
  )
139
  else:
140
  if tile_vae:
141
+ # Remove this line as ControlLDM doesn't have _init_tiled_vae method
142
+ # model._init_tiled_vae(encoder_tile_size=vae_encoder_tile_size, decoder_tile_size=vae_decoder_tile_size)
143
+ pass
144
  if tile_diffusion:
145
  samples = sampler.sample_with_tile_ccsr(
146
  tile_size=tile_diffusion_size, tile_stride=tile_diffusion_stride,
 
171
  return f"Current resolution: {width}x{height}. Output resolution: {int(width*scale)}x{int(height*scale)}"
172
  return "Upload an image to see the output resolution"
173
 
174
+ # Improved UI design
175
+ css = """
176
+ .container {max-width: 1200px; margin: auto; padding: 20px;}
177
+ .input-image {width: 100%; max-height: 500px; object-fit: contain;}
178
+ .output-gallery {display: flex; flex-wrap: wrap; justify-content: center;}
179
+ .output-image {margin: 10px; max-width: 45%; height: auto;}
180
+ """
181
+
182
+ with gr.Blocks(css=css) as block:
183
+ gr.HTML("<h1 style='text-align: center;'>CCSR Upscaler</h1>")
184
 
185
  with gr.Row():
186
+ with gr.Column(scale=1):
187
+ input_image = gr.Image(type="pil", label="Input Image", elem_classes="input-image")
188
+ sr_scale = gr.Slider(label="SR Scale", minimum=1, maximum=8, value=4, step=0.1, info="Super-resolution scale factor.")
189
+ output_resolution = gr.Markdown("Upload an image to see the output resolution")
190
+ run_button = gr.Button(value="Run", variant="primary")
191
+
192
+ with gr.Column(scale=1):
193
+ with gr.Accordion("Advanced Options", open=False):
194
+ num_samples = gr.Slider(label="Number Of Samples", minimum=1, maximum=12, value=1, step=1)
195
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
196
+ positive_prompt = gr.Textbox(label="Positive Prompt", value="")
197
+ negative_prompt = gr.Textbox(
198
+ label="Negative Prompt",
199
+ value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality"
200
+ )
201
+ cfg_scale = gr.Slider(label="Classifier Free Guidance Scale", minimum=0.1, maximum=30.0, value=1.0, step=0.1)
202
+ steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=45, step=1)
203
+ use_color_fix = gr.Checkbox(label="Use Color Correction", value=True)
204
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=231)
205
+ tile_diffusion = gr.Checkbox(label="Tile diffusion", value=False)
206
+ tile_diffusion_size = gr.Slider(label="Tile diffusion size", minimum=512, maximum=1024, value=512, step=256)
207
+ tile_diffusion_stride = gr.Slider(label="Tile diffusion stride", minimum=256, maximum=512, value=256, step=128)
208
+ tile_vae = gr.Checkbox(label="Tile VAE", value=True)
209
+ vae_encoder_tile_size = gr.Slider(label="Encoder tile size", minimum=512, maximum=5000, value=1024, step=256)
210
+ vae_decoder_tile_size = gr.Slider(label="Decoder tile size", minimum=64, maximum=512, value=224, step=128)
211
 
212
  with gr.Row():
213
+ result_gallery = gr.Gallery(label="Output", show_label=False, elem_id="gallery", elem_classes="output-gallery")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
  inputs = [
216
+ input_image, num_samples, sr_scale, strength, positive_prompt, negative_prompt,
217
+ cfg_scale, steps, use_color_fix, seed, tile_diffusion, tile_diffusion_size,
218
+ tile_diffusion_stride, tile_vae, vae_encoder_tile_size, vae_decoder_tile_size,
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  ]
220
  run_button.click(fn=process, inputs=inputs, outputs=[result_gallery])
221
 
 
222
  input_image.change(update_output_resolution, inputs=[input_image, sr_scale], outputs=[output_resolution])
223
  sr_scale.change(update_output_resolution, inputs=[input_image, sr_scale], outputs=[output_resolution])
224
 
 
225
  input_image.change(
226
  lambda x: gr.update(interactive=x is not None),
227
  inputs=[input_image],