WindVChen commited on
Commit
e200a3f
β€’
1 Parent(s): 2cfb929

Upload 23 files

Browse files
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: πŸ‘‹πŸƒβ€β™‚οΈ
4
  colorFrom: purple
5
  colorTo: pink
6
  sdk: gradio
7
- sdk_version: 3.24.0
8
  app_file: app.py
9
  python_version: 3.8.11
10
  pinned: false
 
4
  colorFrom: purple
5
  colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 3.26.0
8
  app_file: app.py
9
  python_version: 3.8.11
10
  pinned: false
app.py CHANGED
@@ -6,7 +6,6 @@ import gradio as gr
6
  import numpy as np
7
  import sys
8
  import io
9
- import torch
10
 
11
 
12
  class Logger:
@@ -38,7 +37,7 @@ def read_logs():
38
  return out
39
 
40
 
41
- with gr.Blocks() as app:
42
  gr.Markdown("""
43
  # HINet (or INR-Harmonization) - A novel image Harmonization method based on Implicit neural Networks
44
  ## Harmonize any image you want! Arbitrary resolution, and arbitrary aspect ratio!
@@ -49,6 +48,16 @@ with gr.Blocks() as app:
49
  * Official Repo: [INR-Harmonization](https://github.com/WindVChen/INR-Harmonization)
50
  """)
51
 
 
 
 
 
 
 
 
 
 
 
52
  valid_checkpoints_dict = {"Resolution_256_iHarmony4": "Resolution_256_iHarmony4.pth",
53
  "Resolution_1024_HAdobe5K": "Resolution_1024_HAdobe5K.pth",
54
  "Resolution_2048_HAdobe5K": "Resolution_2048_HAdobe5K.pth",
@@ -61,13 +70,12 @@ with gr.Blocks() as app:
61
  })
62
  with gr.Row():
63
  with gr.Column():
64
- form_composite_image = gr.Image(label='Input Composite image', type='pil').style(height="auto")
65
- gr.Examples(examples=[os.path.join("demo", i) for i in os.listdir("demo") if "composite" in i],
66
  label="Composite Examples", inputs=form_composite_image, cache_examples=False)
67
  with gr.Column():
68
- form_mask_image = gr.Image(label='Input Mask image', type='pil', interactive=False).style(
69
- height="auto")
70
- gr.Examples(examples=[os.path.join("demo", i) for i in os.listdir("demo") if "mask" in i],
71
  label="Mask Examples", inputs=form_mask_image, cache_examples=False)
72
  with gr.Row():
73
  with gr.Column(scale=4):
@@ -109,15 +117,14 @@ with gr.Blocks() as app:
109
  label="Split Resolution",
110
  )
111
  form_split_num = gr.Number(
112
- value=8,
113
  interactive=False,
114
  label="Split Number")
115
  with gr.Row():
116
  form_log = gr.Textbox(read_logs, label="Logs", interactive=False, type="text", every=1)
117
 
118
  with gr.Column(scale=4):
119
- form_harmonized_image = gr.Image(label='Harmonized Result', type='numpy', interactive=False).style(
120
- height="auto")
121
  form_start_btn = gr.Button("Start Harmonization", interactive=False)
122
  form_reset_btn = gr.Button("Reset", interactive=True)
123
  form_stop_btn = gr.Button("Stop", interactive=True)
@@ -126,7 +133,7 @@ with gr.Blocks() as app:
126
  def on_change_form_composite_image(form_composite_image):
127
  if form_composite_image is None:
128
  return gr.update(interactive=False, value=None), gr.update(value=None)
129
- return gr.update(interactive=True), gr.update(value=None)
130
 
131
 
132
  def on_change_form_mask_image(form_composite_image, form_mask_image):
@@ -141,15 +148,15 @@ with gr.Blocks() as app:
141
  w, h = form_composite_image.size[:2]
142
  if h != w or (h % 16 != 0):
143
  return gr.update(value='Arbitrary Image', interactive=False), gr.update(interactive=True), gr.update(
144
- interactive=True), gr.update(interactive=True), gr.update(interactive=False,
145
- value=-1), gr.update(value=None)
146
  else:
147
  return gr.update(value='Square Image', interactive=True), gr.update(interactive=True), gr.update(
148
- interactive=True), gr.update(interactive=False), gr.update(interactive=True,
149
- value=h // 16,
150
  maximum=h,
151
  minimum=h // 16,
152
- step=h // 16), gr.update(value=None)
153
 
154
 
155
  form_composite_image.change(
@@ -185,9 +192,9 @@ with gr.Blocks() as app:
185
 
186
  def on_change_form_inference_mode(form_inference_mode):
187
  if form_inference_mode == "Square Image":
188
- return gr.update(interactive=True), gr.update(interactive=False)
189
  else:
190
- return gr.update(interactive=False), gr.update(interactive=True)
191
 
192
 
193
  form_inference_mode.change(on_change_form_inference_mode, inputs=[form_inference_mode],
@@ -197,6 +204,7 @@ with gr.Blocks() as app:
197
  def on_click_form_start_btn(form_composite_image, form_mask_image, form_pretrained_dropdown, form_inference_mode,
198
  form_split_res, form_split_num):
199
  log.log = io.BytesIO()
 
200
  if form_inference_mode == "Square Image":
201
  from efficient_inference_for_square_image import parse_args, main_process, global_state
202
  global_state[0] = 1
@@ -287,15 +295,6 @@ with gr.Blocks() as app:
287
  inputs=[form_inference_mode],
288
  outputs=[form_log, form_composite_image, form_mask_image, form_start_btn], cancels=generate)
289
 
290
- gr.Markdown("""
291
- ## Quick Start
292
- 1. Select desired `Pretrained Model`.
293
- 2. Select a composite image, and then a mask with the same size.
294
- 3. Select the inference mode (for non-square image, only `Arbitrary Image` support).
295
- 4. Set `Split Resolution` (Patches' resolution) or `Split Number` (How many patches, about N*N) according to the inference mode.
296
- 3. Click `Start` and enjoy it!
297
-
298
- """)
299
  gr.HTML("""
300
  <style>
301
  .container {
 
6
  import numpy as np
7
  import sys
8
  import io
 
9
 
10
 
11
  class Logger:
 
37
  return out
38
 
39
 
40
+ with gr.Blocks(css=".output-image, .input-image, .image-preview {height: 600px !important}") as app:
41
  gr.Markdown("""
42
  # HINet (or INR-Harmonization) - A novel image Harmonization method based on Implicit neural Networks
43
  ## Harmonize any image you want! Arbitrary resolution, and arbitrary aspect ratio!
 
48
  * Official Repo: [INR-Harmonization](https://github.com/WindVChen/INR-Harmonization)
49
  """)
50
 
51
+ gr.Markdown("""
52
+ ## Quick Start
53
+ 1. Select desired `Pretrained Model`.
54
+ 2. Select a composite image, and then a mask with the same size.
55
+ 3. Select the inference mode (for non-square image, only `Arbitrary Image` support). Also note that `Square Image` mode will be much faster than `Arbitrary Image` mode.
56
+ 4. Set `Split Resolution` (Patches' resolution) or `Split Number` (How many patches, about N*N) according to the inference mode.
57
+ 3. Click `Start` and enjoy it!
58
+
59
+ """)
60
+
61
  valid_checkpoints_dict = {"Resolution_256_iHarmony4": "Resolution_256_iHarmony4.pth",
62
  "Resolution_1024_HAdobe5K": "Resolution_1024_HAdobe5K.pth",
63
  "Resolution_2048_HAdobe5K": "Resolution_2048_HAdobe5K.pth",
 
70
  })
71
  with gr.Row():
72
  with gr.Column():
73
+ form_composite_image = gr.Image(label='Input Composite image', type='pil').style(height=512)
74
+ gr.Examples(examples=sorted([os.path.join("demo", i) for i in os.listdir("demo") if "composite" in i]),
75
  label="Composite Examples", inputs=form_composite_image, cache_examples=False)
76
  with gr.Column():
77
+ form_mask_image = gr.Image(label='Input Mask image', type='pil', interactive=False).style(height=512)
78
+ gr.Examples(examples=sorted([os.path.join("demo", i) for i in os.listdir("demo") if "mask" in i]),
 
79
  label="Mask Examples", inputs=form_mask_image, cache_examples=False)
80
  with gr.Row():
81
  with gr.Column(scale=4):
 
117
  label="Split Resolution",
118
  )
119
  form_split_num = gr.Number(
120
+ value=2,
121
  interactive=False,
122
  label="Split Number")
123
  with gr.Row():
124
  form_log = gr.Textbox(read_logs, label="Logs", interactive=False, type="text", every=1)
125
 
126
  with gr.Column(scale=4):
127
+ form_harmonized_image = gr.Image(label='Harmonized Result', type='numpy', interactive=False).style(height=512)
 
128
  form_start_btn = gr.Button("Start Harmonization", interactive=False)
129
  form_reset_btn = gr.Button("Reset", interactive=True)
130
  form_stop_btn = gr.Button("Stop", interactive=True)
 
133
  def on_change_form_composite_image(form_composite_image):
134
  if form_composite_image is None:
135
  return gr.update(interactive=False, value=None), gr.update(value=None)
136
+ return gr.update(interactive=True, value=None), gr.update(value=None)
137
 
138
 
139
  def on_change_form_mask_image(form_composite_image, form_mask_image):
 
148
  w, h = form_composite_image.size[:2]
149
  if h != w or (h % 16 != 0):
150
  return gr.update(value='Arbitrary Image', interactive=False), gr.update(interactive=True), gr.update(
151
+ interactive=True), gr.update(interactive=True, visible=True), gr.update(interactive=False,
152
+ value=-1, visible=False), gr.update(value=None)
153
  else:
154
  return gr.update(value='Square Image', interactive=True), gr.update(interactive=True), gr.update(
155
+ interactive=True), gr.update(interactive=False, visible=False), gr.update(interactive=True,
156
+ value=h // 2,
157
  maximum=h,
158
  minimum=h // 16,
159
+ step=h // 16, visible=True), gr.update(value=None)
160
 
161
 
162
  form_composite_image.change(
 
192
 
193
  def on_change_form_inference_mode(form_inference_mode):
194
  if form_inference_mode == "Square Image":
195
+ return gr.update(interactive=True, visible=True), gr.update(interactive=False, visible=False)
196
  else:
197
+ return gr.update(interactive=False, visible=False), gr.update(interactive=True, visible=True)
198
 
199
 
200
  form_inference_mode.change(on_change_form_inference_mode, inputs=[form_inference_mode],
 
204
  def on_click_form_start_btn(form_composite_image, form_mask_image, form_pretrained_dropdown, form_inference_mode,
205
  form_split_res, form_split_num):
206
  log.log = io.BytesIO()
207
+ print(f"Harmonizing image with {form_composite_image.size[1]}*{form_composite_image.size[0]}...")
208
  if form_inference_mode == "Square Image":
209
  from efficient_inference_for_square_image import parse_args, main_process, global_state
210
  global_state[0] = 1
 
295
  inputs=[form_inference_mode],
296
  outputs=[form_log, form_composite_image, form_mask_image, form_start_btn], cancels=generate)
297
 
 
 
 
 
 
 
 
 
 
298
  gr.HTML("""
299
  <style>
300
  .container {
demo/demo_1k_composite_2.jpg ADDED
demo/demo_1k_composite_3.jpg ADDED
demo/demo_1k_mask_2.jpg ADDED
demo/demo_1k_mask_3.jpg ADDED
demo/demo_composite.jpg ADDED
demo/demo_composite_1.jpg ADDED
demo/demo_composite_2.jpg ADDED
demo/demo_composite_3.jpg ADDED
demo/demo_composite_4.jpg ADDED
demo/demo_composite_5.jpg ADDED
demo/demo_composite_6.jpg ADDED
demo/demo_mask.png ADDED
demo/demo_mask_1.png ADDED
demo/demo_mask_2.png ADDED
demo/demo_mask_3.png ADDED
demo/demo_mask_4.jpg ADDED
demo/demo_mask_5.jpg ADDED
demo/demo_mask_6.jpg ADDED
efficient_inference_for_square_image.py CHANGED
@@ -284,6 +284,7 @@ def inference(model, opt, composite_image=None, mask=None):
284
  mask,
285
  fg_INR_coordinates, start_proportion[0]
286
  )
 
287
  if opt.device == "cuda":
288
  torch.cuda.reset_max_memory_allocated()
289
  torch.cuda.reset_max_memory_cached()
@@ -333,12 +334,11 @@ def inference(model, opt, composite_image=None, mask=None):
333
  def main_process(opt, composite_image=None, mask=None):
334
  cudnn.benchmark = True
335
 
 
336
  model = build_model(opt).to(opt.device)
337
 
338
  load_dict = torch.load(opt.pretrained, map_location='cpu')['model']
339
- for k in load_dict.keys():
340
- if k not in model.state_dict().keys():
341
- print(f"Skip {k}")
342
  model.load_state_dict(load_dict, strict=False)
343
 
344
  return inference(model, opt, composite_image, mask)
 
284
  mask,
285
  fg_INR_coordinates, start_proportion[0]
286
  )
287
+ print("Ready for harmonization...")
288
  if opt.device == "cuda":
289
  torch.cuda.reset_max_memory_allocated()
290
  torch.cuda.reset_max_memory_cached()
 
334
  def main_process(opt, composite_image=None, mask=None):
335
  cudnn.benchmark = True
336
 
337
+ print("Preparing model...")
338
  model = build_model(opt).to(opt.device)
339
 
340
  load_dict = torch.load(opt.pretrained, map_location='cpu')['model']
341
+
 
 
342
  model.load_state_dict(load_dict, strict=False)
343
 
344
  return inference(model, opt, composite_image, mask)
hrnet_ocr.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torch._utils
7
+
8
+ from .ocr import SpatialOCR_Module, SpatialGather_Module
9
+ from .resnetv1b import BasicBlockV1b, BottleneckV1b
10
+
11
+ relu_inplace = True
12
+
13
+
14
+ class HighResolutionModule(nn.Module):
15
+ def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
16
+ num_channels, fuse_method,multi_scale_output=True,
17
+ norm_layer=nn.BatchNorm2d, align_corners=True):
18
+ super(HighResolutionModule, self).__init__()
19
+ self._check_branches(num_branches, num_blocks, num_inchannels, num_channels)
20
+
21
+ self.num_inchannels = num_inchannels
22
+ self.fuse_method = fuse_method
23
+ self.num_branches = num_branches
24
+ self.norm_layer = norm_layer
25
+ self.align_corners = align_corners
26
+
27
+ self.multi_scale_output = multi_scale_output
28
+
29
+ self.branches = self._make_branches(
30
+ num_branches, blocks, num_blocks, num_channels)
31
+ self.fuse_layers = self._make_fuse_layers()
32
+ self.relu = nn.ReLU(inplace=relu_inplace)
33
+
34
+ def _check_branches(self, num_branches, num_blocks, num_inchannels, num_channels):
35
+ if num_branches != len(num_blocks):
36
+ error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
37
+ num_branches, len(num_blocks))
38
+ raise ValueError(error_msg)
39
+
40
+ if num_branches != len(num_channels):
41
+ error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
42
+ num_branches, len(num_channels))
43
+ raise ValueError(error_msg)
44
+
45
+ if num_branches != len(num_inchannels):
46
+ error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
47
+ num_branches, len(num_inchannels))
48
+ raise ValueError(error_msg)
49
+
50
+ def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
51
+ stride=1):
52
+ downsample = None
53
+ if stride != 1 or \
54
+ self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
55
+ downsample = nn.Sequential(
56
+ nn.Conv2d(self.num_inchannels[branch_index],
57
+ num_channels[branch_index] * block.expansion,
58
+ kernel_size=1, stride=stride, bias=False),
59
+ self.norm_layer(num_channels[branch_index] * block.expansion),
60
+ )
61
+
62
+ layers = []
63
+ layers.append(block(self.num_inchannels[branch_index],
64
+ num_channels[branch_index], stride,
65
+ downsample=downsample, norm_layer=self.norm_layer))
66
+ self.num_inchannels[branch_index] = \
67
+ num_channels[branch_index] * block.expansion
68
+ for i in range(1, num_blocks[branch_index]):
69
+ layers.append(block(self.num_inchannels[branch_index],
70
+ num_channels[branch_index],
71
+ norm_layer=self.norm_layer))
72
+
73
+ return nn.Sequential(*layers)
74
+
75
+ def _make_branches(self, num_branches, block, num_blocks, num_channels):
76
+ branches = []
77
+
78
+ for i in range(num_branches):
79
+ branches.append(
80
+ self._make_one_branch(i, block, num_blocks, num_channels))
81
+
82
+ return nn.ModuleList(branches)
83
+
84
+ def _make_fuse_layers(self):
85
+ if self.num_branches == 1:
86
+ return None
87
+
88
+ num_branches = self.num_branches
89
+ num_inchannels = self.num_inchannels
90
+ fuse_layers = []
91
+ for i in range(num_branches if self.multi_scale_output else 1):
92
+ fuse_layer = []
93
+ for j in range(num_branches):
94
+ if j > i:
95
+ fuse_layer.append(nn.Sequential(
96
+ nn.Conv2d(in_channels=num_inchannels[j],
97
+ out_channels=num_inchannels[i],
98
+ kernel_size=1,
99
+ bias=False),
100
+ self.norm_layer(num_inchannels[i])))
101
+ elif j == i:
102
+ fuse_layer.append(None)
103
+ else:
104
+ conv3x3s = []
105
+ for k in range(i - j):
106
+ if k == i - j - 1:
107
+ num_outchannels_conv3x3 = num_inchannels[i]
108
+ conv3x3s.append(nn.Sequential(
109
+ nn.Conv2d(num_inchannels[j],
110
+ num_outchannels_conv3x3,
111
+ kernel_size=3, stride=2, padding=1, bias=False),
112
+ self.norm_layer(num_outchannels_conv3x3)))
113
+ else:
114
+ num_outchannels_conv3x3 = num_inchannels[j]
115
+ conv3x3s.append(nn.Sequential(
116
+ nn.Conv2d(num_inchannels[j],
117
+ num_outchannels_conv3x3,
118
+ kernel_size=3, stride=2, padding=1, bias=False),
119
+ self.norm_layer(num_outchannels_conv3x3),
120
+ nn.ReLU(inplace=relu_inplace)))
121
+ fuse_layer.append(nn.Sequential(*conv3x3s))
122
+ fuse_layers.append(nn.ModuleList(fuse_layer))
123
+
124
+ return nn.ModuleList(fuse_layers)
125
+
126
+ def get_num_inchannels(self):
127
+ return self.num_inchannels
128
+
129
+ def forward(self, x):
130
+ if self.num_branches == 1:
131
+ return [self.branches[0](x[0])]
132
+
133
+ for i in range(self.num_branches):
134
+ x[i] = self.branches[i](x[i])
135
+
136
+ x_fuse = []
137
+ for i in range(len(self.fuse_layers)):
138
+ y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
139
+ for j in range(1, self.num_branches):
140
+ if i == j:
141
+ y = y + x[j]
142
+ elif j > i:
143
+ width_output = x[i].shape[-1]
144
+ height_output = x[i].shape[-2]
145
+ y = y + F.interpolate(
146
+ self.fuse_layers[i][j](x[j]),
147
+ size=[height_output, width_output],
148
+ mode='bilinear', align_corners=self.align_corners)
149
+ else:
150
+ y = y + self.fuse_layers[i][j](x[j])
151
+ x_fuse.append(self.relu(y))
152
+
153
+ return x_fuse
154
+
155
+
156
+ class HighResolutionNet(nn.Module):
157
+ def __init__(self, width, num_classes, ocr_width=256, small=False,
158
+ norm_layer=nn.BatchNorm2d, align_corners=True, opt=None):
159
+ super(HighResolutionNet, self).__init__()
160
+ self.opt = opt
161
+ self.norm_layer = norm_layer
162
+ self.width = width
163
+ self.ocr_width = ocr_width
164
+ self.ocr_on = ocr_width > 0
165
+ self.align_corners = align_corners
166
+
167
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
168
+ self.bn1 = norm_layer(64)
169
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
170
+ self.bn2 = norm_layer(64)
171
+ self.relu = nn.ReLU(inplace=relu_inplace)
172
+
173
+ num_blocks = 2 if small else 4
174
+
175
+ stage1_num_channels = 64
176
+ self.layer1 = self._make_layer(BottleneckV1b, 64, stage1_num_channels, blocks=num_blocks)
177
+ stage1_out_channel = BottleneckV1b.expansion * stage1_num_channels
178
+
179
+ self.stage2_num_branches = 2
180
+ num_channels = [width, 2 * width]
181
+ num_inchannels = [
182
+ num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
183
+ self.transition1 = self._make_transition_layer(
184
+ [stage1_out_channel], num_inchannels)
185
+ self.stage2, pre_stage_channels = self._make_stage(
186
+ BasicBlockV1b, num_inchannels=num_inchannels, num_modules=1, num_branches=self.stage2_num_branches,
187
+ num_blocks=2 * [num_blocks], num_channels=num_channels)
188
+
189
+ self.stage3_num_branches = 3
190
+ num_channels = [width, 2 * width, 4 * width]
191
+ num_inchannels = [
192
+ num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
193
+ self.transition2 = self._make_transition_layer(
194
+ pre_stage_channels, num_inchannels)
195
+ self.stage3, pre_stage_channels = self._make_stage(
196
+ BasicBlockV1b, num_inchannels=num_inchannels,
197
+ num_modules=3 if small else 4, num_branches=self.stage3_num_branches,
198
+ num_blocks=3 * [num_blocks], num_channels=num_channels)
199
+
200
+ self.stage4_num_branches = 4
201
+ num_channels = [width, 2 * width, 4 * width, 8 * width]
202
+ num_inchannels = [
203
+ num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
204
+ self.transition3 = self._make_transition_layer(
205
+ pre_stage_channels, num_inchannels)
206
+ self.stage4, pre_stage_channels = self._make_stage(
207
+ BasicBlockV1b, num_inchannels=num_inchannels, num_modules=2 if small else 3,
208
+ num_branches=self.stage4_num_branches,
209
+ num_blocks=4 * [num_blocks], num_channels=num_channels)
210
+
211
+ if self.ocr_on:
212
+ last_inp_channels = np.int(np.sum(pre_stage_channels))
213
+ ocr_mid_channels = 2 * ocr_width
214
+ ocr_key_channels = ocr_width
215
+
216
+ self.conv3x3_ocr = nn.Sequential(
217
+ nn.Conv2d(last_inp_channels, ocr_mid_channels,
218
+ kernel_size=3, stride=1, padding=1),
219
+ norm_layer(ocr_mid_channels),
220
+ nn.ReLU(inplace=relu_inplace),
221
+ )
222
+ self.ocr_gather_head = SpatialGather_Module(num_classes)
223
+
224
+ self.ocr_distri_head = SpatialOCR_Module(in_channels=ocr_mid_channels,
225
+ key_channels=ocr_key_channels,
226
+ out_channels=ocr_mid_channels,
227
+ scale=1,
228
+ dropout=0.05,
229
+ norm_layer=norm_layer,
230
+ align_corners=align_corners, opt=opt)
231
+
232
+ def _make_transition_layer(
233
+ self, num_channels_pre_layer, num_channels_cur_layer):
234
+ num_branches_cur = len(num_channels_cur_layer)
235
+ num_branches_pre = len(num_channels_pre_layer)
236
+
237
+ transition_layers = []
238
+ for i in range(num_branches_cur):
239
+ if i < num_branches_pre:
240
+ if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
241
+ transition_layers.append(nn.Sequential(
242
+ nn.Conv2d(num_channels_pre_layer[i],
243
+ num_channels_cur_layer[i],
244
+ kernel_size=3,
245
+ stride=1,
246
+ padding=1,
247
+ bias=False),
248
+ self.norm_layer(num_channels_cur_layer[i]),
249
+ nn.ReLU(inplace=relu_inplace)))
250
+ else:
251
+ transition_layers.append(None)
252
+ else:
253
+ conv3x3s = []
254
+ for j in range(i + 1 - num_branches_pre):
255
+ inchannels = num_channels_pre_layer[-1]
256
+ outchannels = num_channels_cur_layer[i] \
257
+ if j == i - num_branches_pre else inchannels
258
+ conv3x3s.append(nn.Sequential(
259
+ nn.Conv2d(inchannels, outchannels,
260
+ kernel_size=3, stride=2, padding=1, bias=False),
261
+ self.norm_layer(outchannels),
262
+ nn.ReLU(inplace=relu_inplace)))
263
+ transition_layers.append(nn.Sequential(*conv3x3s))
264
+
265
+ return nn.ModuleList(transition_layers)
266
+
267
+ def _make_layer(self, block, inplanes, planes, blocks, stride=1):
268
+ downsample = None
269
+ if stride != 1 or inplanes != planes * block.expansion:
270
+ downsample = nn.Sequential(
271
+ nn.Conv2d(inplanes, planes * block.expansion,
272
+ kernel_size=1, stride=stride, bias=False),
273
+ self.norm_layer(planes * block.expansion),
274
+ )
275
+
276
+ layers = []
277
+ layers.append(block(inplanes, planes, stride,
278
+ downsample=downsample, norm_layer=self.norm_layer))
279
+ inplanes = planes * block.expansion
280
+ for i in range(1, blocks):
281
+ layers.append(block(inplanes, planes, norm_layer=self.norm_layer))
282
+
283
+ return nn.Sequential(*layers)
284
+
285
+ def _make_stage(self, block, num_inchannels,
286
+ num_modules, num_branches, num_blocks, num_channels,
287
+ fuse_method='SUM',
288
+ multi_scale_output=True):
289
+ modules = []
290
+ for i in range(num_modules):
291
+ # multi_scale_output is only used last module
292
+ if not multi_scale_output and i == num_modules - 1:
293
+ reset_multi_scale_output = False
294
+ else:
295
+ reset_multi_scale_output = True
296
+ modules.append(
297
+ HighResolutionModule(num_branches,
298
+ block,
299
+ num_blocks,
300
+ num_inchannels,
301
+ num_channels,
302
+ fuse_method,
303
+ reset_multi_scale_output,
304
+ norm_layer=self.norm_layer,
305
+ align_corners=self.align_corners)
306
+ )
307
+ num_inchannels = modules[-1].get_num_inchannels()
308
+
309
+ return nn.Sequential(*modules), num_inchannels
310
+
311
+ def forward(self, x, mask=None, additional_features=None):
312
+ hrnet_feats = self.compute_hrnet_feats(x, additional_features)
313
+ if not self.ocr_on:
314
+ return hrnet_feats,
315
+
316
+ ocr_feats = self.conv3x3_ocr(hrnet_feats)
317
+ mask = nn.functional.interpolate(mask, size=ocr_feats.size()[2:], mode='bilinear', align_corners=True)
318
+ context = self.ocr_gather_head(ocr_feats, mask)
319
+ ocr_feats = self.ocr_distri_head(ocr_feats, context)
320
+ return ocr_feats,
321
+
322
+ def compute_hrnet_feats(self, x, additional_features, return_list=False):
323
+ x = self.compute_pre_stage_features(x, additional_features)
324
+ x = self.layer1(x)
325
+
326
+ x_list = []
327
+ for i in range(self.stage2_num_branches):
328
+ if self.transition1[i] is not None:
329
+ x_list.append(self.transition1[i](x))
330
+ else:
331
+ x_list.append(x)
332
+ y_list = self.stage2(x_list)
333
+
334
+ x_list = []
335
+ for i in range(self.stage3_num_branches):
336
+ if self.transition2[i] is not None:
337
+ if i < self.stage2_num_branches:
338
+ x_list.append(self.transition2[i](y_list[i]))
339
+ else:
340
+ x_list.append(self.transition2[i](y_list[-1]))
341
+ else:
342
+ x_list.append(y_list[i])
343
+ y_list = self.stage3(x_list)
344
+
345
+ x_list = []
346
+ for i in range(self.stage4_num_branches):
347
+ if self.transition3[i] is not None:
348
+ if i < self.stage3_num_branches:
349
+ x_list.append(self.transition3[i](y_list[i]))
350
+ else:
351
+ x_list.append(self.transition3[i](y_list[-1]))
352
+ else:
353
+ x_list.append(y_list[i])
354
+ x = self.stage4(x_list)
355
+
356
+ if return_list:
357
+ return x
358
+
359
+ # Upsampling
360
+ x0_h, x0_w = x[0].size(2), x[0].size(3)
361
+ x1 = F.interpolate(x[1], size=(x0_h, x0_w),
362
+ mode='bilinear', align_corners=self.align_corners)
363
+ x2 = F.interpolate(x[2], size=(x0_h, x0_w),
364
+ mode='bilinear', align_corners=self.align_corners)
365
+ x3 = F.interpolate(x[3], size=(x0_h, x0_w),
366
+ mode='bilinear', align_corners=self.align_corners)
367
+
368
+ return torch.cat([x[0], x1, x2, x3], 1)
369
+
370
+ def compute_pre_stage_features(self, x, additional_features):
371
+ x = self.conv1(x)
372
+ x = self.bn1(x)
373
+ x = self.relu(x)
374
+ if additional_features is not None:
375
+ x = x + additional_features
376
+ x = self.conv2(x)
377
+ x = self.bn2(x)
378
+ return self.relu(x)
379
+
380
+ def load_pretrained_weights(self, pretrained_path=''):
381
+ model_dict = self.state_dict()
382
+
383
+ if not os.path.exists(pretrained_path):
384
+ print(f'\nFile "{pretrained_path}" does not exist.')
385
+ print('You need to specify the correct path to the pre-trained weights.\n'
386
+ 'You can download the weights for HRNet from the repository:\n'
387
+ 'https://github.com/HRNet/HRNet-Image-Classification')
388
+ exit(1)
389
+ pretrained_dict = torch.load(pretrained_path, map_location={'cuda:0': 'cpu'})
390
+ pretrained_dict = {k.replace('last_layer', 'aux_head').replace('model.', ''): v for k, v in
391
+ pretrained_dict.items()}
392
+ params_count = len(pretrained_dict)
393
+
394
+ pretrained_dict = {k: v for k, v in pretrained_dict.items()
395
+ if k in model_dict.keys()}
396
+
397
+ # print(f'Loaded {len(pretrained_dict)} of {params_count} pretrained parameters for HRNet')
398
+
399
+ model_dict.update(pretrained_dict)
400
+ self.load_state_dict(model_dict)
inference_for_arbitrary_resolution_image.py CHANGED
@@ -276,6 +276,8 @@ def inference(model, opt, composite_image=None, mask=None):
276
  mask,
277
  fg_INR_coordinates,
278
  )
 
 
279
  if opt.device == "cuda":
280
  torch.cuda.reset_max_memory_allocated()
281
  torch.cuda.reset_max_memory_cached()
@@ -325,12 +327,11 @@ def inference(model, opt, composite_image=None, mask=None):
325
  def main_process(opt, composite_image=None, mask=None):
326
  cudnn.benchmark = True
327
 
 
328
  model = build_model(opt).to(opt.device)
329
 
330
  load_dict = torch.load(opt.pretrained, map_location='cpu')['model']
331
- for k in load_dict.keys():
332
- if k not in model.state_dict().keys():
333
- print(f"Skip {k}")
334
  model.load_state_dict(load_dict, strict=False)
335
 
336
  return inference(model, opt, composite_image, mask)
 
276
  mask,
277
  fg_INR_coordinates,
278
  )
279
+ print("Ready for harmonization...")
280
+
281
  if opt.device == "cuda":
282
  torch.cuda.reset_max_memory_allocated()
283
  torch.cuda.reset_max_memory_cached()
 
327
  def main_process(opt, composite_image=None, mask=None):
328
  cudnn.benchmark = True
329
 
330
+ print("Preparing model...")
331
  model = build_model(opt).to(opt.device)
332
 
333
  load_dict = torch.load(opt.pretrained, map_location='cpu')['model']
334
+
 
 
335
  model.load_state_dict(load_dict, strict=False)
336
 
337
  return inference(model, opt, composite_image, mask)