Spaces:
Running
Running
File size: 11,010 Bytes
5fd3fc6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 |
"""
Heart of most evaluation scripts (DAVIS semi-sup/interactive, GUI)
Handles propagation and fusion
See eval_semi_davis.py / eval_interactive_davis.py for examples
"""
import numpy as np
import torch
from lib.MiVOS_STCN.model.aggregate import aggregate_wbg
from lib.MiVOS_STCN.model.fusion_net import FusionNet
from lib.MiVOS_STCN.model.propagation.prop_net import PropagationNetwork
from lib.MiVOS_STCN.util.tensor_util import pad_divide_by
class InferenceCore:
"""
images - leave them in original dimension (unpadded), but do normalize them.
Should be CPU tensors of shape B*T*3*H*W
mem_profile - How extravagant I can use the GPU memory.
Usually more memory -> faster speed but I have not drawn the exact relation
0 - Use the most memory
1 - Intermediate, larger buffer
2 - Intermediate, small buffer
3 - Use the minimal amount of GPU memory
Note that *none* of the above options will affect the accuracy
This is a space-time tradeoff, not a space-performance one
mem_freq - Period at which new memory are put in the bank
Higher number -> less memory usage
Unlike the last option, this *is* a space-performance tradeoff
"""
def __init__(self, prop_net:PropagationNetwork, fuse_net:FusionNet, images, num_objects,
mem_profile=0, mem_freq=5, device='cuda:0'):
self.prop_net = prop_net.to(device, non_blocking=True)
if fuse_net is not None:
self.fuse_net = fuse_net.to(device, non_blocking=True)
self.mem_profile = mem_profile
self.mem_freq = mem_freq
self.device = device
if mem_profile == 0:
self.data_dev = device
self.result_dev = device
self.k_buf_size = 105
self.i_buf_size = -1 # no need to buffer image
elif mem_profile == 1:
self.data_dev = 'cpu'
self.result_dev = device
self.k_buf_size = 105
self.i_buf_size = 105
elif mem_profile == 2:
self.data_dev = 'cpu'
self.result_dev = 'cpu'
self.k_buf_size = 3
self.i_buf_size = 3
else:
self.data_dev = 'cpu'
self.result_dev = 'cpu'
self.k_buf_size = 1
self.i_buf_size = 1
# True dimensions
t = images.shape[1]
h, w = images.shape[-2:]
self.k = num_objects
# Pad each side to multiples of 16
self.images, self.pad = pad_divide_by(images, 16, images.shape[-2:])
# Padded dimensions
nh, nw = self.images.shape[-2:]
self.images = self.images.to(self.data_dev, non_blocking=False)
# These two store the same information in different formats
self.masks = torch.zeros((t, 1, nh, nw), dtype=torch.uint8, device=self.result_dev)
self.np_masks = np.zeros((t, h, w), dtype=np.uint8)
# Object probabilities, background included
self.prob = torch.zeros((self.k+1, t, 1, nh, nw), dtype=torch.float32, device=self.result_dev)
self.prob[0] = 1e-7
self.t, self.h, self.w = t, h, w
self.nh, self.nw = nh, nw
self.kh = self.nh//16
self.kw = self.nw//16
self.key_buf = {}
self.image_buf = {}
self.interacted = set()
self.certain_mem_k = None
self.certain_mem_v = None
def get_image_buffered(self, idx):
if self.data_dev == self.device:
return self.images[:,idx]
# buffer the .cuda() calls
if idx not in self.image_buf:
# Flush buffer
if len(self.image_buf) > self.i_buf_size:
self.image_buf = {}
self.image_buf[idx] = self.images[:,idx].to(self.device)
result = self.image_buf[idx]
return result
def get_key_feat_buffered(self, idx):
if idx not in self.key_buf:
# Flush buffer
if len(self.key_buf) > self.k_buf_size:
self.key_buf = {}
self.key_buf[idx] = self.prop_net.encode_key(self.get_image_buffered(idx))
result = self.key_buf[idx]
return result
def do_pass(self, key_k, key_v, idx, forward=True, step_cb=None):
"""
Do a complete pass that includes propagation and fusion
key_k/key_v - memory feature of the starting frame
idx - Frame index of the starting frame
forward - forward/backward propagation
step_cb - Callback function used for GUI (progress bar) only
"""
# Pointer in the memory bank
num_certain_keys = self.certain_mem_k.shape[2]
m_front = num_certain_keys
# Determine the required size of the memory bank
if forward:
closest_ti = min([ti for ti in self.interacted if ti > idx] + [self.t])
total_m = (closest_ti - idx - 1)//self.mem_freq + 1 + num_certain_keys
else:
closest_ti = max([ti for ti in self.interacted if ti < idx] + [-1])
total_m = (idx - closest_ti - 1)//self.mem_freq + 1 + num_certain_keys
_, CK, _, H, W = key_k.shape
K, CV, _, _, _ = key_v.shape
# Pre-allocate keys/values memory
keys = torch.empty((1, CK, total_m, H, W), dtype=torch.float32, device=self.device)
values = torch.empty((K, CV, total_m, H, W), dtype=torch.float32, device=self.device)
# Initial key/value passed in
keys[:,:,0:num_certain_keys] = self.certain_mem_k
values[:,:,0:num_certain_keys] = self.certain_mem_v
last_ti = idx
# Note that we never reach closest_ti, just the frame before it
if forward:
this_range = range(idx+1, closest_ti)
end = closest_ti - 1
else:
this_range = range(idx-1, closest_ti, -1)
end = closest_ti + 1
for ti in this_range:
this_k = keys[:,:,:m_front]
this_v = values[:,:,:m_front]
k16, qv16, qf16, qf8, qf4 = self.get_key_feat_buffered(ti)
out_mask = self.prop_net.segment_with_query(this_k, this_v, qf8, qf4, k16, qv16)
out_mask = aggregate_wbg(out_mask, keep_bg=True)
if ti != end and abs(ti-last_ti) >= self.mem_freq:
keys[:,:,m_front:m_front+1] = k16.unsqueeze(2)
values[:,:,m_front:m_front+1] = self.prop_net.encode_value(
self.get_image_buffered(ti), qf16, out_mask[1:])
m_front += 1
last_ti = ti
# In-place fusion, maximizes the use of queried buffer
# esp. for long sequence where the buffer will be flushed
if (closest_ti != self.t) and (closest_ti != -1):
self.prob[:,ti] = self.fuse_one_frame(closest_ti, idx, ti, self.prob[:,ti], out_mask,
key_k, k16).to(self.result_dev)
else:
self.prob[:,ti] = out_mask.to(self.result_dev)
# Callback function for the GUI
if step_cb is not None:
step_cb()
return closest_ti
def fuse_one_frame(self, tc, tr, ti, prev_mask, curr_mask, mk16, qk16):
assert(tc<ti<tr or tr<ti<tc)
prob = torch.zeros((self.k, 1, self.nh, self.nw), dtype=torch.float32, device=self.device)
# Compute linear coefficients
nc = abs(tc-ti) / abs(tc-tr)
nr = abs(tr-ti) / abs(tc-tr)
dist = torch.FloatTensor([nc, nr]).to(self.device).unsqueeze(0)
attn_map = self.prop_net.get_attention(mk16, self.pos_mask_diff, self.neg_mask_diff, qk16)
for k in range(1, self.k+1):
w = torch.sigmoid(self.fuse_net(self.get_image_buffered(ti),
prev_mask[k:k+1].to(self.device), curr_mask[k:k+1].to(self.device), attn_map[k:k+1], dist))
prob[k-1] = w
return aggregate_wbg(prob, keep_bg=True)
def interact(self, mask, idx, total_cb=None, step_cb=None):
"""
Interact -> Propagate -> Fuse
mask - One-hot mask of the interacted frame, background included
idx - Frame index of the interacted frame
total_cb, step_cb - Callback functions for the GUI
Return: all mask results in np format for DAVIS evaluation
"""
self.interacted.add(idx)
mask = mask.to(self.device)
mask, _ = pad_divide_by(mask, 16, mask.shape[-2:])
self.mask_diff = mask - self.prob[:, idx].to(self.device)
self.pos_mask_diff = self.mask_diff.clamp(0, 1)
self.neg_mask_diff = (-self.mask_diff).clamp(0, 1)
self.prob[:, idx] = mask
key_k, _, qf16, _, _ = self.get_key_feat_buffered(idx)
key_k = key_k.unsqueeze(2)
key_v = self.prop_net.encode_value(self.get_image_buffered(idx), qf16, mask[1:])
if self.certain_mem_k is None:
self.certain_mem_k = key_k
self.certain_mem_v = key_v
else:
self.certain_mem_k = torch.cat([self.certain_mem_k, key_k], 2)
self.certain_mem_v = torch.cat([self.certain_mem_v, key_v], 2)
if total_cb is not None:
# Finds the total num. frames to process
front_limit = min([ti for ti in self.interacted if ti > idx] + [self.t])
back_limit = max([ti for ti in self.interacted if ti < idx] + [-1])
total_num = front_limit - back_limit - 2 # -1 for shift, -1 for center frame
if total_num > 0:
total_cb(total_num)
self.do_pass(key_k, key_v, idx, True, step_cb=step_cb)
self.do_pass(key_k, key_v, idx, False, step_cb=step_cb)
# This is a more memory-efficient argmax
for ti in range(self.t):
self.masks[ti] = torch.argmax(self.prob[:,ti], dim=0)
out_masks = self.masks
# Trim paddings
if self.pad[2]+self.pad[3] > 0:
out_masks = out_masks[:,:,self.pad[2]:-self.pad[3],:]
if self.pad[0]+self.pad[1] > 0:
out_masks = out_masks[:,:,:,self.pad[0]:-self.pad[1]]
self.np_masks = (out_masks.detach().cpu().numpy()[:,0]).astype(np.uint8)
return self.np_masks
def update_mask_only(self, prob_mask, idx):
"""
Interaction only, no propagation/fusion
prob_mask - mask of the interacted frame, background included
idx - Frame index of the interacted frame
Return: all mask results in np format for DAVIS evaluation
"""
mask = torch.argmax(prob_mask, 0)
self.masks[idx] = mask
# Mask - 1 * H * W
if self.pad[2]+self.pad[3] > 0:
mask = mask[:,self.pad[2]:-self.pad[3],:]
if self.pad[0]+self.pad[1] > 0:
mask = mask[:,:,self.pad[0]:-self.pad[1]]
mask = (mask.detach().cpu().numpy()[0]).astype(np.uint8)
self.np_masks[idx] = mask
return self.np_masks |