|
import numpy as np |
|
|
|
from utils.image import one_hot_mask |
|
|
|
from networks.layers.basic import seq_to_2d |
|
from networks.engines.aot_engine import AOTEngine, AOTInferEngine |
|
|
|
|
|
class DeAOTEngine(AOTEngine): |
|
def __init__(self, |
|
aot_model, |
|
gpu_id=0, |
|
long_term_mem_gap=9999, |
|
short_term_mem_skip=1, |
|
layer_loss_scaling_ratio=2., |
|
max_len_long_term=9999): |
|
super().__init__(aot_model, gpu_id, long_term_mem_gap, |
|
short_term_mem_skip, max_len_long_term) |
|
self.layer_loss_scaling_ratio = layer_loss_scaling_ratio |
|
def update_short_term_memory(self, curr_mask, curr_id_emb=None, skip_long_term_update=False): |
|
|
|
if curr_id_emb is None: |
|
if len(curr_mask.size()) == 3 or curr_mask.size()[0] == 1: |
|
curr_one_hot_mask = one_hot_mask(curr_mask, self.max_obj_num) |
|
else: |
|
curr_one_hot_mask = curr_mask |
|
curr_id_emb = self.assign_identity(curr_one_hot_mask) |
|
|
|
lstt_curr_memories = self.curr_lstt_output[1] |
|
lstt_curr_memories_2d = [] |
|
for layer_idx in range(len(lstt_curr_memories)): |
|
curr_k, curr_v, curr_id_k, curr_id_v = lstt_curr_memories[ |
|
layer_idx] |
|
curr_id_k, curr_id_v = self.AOT.LSTT.layers[ |
|
layer_idx].fuse_key_value_id(curr_id_k, curr_id_v, curr_id_emb) |
|
lstt_curr_memories[layer_idx][2], lstt_curr_memories[layer_idx][ |
|
3] = curr_id_k, curr_id_v |
|
local_curr_id_k = seq_to_2d( |
|
curr_id_k, self.enc_size_2d) if curr_id_k is not None else None |
|
local_curr_id_v = seq_to_2d(curr_id_v, self.enc_size_2d) |
|
lstt_curr_memories_2d.append([ |
|
seq_to_2d(curr_k, self.enc_size_2d), |
|
seq_to_2d(curr_v, self.enc_size_2d), local_curr_id_k, |
|
local_curr_id_v |
|
]) |
|
|
|
self.short_term_memories_list.append(lstt_curr_memories_2d) |
|
self.short_term_memories_list = self.short_term_memories_list[ |
|
-self.short_term_mem_skip:] |
|
self.short_term_memories = self.short_term_memories_list[0] |
|
|
|
if self.frame_step - self.last_mem_step >= self.long_term_mem_gap: |
|
|
|
if not skip_long_term_update: |
|
self.update_long_term_memory(lstt_curr_memories) |
|
self.last_mem_step = self.frame_step |
|
|
|
|
|
class DeAOTInferEngine(AOTInferEngine): |
|
def __init__(self, |
|
aot_model, |
|
gpu_id=0, |
|
long_term_mem_gap=9999, |
|
short_term_mem_skip=1, |
|
max_aot_obj_num=None, |
|
max_len_long_term=9999): |
|
super().__init__(aot_model, gpu_id, long_term_mem_gap, |
|
short_term_mem_skip, max_aot_obj_num, max_len_long_term) |
|
def add_reference_frame(self, img, mask, obj_nums, frame_step=-1): |
|
if isinstance(obj_nums, list): |
|
obj_nums = obj_nums[0] |
|
self.obj_nums = obj_nums |
|
aot_num = max(np.ceil(obj_nums / self.max_aot_obj_num), 1) |
|
while (aot_num > len(self.aot_engines)): |
|
new_engine = DeAOTEngine(self.AOT, self.gpu_id, |
|
self.long_term_mem_gap, |
|
self.short_term_mem_skip, |
|
max_len_long_term = self.max_len_long_term) |
|
new_engine.eval() |
|
self.aot_engines.append(new_engine) |
|
|
|
separated_masks, separated_obj_nums = self.separate_mask( |
|
mask, obj_nums) |
|
img_embs = None |
|
for aot_engine, separated_mask, separated_obj_num in zip( |
|
self.aot_engines, separated_masks, separated_obj_nums): |
|
if aot_engine.obj_nums is None or aot_engine.obj_nums[0] < separated_obj_num: |
|
aot_engine.add_reference_frame(img, |
|
separated_mask, |
|
obj_nums=[separated_obj_num], |
|
frame_step=frame_step, |
|
img_embs=img_embs) |
|
else: |
|
aot_engine.update_short_term_memory(separated_mask) |
|
if img_embs is None: |
|
img_embs = aot_engine.curr_enc_embs |
|
|
|
self.update_size() |
|
|