Shane922 commited on
Commit
1532f4e
·
1 Parent(s): 9ee2acc

limit video infer memory occupy

Browse files
video_depth_anything/video_depth.py CHANGED
@@ -65,6 +65,12 @@ class VideoDepthAnything(nn.Module):
65
  return depth.squeeze(1).unflatten(0, (B, T)) # return shape [B, T, H, W]
66
 
67
  def infer_video_depth(self, frames, target_fps, input_size=518, device='cuda'):
 
 
 
 
 
 
68
  transform = Compose([
69
  Resize(
70
  width=input_size,
@@ -79,7 +85,6 @@ class VideoDepthAnything(nn.Module):
79
  PrepareForNet(),
80
  ])
81
 
82
- frame_size = frames[0].shape[:2]
83
  frame_list = [frames[i] for i in range(frames.shape[0])]
84
  frame_step = INFER_LEN - OVERLAP
85
  org_video_len = len(frame_list)
@@ -99,7 +104,7 @@ class VideoDepthAnything(nn.Module):
99
  with torch.no_grad():
100
  depth = self.forward(cur_input) # depth shape: [1, T, H, W]
101
 
102
- depth = F.interpolate(depth.flatten(0,1).unsqueeze(1), size=frame_size, mode='bilinear', align_corners=True)
103
  depth_list += [depth[i][0].cpu().numpy() for i in range(depth.shape[0])]
104
 
105
  pre_input = cur_input
 
65
  return depth.squeeze(1).unflatten(0, (B, T)) # return shape [B, T, H, W]
66
 
67
  def infer_video_depth(self, frames, target_fps, input_size=518, device='cuda'):
68
+ frame_height, frame_width = frames[0].shape[:2]
69
+ ratio = max(frame_height, frame_width) / min(frame_height, frame_width)
70
+ if ratio > 1.78:
71
+ input_size = int(input_size * 1.78 / ratio)
72
+ input_size = round(input_size / 14) * 14
73
+
74
  transform = Compose([
75
  Resize(
76
  width=input_size,
 
85
  PrepareForNet(),
86
  ])
87
 
 
88
  frame_list = [frames[i] for i in range(frames.shape[0])]
89
  frame_step = INFER_LEN - OVERLAP
90
  org_video_len = len(frame_list)
 
104
  with torch.no_grad():
105
  depth = self.forward(cur_input) # depth shape: [1, T, H, W]
106
 
107
+ depth = F.interpolate(depth.flatten(0,1).unsqueeze(1), size=(frame_height, frame_width), mode='bilinear', align_corners=True)
108
  depth_list += [depth[i][0].cpu().numpy() for i in range(depth.shape[0])]
109
 
110
  pre_input = cur_input