Spaces:
Running
Running
""" | |
Heart of most evaluation scripts (DAVIS semi-sup/interactive, GUI) | |
Handles propagation and fusion | |
See eval_semi_davis.py / eval_interactive_davis.py for examples | |
""" | |
import numpy as np | |
import torch | |
from lib.MiVOS_STCN.model.aggregate import aggregate_wbg | |
from lib.MiVOS_STCN.model.fusion_net import FusionNet | |
from lib.MiVOS_STCN.model.propagation.prop_net import PropagationNetwork | |
from lib.MiVOS_STCN.util.tensor_util import pad_divide_by | |
class InferenceCore: | |
""" | |
images - leave them in original dimension (unpadded), but do normalize them. | |
Should be CPU tensors of shape B*T*3*H*W | |
mem_profile - How extravagant I can use the GPU memory. | |
Usually more memory -> faster speed but I have not drawn the exact relation | |
0 - Use the most memory | |
1 - Intermediate, larger buffer | |
2 - Intermediate, small buffer | |
3 - Use the minimal amount of GPU memory | |
Note that *none* of the above options will affect the accuracy | |
This is a space-time tradeoff, not a space-performance one | |
mem_freq - Period at which new memory are put in the bank | |
Higher number -> less memory usage | |
Unlike the last option, this *is* a space-performance tradeoff | |
""" | |
def __init__(self, prop_net:PropagationNetwork, fuse_net:FusionNet, images, num_objects, | |
mem_profile=0, mem_freq=5, device='cuda:0'): | |
self.prop_net = prop_net.to(device, non_blocking=True) | |
if fuse_net is not None: | |
self.fuse_net = fuse_net.to(device, non_blocking=True) | |
self.mem_profile = mem_profile | |
self.mem_freq = mem_freq | |
self.device = device | |
if mem_profile == 0: | |
self.data_dev = device | |
self.result_dev = device | |
self.k_buf_size = 105 | |
self.i_buf_size = -1 # no need to buffer image | |
elif mem_profile == 1: | |
self.data_dev = 'cpu' | |
self.result_dev = device | |
self.k_buf_size = 105 | |
self.i_buf_size = 105 | |
elif mem_profile == 2: | |
self.data_dev = 'cpu' | |
self.result_dev = 'cpu' | |
self.k_buf_size = 3 | |
self.i_buf_size = 3 | |
else: | |
self.data_dev = 'cpu' | |
self.result_dev = 'cpu' | |
self.k_buf_size = 1 | |
self.i_buf_size = 1 | |
# True dimensions | |
t = images.shape[1] | |
h, w = images.shape[-2:] | |
self.k = num_objects | |
# Pad each side to multiples of 16 | |
self.images, self.pad = pad_divide_by(images, 16, images.shape[-2:]) | |
# Padded dimensions | |
nh, nw = self.images.shape[-2:] | |
self.images = self.images.to(self.data_dev, non_blocking=False) | |
# These two store the same information in different formats | |
self.masks = torch.zeros((t, 1, nh, nw), dtype=torch.uint8, device=self.result_dev) | |
self.np_masks = np.zeros((t, h, w), dtype=np.uint8) | |
# Object probabilities, background included | |
self.prob = torch.zeros((self.k+1, t, 1, nh, nw), dtype=torch.float32, device=self.result_dev) | |
self.prob[0] = 1e-7 | |
self.t, self.h, self.w = t, h, w | |
self.nh, self.nw = nh, nw | |
self.kh = self.nh//16 | |
self.kw = self.nw//16 | |
self.key_buf = {} | |
self.image_buf = {} | |
self.interacted = set() | |
self.certain_mem_k = None | |
self.certain_mem_v = None | |
def get_image_buffered(self, idx): | |
if self.data_dev == self.device: | |
return self.images[:,idx] | |
# buffer the .cuda() calls | |
if idx not in self.image_buf: | |
# Flush buffer | |
if len(self.image_buf) > self.i_buf_size: | |
self.image_buf = {} | |
self.image_buf[idx] = self.images[:,idx].to(self.device) | |
result = self.image_buf[idx] | |
return result | |
def get_key_feat_buffered(self, idx): | |
if idx not in self.key_buf: | |
# Flush buffer | |
if len(self.key_buf) > self.k_buf_size: | |
self.key_buf = {} | |
self.key_buf[idx] = self.prop_net.encode_key(self.get_image_buffered(idx)) | |
result = self.key_buf[idx] | |
return result | |
def do_pass(self, key_k, key_v, idx, forward=True, step_cb=None): | |
""" | |
Do a complete pass that includes propagation and fusion | |
key_k/key_v - memory feature of the starting frame | |
idx - Frame index of the starting frame | |
forward - forward/backward propagation | |
step_cb - Callback function used for GUI (progress bar) only | |
""" | |
# Pointer in the memory bank | |
num_certain_keys = self.certain_mem_k.shape[2] | |
m_front = num_certain_keys | |
# Determine the required size of the memory bank | |
if forward: | |
closest_ti = min([ti for ti in self.interacted if ti > idx] + [self.t]) | |
total_m = (closest_ti - idx - 1)//self.mem_freq + 1 + num_certain_keys | |
else: | |
closest_ti = max([ti for ti in self.interacted if ti < idx] + [-1]) | |
total_m = (idx - closest_ti - 1)//self.mem_freq + 1 + num_certain_keys | |
_, CK, _, H, W = key_k.shape | |
K, CV, _, _, _ = key_v.shape | |
# Pre-allocate keys/values memory | |
keys = torch.empty((1, CK, total_m, H, W), dtype=torch.float32, device=self.device) | |
values = torch.empty((K, CV, total_m, H, W), dtype=torch.float32, device=self.device) | |
# Initial key/value passed in | |
keys[:,:,0:num_certain_keys] = self.certain_mem_k | |
values[:,:,0:num_certain_keys] = self.certain_mem_v | |
last_ti = idx | |
# Note that we never reach closest_ti, just the frame before it | |
if forward: | |
this_range = range(idx+1, closest_ti) | |
end = closest_ti - 1 | |
else: | |
this_range = range(idx-1, closest_ti, -1) | |
end = closest_ti + 1 | |
for ti in this_range: | |
this_k = keys[:,:,:m_front] | |
this_v = values[:,:,:m_front] | |
k16, qv16, qf16, qf8, qf4 = self.get_key_feat_buffered(ti) | |
out_mask = self.prop_net.segment_with_query(this_k, this_v, qf8, qf4, k16, qv16) | |
out_mask = aggregate_wbg(out_mask, keep_bg=True) | |
if ti != end and abs(ti-last_ti) >= self.mem_freq: | |
keys[:,:,m_front:m_front+1] = k16.unsqueeze(2) | |
values[:,:,m_front:m_front+1] = self.prop_net.encode_value( | |
self.get_image_buffered(ti), qf16, out_mask[1:]) | |
m_front += 1 | |
last_ti = ti | |
# In-place fusion, maximizes the use of queried buffer | |
# esp. for long sequence where the buffer will be flushed | |
if (closest_ti != self.t) and (closest_ti != -1): | |
self.prob[:,ti] = self.fuse_one_frame(closest_ti, idx, ti, self.prob[:,ti], out_mask, | |
key_k, k16).to(self.result_dev) | |
else: | |
self.prob[:,ti] = out_mask.to(self.result_dev) | |
# Callback function for the GUI | |
if step_cb is not None: | |
step_cb() | |
return closest_ti | |
def fuse_one_frame(self, tc, tr, ti, prev_mask, curr_mask, mk16, qk16): | |
assert(tc<ti<tr or tr<ti<tc) | |
prob = torch.zeros((self.k, 1, self.nh, self.nw), dtype=torch.float32, device=self.device) | |
# Compute linear coefficients | |
nc = abs(tc-ti) / abs(tc-tr) | |
nr = abs(tr-ti) / abs(tc-tr) | |
dist = torch.FloatTensor([nc, nr]).to(self.device).unsqueeze(0) | |
attn_map = self.prop_net.get_attention(mk16, self.pos_mask_diff, self.neg_mask_diff, qk16) | |
for k in range(1, self.k+1): | |
w = torch.sigmoid(self.fuse_net(self.get_image_buffered(ti), | |
prev_mask[k:k+1].to(self.device), curr_mask[k:k+1].to(self.device), attn_map[k:k+1], dist)) | |
prob[k-1] = w | |
return aggregate_wbg(prob, keep_bg=True) | |
def interact(self, mask, idx, total_cb=None, step_cb=None): | |
""" | |
Interact -> Propagate -> Fuse | |
mask - One-hot mask of the interacted frame, background included | |
idx - Frame index of the interacted frame | |
total_cb, step_cb - Callback functions for the GUI | |
Return: all mask results in np format for DAVIS evaluation | |
""" | |
self.interacted.add(idx) | |
mask = mask.to(self.device) | |
mask, _ = pad_divide_by(mask, 16, mask.shape[-2:]) | |
self.mask_diff = mask - self.prob[:, idx].to(self.device) | |
self.pos_mask_diff = self.mask_diff.clamp(0, 1) | |
self.neg_mask_diff = (-self.mask_diff).clamp(0, 1) | |
self.prob[:, idx] = mask | |
key_k, _, qf16, _, _ = self.get_key_feat_buffered(idx) | |
key_k = key_k.unsqueeze(2) | |
key_v = self.prop_net.encode_value(self.get_image_buffered(idx), qf16, mask[1:]) | |
if self.certain_mem_k is None: | |
self.certain_mem_k = key_k | |
self.certain_mem_v = key_v | |
else: | |
self.certain_mem_k = torch.cat([self.certain_mem_k, key_k], 2) | |
self.certain_mem_v = torch.cat([self.certain_mem_v, key_v], 2) | |
if total_cb is not None: | |
# Finds the total num. frames to process | |
front_limit = min([ti for ti in self.interacted if ti > idx] + [self.t]) | |
back_limit = max([ti for ti in self.interacted if ti < idx] + [-1]) | |
total_num = front_limit - back_limit - 2 # -1 for shift, -1 for center frame | |
if total_num > 0: | |
total_cb(total_num) | |
self.do_pass(key_k, key_v, idx, True, step_cb=step_cb) | |
self.do_pass(key_k, key_v, idx, False, step_cb=step_cb) | |
# This is a more memory-efficient argmax | |
for ti in range(self.t): | |
self.masks[ti] = torch.argmax(self.prob[:,ti], dim=0) | |
out_masks = self.masks | |
# Trim paddings | |
if self.pad[2]+self.pad[3] > 0: | |
out_masks = out_masks[:,:,self.pad[2]:-self.pad[3],:] | |
if self.pad[0]+self.pad[1] > 0: | |
out_masks = out_masks[:,:,:,self.pad[0]:-self.pad[1]] | |
self.np_masks = (out_masks.detach().cpu().numpy()[:,0]).astype(np.uint8) | |
return self.np_masks | |
def update_mask_only(self, prob_mask, idx): | |
""" | |
Interaction only, no propagation/fusion | |
prob_mask - mask of the interacted frame, background included | |
idx - Frame index of the interacted frame | |
Return: all mask results in np format for DAVIS evaluation | |
""" | |
mask = torch.argmax(prob_mask, 0) | |
self.masks[idx] = mask | |
# Mask - 1 * H * W | |
if self.pad[2]+self.pad[3] > 0: | |
mask = mask[:,self.pad[2]:-self.pad[3],:] | |
if self.pad[0]+self.pad[1] > 0: | |
mask = mask[:,:,self.pad[0]:-self.pad[1]] | |
mask = (mask.detach().cpu().numpy()[0]).astype(np.uint8) | |
self.np_masks[idx] = mask | |
return self.np_masks |