owiedotch commited on
Commit
ba5770d
1 Parent(s): 9fd6710

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -36
app.py CHANGED
@@ -9,7 +9,9 @@ import subprocess
9
  from tqdm import tqdm
10
  import requests
11
  import spaces
12
- import einops # Import einops to fix the NameError
 
 
13
 
14
  def download_file(url, filename):
15
  response = requests.get(url, stream=True)
@@ -51,6 +53,7 @@ from ldm.xformers_state import disable_xformers
51
  from model.q_sampler import SpacedSampler
52
  from model.ccsr_stage1 import ControlLDM
53
  from utils.common import instantiate_from_config, load_state_dict
 
54
 
55
  config = OmegaConf.load("configs/model/ccsr_stage2.yaml")
56
  model = instantiate_from_config(config)
@@ -59,46 +62,161 @@ load_state_dict(model, ckpt, strict=True)
59
  model.freeze()
60
  model.to("cuda")
61
 
 
 
62
  @spaces.GPU
63
  @torch.no_grad()
64
- def process(image, steps, t_max, t_min, color_fix_type, scale):
65
- image = Image.open(image).convert("RGB")
66
- image = image.resize((256, 256), Image.LANCZOS)
67
- image = np.array(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- sampler = SpacedSampler(model, var_type="fixed_small")
70
- control = torch.tensor(np.stack([image]) / 255.0, dtype=torch.float32, device=model.device).clamp_(0, 1)
71
- control = einops.rearrange(control, "n h w c -> n c h w").contiguous()
 
 
72
 
73
- model.control_scales = [scale] * 13 # Use the scale parameter
74
 
75
- height, width = control.size(-2), control.size(-1)
76
- shape = (1, 4, height // 8, width // 8)
77
- x_T = torch.randn(shape, device=model.device, dtype=torch.float32)
 
78
 
79
- samples = sampler.sample_ccsr(
80
- steps=steps, t_max=t_max, t_min=t_min, shape=shape, cond_img=control,
81
- positive_prompt="", negative_prompt="", x_T=x_T,
82
- cfg_scale=1.0, color_fix_type=color_fix_type
83
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- x_samples = samples.clamp(0, 1)
86
- x_samples = (einops.rearrange(x_samples, "b c h w -> b h w c") * 255).cpu().numpy().clip(0, 255).astype(np.uint8)
87
-
88
- return Image.fromarray(x_samples[0])
89
-
90
- interface = gr.Interface(
91
- fn=process,
92
- inputs=[
93
- gr.Image(type="filepath", label="Input Image"),
94
- gr.Slider(minimum=1, maximum=100, step=1, value=45, label="Steps"),
95
- gr.Slider(minimum=0, maximum=1, step=0.0001, value=0.6667, label="T Max"),
96
- gr.Slider(minimum=0, maximum=1, step=0.0001, value=0.3333, label="T Min"),
97
- gr.Dropdown(choices=["adain", "wavelet", "none"], value="adain", label="Color Fix Type"),
98
- gr.Slider(minimum=0, maximum=2, step=0.01, value=1.0, label="Scale"),
99
- ],
100
- outputs=gr.Image(type="pil", label="Output Image"),
101
- title="CCSR: Continuous Contrastive Super-Resolution",
102
- )
103
-
104
- interface.launch()
 
9
  from tqdm import tqdm
10
  import requests
11
  import spaces
12
+ import einops
13
+ import math
14
+ import random
15
 
16
  def download_file(url, filename):
17
  response = requests.get(url, stream=True)
 
53
  from model.q_sampler import SpacedSampler
54
  from model.ccsr_stage1 import ControlLDM
55
  from utils.common import instantiate_from_config, load_state_dict
56
+ from utils.image import auto_resize
57
 
58
  config = OmegaConf.load("configs/model/ccsr_stage2.yaml")
59
  model = instantiate_from_config(config)
 
62
  model.freeze()
63
  model.to("cuda")
64
 
65
+ sampler = SpacedSampler(model, var_type="fixed_small")
66
+
67
  @spaces.GPU
68
  @torch.no_grad()
69
+ def process(
70
+ control_img: Image.Image,
71
+ num_samples: int,
72
+ sr_scale: int,
73
+ strength: float,
74
+ positive_prompt: str,
75
+ negative_prompt: str,
76
+ cfg_scale: float,
77
+ steps: int,
78
+ use_color_fix: bool,
79
+ seed: int,
80
+ tile_diffusion: bool,
81
+ tile_diffusion_size: int,
82
+ tile_diffusion_stride: int,
83
+ tile_vae: bool,
84
+ vae_encoder_tile_size: int,
85
+ vae_decoder_tile_size: int
86
+ ):
87
+ print(
88
+ f"control image shape={control_img.size}\n"
89
+ f"num_samples={num_samples}, sr_scale={sr_scale}, strength={strength}\n"
90
+ f"positive_prompt='{positive_prompt}', negative_prompt='{negative_prompt}'\n"
91
+ f"cdf scale={cfg_scale}, steps={steps}, use_color_fix={use_color_fix}\n"
92
+ f"seed={seed}\n"
93
+ f"tile_diffusion={tile_diffusion}, tile_diffusion_size={tile_diffusion_size}, tile_diffusion_stride={tile_diffusion_stride}"
94
+ f"tile_vae={tile_vae}, vae_encoder_tile_size={vae_encoder_tile_size}, vae_decoder_tile_size={vae_decoder_tile_size}"
95
+ )
96
+ if seed == -1:
97
+ seed = random.randint(0, 2**32 - 1)
98
+ torch.manual_seed(seed)
99
 
100
+ if sr_scale != 1:
101
+ control_img = control_img.resize(
102
+ tuple(math.ceil(x * sr_scale) for x in control_img.size),
103
+ Image.BICUBIC
104
+ )
105
 
106
+ input_size = control_img.size
107
 
108
+ if not tile_diffusion:
109
+ control_img = auto_resize(control_img, 512)
110
+ else:
111
+ control_img = auto_resize(control_img, tile_diffusion_size)
112
 
113
+ control_img = control_img.resize(
114
+ tuple((s // 64 + 1) * 64 for s in control_img.size), Image.LANCZOS
 
 
115
  )
116
+ control_img = np.array(control_img)
117
+
118
+ control = torch.tensor(control_img[None] / 255.0, dtype=torch.float32, device=model.device).clamp_(0, 1)
119
+ control = einops.rearrange(control, "n h w c -> n c h w").contiguous()
120
+ height, width = control.size(-2), control.size(-1)
121
+ model.control_scales = [strength] * 13
122
+
123
+ preds = []
124
+ for _ in tqdm(range(num_samples)):
125
+ shape = (1, 4, height // 8, width // 8)
126
+ x_T = torch.randn(shape, device=model.device, dtype=torch.float32)
127
+
128
+ if not tile_diffusion and not tile_vae:
129
+ samples = sampler.sample_ccsr(
130
+ steps=steps, t_max=0.6667, t_min=0.3333, shape=shape, cond_img=control,
131
+ positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T,
132
+ cfg_scale=cfg_scale,
133
+ color_fix_type="adain" if use_color_fix else "none"
134
+ )
135
+ else:
136
+ if tile_vae:
137
+ model._init_tiled_vae(encoder_tile_size=vae_encoder_tile_size, decoder_tile_size=vae_decoder_tile_size)
138
+ if tile_diffusion:
139
+ samples = sampler.sample_with_tile_ccsr(
140
+ tile_size=tile_diffusion_size, tile_stride=tile_diffusion_stride,
141
+ steps=steps, t_max=0.6667, t_min=0.3333, shape=shape, cond_img=control,
142
+ positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T,
143
+ cfg_scale=cfg_scale,
144
+ color_fix_type="adain" if use_color_fix else "none"
145
+ )
146
+ else:
147
+ samples = sampler.sample_ccsr(
148
+ steps=steps, t_max=0.6667, t_min=0.3333, shape=shape, cond_img=control,
149
+ positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T,
150
+ cfg_scale=cfg_scale,
151
+ color_fix_type="adain" if use_color_fix else "none"
152
+ )
153
+
154
+ x_samples = samples.clamp(0, 1)
155
+ x_samples = (einops.rearrange(x_samples, "b c h w -> b h w c") * 255).cpu().numpy().clip(0, 255).astype(np.uint8)
156
+
157
+ img = Image.fromarray(x_samples[0, ...]).resize(input_size, Image.LANCZOS)
158
+ preds.append(np.array(img))
159
+
160
+ return preds
161
+
162
+ MARKDOWN = \
163
+ """
164
+ ## Improving the Stability of Diffusion Models for Content Consistent Super-Resolution
165
+
166
+ [GitHub](https://github.com/csslc/CCSR) | [Paper](https://arxiv.org/pdf/2401.00877.pdf) | [Project Page](https://csslc.github.io/project-CCSR/)
167
+
168
+ If CCSR is helpful for you, please help star the GitHub Repo. Thanks!
169
+ """
170
+
171
+ block = gr.Blocks().queue()
172
+ with block:
173
+ with gr.Row():
174
+ gr.Markdown(MARKDOWN)
175
+ with gr.Row():
176
+ with gr.Column():
177
+ input_image = gr.Image(source="upload", type="pil", label="Input Image")
178
+ run_button = gr.Button(label="Run")
179
+ with gr.Accordion("Options", open=True):
180
+ num_samples = gr.Slider(label="Number Of Samples", minimum=1, maximum=12, value=1, step=1, info="Number of output images to generate.")
181
+ sr_scale = gr.Dropdown(label="SR Scale", choices=["2x", "4x", "8x"], value="4x", info="Super-resolution scale factor.")
182
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01, info="Strength of the control signal.")
183
+ positive_prompt = gr.Textbox(label="Positive Prompt", value="", info="Positive text prompt to guide the image generation.")
184
+ negative_prompt = gr.Textbox(
185
+ label="Negative Prompt",
186
+ value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
187
+ info="Negative text prompt to avoid undesirable features."
188
+ )
189
+ cfg_scale = gr.Slider(label="Classifier Free Guidance Scale", minimum=0.1, maximum=30.0, value=1.0, step=0.1, info="Scale for classifier-free guidance.")
190
+ steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=45, step=1, info="Number of diffusion steps.")
191
+ use_color_fix = gr.Checkbox(label="Use Color Correction", value=True, info="Apply color correction to the output image.")
192
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=231, info="Random seed for reproducibility. Set to -1 for a random seed.")
193
+ tile_diffusion = gr.Checkbox(label="Tile diffusion", value=False, info="Enable tiled diffusion for large images.")
194
+ tile_diffusion_size = gr.Slider(label="Tile diffusion size", minimum=512, maximum=1024, value=512, step=256, info="Size of each tile for tiled diffusion.")
195
+ tile_diffusion_stride = gr.Slider(label="Tile diffusion stride", minimum=256, maximum=512, value=256, step=128, info="Stride between tiles for tiled diffusion.")
196
+ tile_vae = gr.Checkbox(label="Tile VAE", value=True, info="Enable tiled VAE for large images.")
197
+ vae_encoder_tile_size = gr.Slider(label="Encoder tile size", minimum=512, maximum=5000, value=1024, step=256, info="Size of each tile for the VAE encoder.")
198
+ vae_decoder_tile_size = gr.Slider(label="Decoder tile size", minimum=64, maximum=512, value=224, step=128, info="Size of each tile for the VAE decoder.")
199
+ with gr.Column():
200
+ result_gallery = gr.Gallery(label="Output", show_label=False, elem_id="gallery").style(grid=2, height="auto")
201
+
202
+ inputs = [
203
+ input_image,
204
+ num_samples,
205
+ sr_scale,
206
+ strength,
207
+ positive_prompt,
208
+ negative_prompt,
209
+ cfg_scale,
210
+ steps,
211
+ use_color_fix,
212
+ seed,
213
+ tile_diffusion,
214
+ tile_diffusion_size,
215
+ tile_diffusion_stride,
216
+ tile_vae,
217
+ vae_encoder_tile_size,
218
+ vae_decoder_tile_size,
219
+ ]
220
+ run_button.click(fn=process, inputs=inputs, outputs=[result_gallery])
221
 
222
+ block.launch()