Zhenyu Li commited on
Commit
72075ba
Β·
1 Parent(s): dd94046
.gitignore CHANGED
@@ -153,4 +153,6 @@ dcgm/
153
 
154
  .err
155
  .out
156
- script_run.sh
 
 
 
153
 
154
  .err
155
  .out
156
+ script_run.sh
157
+
158
+ gradio_cached_examples
examples/example_1.jpeg β†’ ControlNet/ldm/modules/image_degradation/utils/test.png RENAMED
File without changes
app.py CHANGED
@@ -46,6 +46,7 @@ from ui_prediction import predict_depth
46
  import torch.nn.functional as F
47
 
48
  from huggingface_hub import hf_hub_download
 
49
 
50
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
51
 
@@ -106,13 +107,52 @@ model.load_state_dict(load_state_dict(controlnet_ckp, location=DEVICE), strict=F
106
  model = model.to(DEVICE)
107
  ddim_sampler = DDIMSampler(model)
108
 
109
- # controlnet
110
- title = "# PatchFusion"
111
- description = """Official demo for **PatchFusion: An End-to-End Tile-Based Framework for High-Resolution Monocular Metric Depth Estimation**.
 
112
 
113
- PatchFusion is a deep learning model for high-resolution metric depth estimation from a single image.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
- Please refer to our [paper](???) or [github](???) for more details."""
 
 
116
 
117
  def rescale(A, lbound=-1, ubound=1):
118
  """
@@ -142,9 +182,10 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti
142
  torch.cuda.empty_cache()
143
 
144
  detected_map = F.interpolate(torch.from_numpy(detected_map).unsqueeze(dim=0).unsqueeze(dim=0), (image_resolution, image_resolution), mode='bicubic', align_corners=True).squeeze().numpy()
 
145
 
146
  H, W = detected_map.shape
147
- detected_map_temp = ((1 - detected_map / np.max(detected_map)) * 255)
148
  detected_map = detected_map_temp.astype("uint8")
149
 
150
  detected_map_temp = detected_map_temp[:, :, None]
@@ -184,11 +225,15 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti
184
 
185
  results = [x_samples[i] for i in range(num_samples)]
186
 
187
- return_list = [detected_map_temp] + results
188
  update_return_list = []
189
- for r in return_list:
190
- t_r = torch.from_numpy(r).unsqueeze(dim=0).permute(0, 3, 1, 2)
191
- t_r = F.interpolate(t_r, (h, w), mode='bicubic', align_corners=True).squeeze().permute(1, 2, 0).numpy().astype(np.uint8)
 
 
 
 
192
  update_return_list.append(t_r)
193
 
194
  return update_return_list
@@ -202,17 +247,13 @@ Please refer to our [project webpage](https://zhyever.github.io/patchfusion), [p
202
 
203
  # Advanced tips
204
 
205
- I know people don't like reading introductions, so you could run this demo without any extra modifications.
206
-
207
- But for people want to do some crazy things, I recommand to read the following texts to better under stand how this demo work.
208
-
209
  The overall pipeline: image --> (PatchFusion) --> depth --> (controlnet) --> generated image.
210
 
211
- As for the PatchFusion, it works on default 4k (2160x3840) resolution. All input images will be resized to 4k before passing through PatchFusion as default. It means if you have a higher resolution image, you can increase the resolution in the advanced option.
212
 
213
- For ControlNet, it works on default 896x896 resolution. Again, all input images will be resized to 896x896 before passing through ControlNet as default. You might be not happy because the 4K->896x896 downsampling, but limited by the GPU resource, this demo could only achieve this.
214
 
215
- We provide some tips might be helpful: (1) Try our experimental demo (see our project website) running on a local 80G gpu. But of course, it would be expired soon (in two days maybe); (2) Clone our code repo, and look for a gpu with more than 24G memory; (3) Clone our code repo, run the depth estimation (there are another demos for depth estimation and image-to-3D), and search for another guided high-resolution image generation strategy; (4) Some kind people give this space a stronger gpu support.
216
  """
217
 
218
  with gr.Blocks() as demo:
@@ -225,33 +266,56 @@ with gr.Blocks() as demo:
225
  with gr.Column():
226
  # input_image = gr.Image(source='upload', type="pil")
227
  input_image = gr.Image(label="Input Image", type='pil')
228
- prompt = gr.Textbox(label="Prompt (input your description)", value='An evening scene with the Eiffel Tower, the bridge under the glow of street lamps and a twilight sky')
229
  run_button = gr.Button("Run")
230
- with gr.Accordion("Advanced options", open=False):
231
- # mode = gr.Radio(["P49", "R"], label="Tiling mode", info="We recommand using P49 for fast evaluation and R with 1024 patches for best visualization results, respectively", elem_id='mode', value='R'),
232
- mode = gr.Radio(["P49", "R"], label="Tiling mode", info="We recommand using P49 for fast evaluation and R with 1024 patches for best visualization results, respectively", elem_id='mode', value='P49'),
233
- patch_number = gr.Slider(1, 1024, label="Please decide the number of random patches (Only useful in mode=R)", step=1, value=256)
234
- resolution = gr.Textbox(label="(PatchFusion) Proccessing resolution (Default 4K. Use 'x' to split height and width.)", elem_id='mode', value='2160x3840')
235
- patch_size = gr.Textbox(label="(PatchFusion) Patch size (Default 1/4 of image resolution. Use 'x' to split height and width.)", elem_id='mode', value='540x960')
236
-
237
- num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
238
- image_resolution = gr.Slider(label="ControlNet image resolution (higher resolution will lead to OOM)", minimum=256, maximum=1024, value=896, step=64)
239
- strength = gr.Slider(label="Control strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
240
- guess_mode = gr.Checkbox(label='Guess Mode', value=False)
241
- # detect_resolution = gr.Slider(label="Depth Resolution", minimum=128, maximum=1024, value=384, step=1)
242
- ddim_steps = gr.Slider(label="steps", minimum=1, maximum=100, value=20, step=1)
243
- scale = gr.Slider(label="guidance scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
244
- seed = gr.Slider(label="seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
245
- eta = gr.Number(label="eta (DDIM)", value=0.0)
246
- a_prompt = gr.Textbox(label="Added prompt", value='best quality, extremely detailed')
247
- n_prompt = gr.Textbox(label="Negative prompt", value='worst quality, low quality, lose details')
248
- with gr.Column():
249
- # result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
250
- result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery")
251
- ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, mode[0], patch_number, resolution, patch_size]
252
- run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
253
- examples = gr.Examples(examples=["examples/example_2.jpeg", "examples/example_4.jpeg", "examples/example_5.jpeg"], inputs=[input_image])
254
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
  if __name__ == '__main__':
257
  demo.queue().launch(share=True)
 
46
  import torch.nn.functional as F
47
 
48
  from huggingface_hub import hf_hub_download
49
+ import matplotlib
50
 
51
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
52
 
 
107
  model = model.to(DEVICE)
108
  ddim_sampler = DDIMSampler(model)
109
 
110
+ def colorize_depth_maps(depth_map, min_depth=0, max_depth=0, cmap='Spectral_r', valid_mask=None):
111
+ """
112
+ Colorize depth maps.
113
+ """
114
 
115
+ percentile = 0.03
116
+ min_depth = np.percentile(depth_map, percentile)
117
+ max_depth = np.percentile(depth_map, 100 - percentile)
118
+
119
+ assert len(depth_map.shape) >= 2, "Invalid dimension"
120
+
121
+ if isinstance(depth_map, torch.Tensor):
122
+ depth = depth_map.detach().clone().squeeze().numpy()
123
+ elif isinstance(depth_map, np.ndarray):
124
+ depth = depth_map.copy().squeeze()
125
+ # reshape to [ (B,) H, W ]
126
+ if depth.ndim < 3:
127
+ depth = depth[np.newaxis, :, :]
128
+
129
+ # colorize
130
+ cm = matplotlib.colormaps[cmap]
131
+ depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1)
132
+ img_colored_np = cm(depth, bytes=False)[:,:,:,0:3] # value from 0 to 1
133
+ img_colored_np = np.rollaxis(img_colored_np, 3, 1)
134
+
135
+ if valid_mask is not None:
136
+ if isinstance(depth_map, torch.Tensor):
137
+ valid_mask = valid_mask.detach().numpy()
138
+ valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W]
139
+ if valid_mask.ndim < 3:
140
+ valid_mask = valid_mask[np.newaxis, np.newaxis, :, :]
141
+ else:
142
+ valid_mask = valid_mask[:, np.newaxis, :, :]
143
+ valid_mask = np.repeat(valid_mask, 3, axis=1)
144
+ img_colored_np[~valid_mask] = 0
145
+
146
+ if isinstance(depth_map, torch.Tensor):
147
+ img_colored = torch.from_numpy(img_colored_np).float()
148
+ elif isinstance(depth_map, np.ndarray):
149
+ img_colored = img_colored_np
150
+
151
+ return img_colored
152
 
153
+ def hack_process(path_input, path_depth=None, path_gen=None):
154
+ if path_depth is not None and path_gen is not None:
155
+ return path_input, path_depth, path_gen
156
 
157
  def rescale(A, lbound=-1, ubound=1):
158
  """
 
182
  torch.cuda.empty_cache()
183
 
184
  detected_map = F.interpolate(torch.from_numpy(detected_map).unsqueeze(dim=0).unsqueeze(dim=0), (image_resolution, image_resolution), mode='bicubic', align_corners=True).squeeze().numpy()
185
+ colored_depth = colorize_depth_maps(detected_map) * 255
186
 
187
  H, W = detected_map.shape
188
+ detected_map_temp = ((1 - detected_map / (np.max(detected_map + 1e-3))) * 255)
189
  detected_map = detected_map_temp.astype("uint8")
190
 
191
  detected_map_temp = detected_map_temp[:, :, None]
 
225
 
226
  results = [x_samples[i] for i in range(num_samples)]
227
 
228
+ return_list = [colored_depth] + results
229
  update_return_list = []
230
+ for idx, r in enumerate(return_list):
231
+ if idx == 0:
232
+ t_r = torch.from_numpy(r)
233
+ else:
234
+ t_r = torch.from_numpy(r).unsqueeze(dim=0).permute(0, 3, 1, 2)
235
+ # t_r = F.interpolate(t_r, (h, w), mode='bicubic', align_corners=True).squeeze().permute(1, 2, 0).numpy().astype(np.uint8)
236
+ t_r = t_r.squeeze().permute(1, 2, 0).numpy().astype(np.uint8)
237
  update_return_list.append(t_r)
238
 
239
  return update_return_list
 
247
 
248
  # Advanced tips
249
 
 
 
 
 
250
  The overall pipeline: image --> (PatchFusion) --> depth --> (controlnet) --> generated image.
251
 
252
+ As for the PatchFusion, it works on default 4k (2160x3840) resolution. All input images will be resized to 4k before passing through PatchFusion as default. It means if you have a higher resolution image, you might want to increase the processing resolution in the advanced option (You would also change the patch size to 1/4 image resolution). Because of the tiling strategy, our PatchFusion would not use more memory or time for even higher resolution inputs if properly setting parameters.
253
 
254
+ For ControlNet, it works on default 896x896 resolution. Again, all input images will be resized to 896x896 before passing through ControlNet as default. You might be not happy because the 4K->896x896 downsampling, but limited by the GPU resource, this demo could only achieve this. This is the memory bottleneck. The output is not resized back to the image resolution for fast inference (Well... It's still so slow now... :D)
255
 
256
+ We provide some tips might be helpful: (1) Try our experimental demo (see our project website) running on a local 80G gpu (you could try high-resolution generation there, like the one in our paper). But of course, it would be expired soon (in two days maybe); (2) Clone our code repo, and look for a gpu with more than 24G memory; (3) Clone our code repo, run the depth estimation (there are another demos for depth estimation and image-to-3D), and search for another guided high-resolution image generation strategy; (4) Some kind people give this space a stronger gpu support.
257
  """
258
 
259
  with gr.Blocks() as demo:
 
266
  with gr.Column():
267
  # input_image = gr.Image(source='upload', type="pil")
268
  input_image = gr.Image(label="Input Image", type='pil')
269
+ prompt = gr.Textbox(label="Prompt (input your description)", value='A cozy cottage in an oil painting, with rich textures and vibrant green foliage')
270
  run_button = gr.Button("Run")
271
+
272
+ depth_image = gr.Image(label="Depth Map", elem_id='img-display-output')
273
+ generated_image = gr.Image(label="Generated Map", elem_id='img-display-output')
274
+
275
+ with gr.Row():
276
+ with gr.Accordion("Advanced options", open=False):
277
+ # mode = gr.Radio(["P49", "R"], label="Tiling mode", info="We recommand using P49 for fast evaluation and R with 1024 patches for best visualization results, respectively", elem_id='mode', value='R'),
278
+ mode = gr.Radio(["P49", "R"], label="Tiling mode", info="We recommand using P49 for fast evaluation and R with 1024 patches for best visualization results, respectively", elem_id='mode', value='P49'),
279
+ patch_number = gr.Slider(1, 1024, label="Please decide the number of random patches (Only useful in mode=R)", step=1, value=256)
280
+ resolution = gr.Textbox(label="(PatchFusion) Proccessing resolution (Default 4K. Use 'x' to split height and width.)", elem_id='mode', value='2160x3840')
281
+ patch_size = gr.Textbox(label="(PatchFusion) Patch size (Default 1/4 of image resolution. Use 'x' to split height and width.)", elem_id='mode', value='540x960')
282
+
283
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
284
+ image_resolution = gr.Slider(label="ControlNet image resolution (higher resolution will lead to OOM)", minimum=256, maximum=1024, value=896, step=64)
285
+ strength = gr.Slider(label="Control strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
286
+ guess_mode = gr.Checkbox(label='Guess Mode', value=False)
287
+ # detect_resolution = gr.Slider(label="Depth Resolution", minimum=128, maximum=1024, value=384, step=1)
288
+ ddim_steps = gr.Slider(label="steps", minimum=1, maximum=100, value=20, step=1)
289
+ scale = gr.Slider(label="guidance scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
290
+ seed = gr.Slider(label="seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
291
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
292
+ a_prompt = gr.Textbox(label="Added prompt", value='best quality, extremely detailed')
293
+ n_prompt = gr.Textbox(label="Negative prompt", value='worst quality, low quality, lose details')
 
294
 
295
+ ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, mode[0], patch_number, resolution, patch_size]
296
+ run_button.click(fn=process, inputs=ips, outputs=[depth_image, generated_image])
297
+ examples = gr.Examples(
298
+ inputs=[input_image, depth_image, generated_image],
299
+ outputs=[input_image, depth_image, generated_image],
300
+ examples=[
301
+ [
302
+ "examples/example_4.jpeg",
303
+ "examples/2_depth.png",
304
+ "examples/2_gen.png",
305
+
306
+ ],
307
+ [
308
+ "examples/example_5.jpeg",
309
+ "examples/3_depth.png",
310
+ "examples/3_gen.png",
311
+ ],
312
+ [
313
+ "examples/example_6.png",
314
+ "examples/4_depth.png",
315
+ "examples/4_gen.png",
316
+ ]],
317
+ cache_examples=True,
318
+ fn=hack_process)
319
 
320
  if __name__ == '__main__':
321
  demo.queue().launch(share=True)
examples/{example_2.jpeg β†’ 2_depth.png} RENAMED
File without changes
examples/{example_3.jpeg β†’ 2_gen.png} RENAMED
File without changes
examples/3_depth.png ADDED

Git LFS Details

  • SHA256: 65683bb4c38c7d714304fec588c54a0cd49cf9662097e331fd02d725781be274
  • Pointer size: 131 Bytes
  • Size of remote file: 222 kB
examples/3_gen.png ADDED

Git LFS Details

  • SHA256: 5d1f56c5006fe8584e064e12e98fd61ea6cdfadbbc5786e4d24b45e07e826579
  • Pointer size: 132 Bytes
  • Size of remote file: 1.12 MB
examples/4_depth.png ADDED

Git LFS Details

  • SHA256: 9e730e0d3efc53cce2e512e493855486f6b7e1b0b5bdf214cedb868b6ac9b68f
  • Pointer size: 131 Bytes
  • Size of remote file: 276 kB
examples/4_gen.png ADDED

Git LFS Details

  • SHA256: 93fc1719a9a3c2b99bbfdc55ceb6d69a12df50ea73e5dfb542cc8ae5f31d71d0
  • Pointer size: 132 Bytes
  • Size of remote file: 1.49 MB
examples/example_6.png ADDED

Git LFS Details

  • SHA256: 6c5ef3c57fc74b7f9cea4ca6075fac88d72b804d076f39bb6f9ee42edf92168a
  • Pointer size: 132 Bytes
  • Size of remote file: 7.58 MB