zzl commited on
Commit
2bbc3ee
·
1 Parent(s): 9fd9429
Files changed (4) hide show
  1. app.py +6 -7
  2. demo_img.py +34 -10
  3. demo_vid.py +20 -2
  4. utils.py +18 -0
app.py CHANGED
@@ -14,19 +14,18 @@ with gr.Blocks(css='style.css') as demo:
14
  <h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
15
  <a href="https://paper99.github.io" style="color:blue;">Zhen Li</a><sup>1*</sup>,
16
  <a href="https://github.com/NK-CS-ZZL" style="color:blue;">Zuo-Liang Zhu</a><sup>1*</sup>,
17
- <a href="https://github.com/hlh981029" style="color:blue;">Ling-Hao Han</a><sup>1*</sup>,
18
- <a href="https://houqb.github.io" style="color:blue;">Qibin Hou</a><sup>1*</sup>,
19
- <a href="https://github.com" style="color:blue;">Chun-Le Guo</a><sup>1*</sup>,
20
- <a href="https://mmcheng.net" style="color:blue;">Ming-Ming Cheng</a><sup>1*</sup>,
21
  </h2>
22
  <h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
23
- <sup>1</sup>Nankai University <sup>*</sup> represents the equal contribution and <sup>#</sup> represents the corresponding author.
24
  </h2>
25
  <h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
26
  [<a href="https://arxiv.org/abs/2303.13439" style="color:blue;">arXiv</a>]
27
  [<a href="https://github.com/MCG-NKU/AMT" style="color:blue;">GitHub</a>]
28
- [<a href="https://github.com/MCG-NKU/AMT" style="color:blue;">Colab</a>]
29
- [<a href="https://github.com/MCG-NKU/AMT" style="color:blue;">Replicate</a>]
30
  </h2>
31
  <p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
32
  """
 
14
  <h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
15
  <a href="https://paper99.github.io" style="color:blue;">Zhen Li</a><sup>1*</sup>,
16
  <a href="https://github.com/NK-CS-ZZL" style="color:blue;">Zuo-Liang Zhu</a><sup>1*</sup>,
17
+ <a href="https://github.com/hlh981029" style="color:blue;">Ling-Hao Han</a><sup>1</sup>,
18
+ <a href="https://houqb.github.io" style="color:blue;">Qibin Hou</a><sup>1</sup>,
19
+ <a href="https://github.com" style="color:blue;">Chun-Le Guo</a><sup>1</sup>,
20
+ <a href="https://mmcheng.net" style="color:blue;">Ming-Ming Cheng</a><sup>1</sup>,
21
  </h2>
22
  <h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
23
+ <sup>1</sup>Nankai University <sup>*</sup> represents the equal contribution.
24
  </h2>
25
  <h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
26
  [<a href="https://arxiv.org/abs/2303.13439" style="color:blue;">arXiv</a>]
27
  [<a href="https://github.com/MCG-NKU/AMT" style="color:blue;">GitHub</a>]
28
+ [<a href="https://colab.research.google.com/drive/1IeVO5BmLouhRh6fL2z_y18kgubotoaBq?usp=sharing" style="color:blue;">Colab</a>]
 
29
  </h2>
30
  <p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
31
  """
demo_img.py CHANGED
@@ -7,8 +7,11 @@ from huggingface_hub import hf_hub_download
7
  from networks.amts import Model as AMTS
8
  from networks.amtl import Model as AMTL
9
  from networks.amtg import Model as AMTG
10
- from utils import img2tensor, tensor2img, InputPadder
11
-
 
 
 
12
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
  model_dict = {
14
  'AMT-S': AMTS, 'AMT-L': AMTL, 'AMT-G': AMTG
@@ -23,22 +26,43 @@ def img2vid(model_type, img0, img1, frame_ratio, iters):
23
  model.eval()
24
  img0_t = img2tensor(img0).to(device)
25
  img1_t = img2tensor(img1).to(device)
26
- padder = InputPadder(img0_t.shape, 16)
27
- img0_t, img1_t = padder.pad(img0_t, img1_t)
28
  inputs = [img0_t, img1_t]
 
 
 
 
 
 
 
 
 
 
 
 
29
  embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device)
30
 
 
 
 
 
 
 
 
 
 
 
 
31
  for i in range(iters):
32
  print(f'Iter {i+1}. input_frames={len(inputs)} output_frames={2*len(inputs)-1}')
33
- outputs = [img0_t]
34
  for in_0, in_1 in zip(inputs[:-1], inputs[1:]):
 
 
35
  with torch.no_grad():
36
- imgt_pred = model(in_0, in_1, embt, eval=True)['imgt_pred']
37
- imgt_pred = padder.unpad(imgt_pred)
38
- in_1 = padder.unpad(in_1)
39
- outputs += [imgt_pred, in_1]
40
  inputs = outputs
41
-
42
  out_path = 'results'
43
  size = outputs[0].shape[2:][::-1]
44
  writer = cv2.VideoWriter(f'{out_path}/demo.mp4', cv2.VideoWriter_fourcc(*'mp4v'), frame_ratio, size)
 
7
  from networks.amts import Model as AMTS
8
  from networks.amtl import Model as AMTL
9
  from networks.amtg import Model as AMTG
10
+ from utils import (
11
+ img2tensor, tensor2img,
12
+ InputPadder,
13
+ check_dim_and_resize
14
+ )
15
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
  model_dict = {
17
  'AMT-S': AMTS, 'AMT-L': AMTL, 'AMT-G': AMTG
 
26
  model.eval()
27
  img0_t = img2tensor(img0).to(device)
28
  img1_t = img2tensor(img1).to(device)
 
 
29
  inputs = [img0_t, img1_t]
30
+
31
+ if device == 'cpu':
32
+ # Do not resize in cpu mode
33
+ anchor_resolution = 8192*8192
34
+ anchor_memory = 1
35
+ anchor_memory_bias = 0
36
+ vram_avail = 1
37
+ elif device == 'cuda':
38
+ anchor_resolution = 1024 * 512
39
+ anchor_memory = 1500 * 1024**2
40
+ anchor_memory_bias = 2500 * 1024**2
41
+ vram_avail = torch.cuda.get_device_properties(device).total_memory
42
  embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device)
43
 
44
+ inputs = check_dim_and_resize(inputs)
45
+ h, w = inputs[0].shape[-2:]
46
+ scale = anchor_resolution / (h * w) * np.sqrt((vram_avail - anchor_memory_bias) / anchor_memory)
47
+ scale = 1 if scale > 1 else scale
48
+ scale = 1 / np.floor(1 / np.sqrt(scale) * 16) * 16
49
+ if scale < 1:
50
+ print(f"Due to the limited VRAM, the video will be scaled by {scale:.2f}")
51
+ padding = int(16 / scale)
52
+ padder = InputPadder(inputs[0].shape, padding)
53
+ inputs = padder.pad(*inputs)
54
+
55
  for i in range(iters):
56
  print(f'Iter {i+1}. input_frames={len(inputs)} output_frames={2*len(inputs)-1}')
57
+ outputs = [inputs[0]]
58
  for in_0, in_1 in zip(inputs[:-1], inputs[1:]):
59
+ in_0 = in_0.to(device)
60
+ in_1 = in_1.to(device)
61
  with torch.no_grad():
62
+ imgt_pred = model(in_0, in_1, embt, scale_factor=scale, eval=True)['imgt_pred']
63
+ outputs += [imgt_pred.cpu(), in_1.cpu()]
 
 
64
  inputs = outputs
65
+ outputs = padder.unpad(*outputs)
66
  out_path = 'results'
67
  size = outputs[0].shape[2:][::-1]
68
  writer = cv2.VideoWriter(f'{out_path}/demo.mp4', cv2.VideoWriter_fourcc(*'mp4v'), frame_ratio, size)
demo_vid.py CHANGED
@@ -27,7 +27,25 @@ def vid2vid(model_type, video, iters):
27
  inputs = []
28
  h = int(vcap.get(cv2.CAP_PROP_FRAME_WIDTH))
29
  w = int(vcap.get(cv2.CAP_PROP_FRAME_HEIGHT))
30
- padder = InputPadder((h, w), 16)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  while True:
32
  ret, frame = vcap.read()
33
  if ret is False:
@@ -43,7 +61,7 @@ def vid2vid(model_type, video, iters):
43
  outputs = [inputs[0]]
44
  for in_0, in_1 in zip(inputs[:-1], inputs[1:]):
45
  with torch.no_grad():
46
- imgt_pred = model(in_0, in_1, embt, eval=True)['imgt_pred']
47
  imgt_pred = padder.unpad(imgt_pred)
48
  in_1 = padder.unpad(in_1)
49
  outputs += [imgt_pred, in_1]
 
27
  inputs = []
28
  h = int(vcap.get(cv2.CAP_PROP_FRAME_WIDTH))
29
  w = int(vcap.get(cv2.CAP_PROP_FRAME_HEIGHT))
30
+ if device == 'cpu':
31
+ # Do not resize in cpu mode
32
+ anchor_resolution = 8192*8192
33
+ anchor_memory = 1
34
+ anchor_memory_bias = 0
35
+ vram_avail = 1
36
+ elif device == 'cuda':
37
+ anchor_resolution = 1024 * 512
38
+ anchor_memory = 1500 * 1024**2
39
+ anchor_memory_bias = 2500 * 1024**2
40
+ vram_avail = torch.cuda.get_device_properties(device).total_memory
41
+
42
+ scale = anchor_resolution / (h * w) * np.sqrt((vram_avail - anchor_memory_bias) / anchor_memory)
43
+ scale = 1 if scale > 1 else scale
44
+ scale = 1 / np.floor(1 / np.sqrt(scale) * 16) * 16
45
+ if scale < 1:
46
+ print(f"Due to the limited VRAM, the video will be scaled by {scale:.2f}")
47
+ padding = int(16 / scale)
48
+ padder = InputPadder(inputs[0].shape, padding)
49
  while True:
50
  ret, frame = vcap.read()
51
  if ret is False:
 
61
  outputs = [inputs[0]]
62
  for in_0, in_1 in zip(inputs[:-1], inputs[1:]):
63
  with torch.no_grad():
64
+ imgt_pred = model(in_0, in_1, embt, scale_factor=scale, eval=True)['imgt_pred']
65
  imgt_pred = padder.unpad(imgt_pred)
66
  in_1 = padder.unpad(in_1)
67
  outputs += [imgt_pred, in_1]
utils.py CHANGED
@@ -227,3 +227,21 @@ def warp(img, flow):
227
  grid_ = (grid + flow_).permute(0, 2, 3, 1)
228
  output = F.grid_sample(input=img, grid=grid_, mode='bilinear', padding_mode='border', align_corners=True)
229
  return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  grid_ = (grid + flow_).permute(0, 2, 3, 1)
228
  output = F.grid_sample(input=img, grid=grid_, mode='bilinear', padding_mode='border', align_corners=True)
229
  return output
230
+
231
+ def check_dim_and_resize(tensor_list):
232
+ shape_list = []
233
+ for t in tensor_list:
234
+ shape_list.append(t.shape[2:])
235
+
236
+ if len(set(shape_list)) > 1:
237
+ desired_shape = shape_list[0]
238
+ print(f'Inconsistent size of input video frames. All frames will be resized to {desired_shape}')
239
+
240
+ resize_tensor_list = []
241
+ for t in tensor_list:
242
+ resize_tensor_list.append(torch.nn.functional.interpolate(t, size=tuple(desired_shape), mode='bilinear'))
243
+
244
+ tensor_list = resize_tensor_list
245
+
246
+ return tensor_list
247
+