|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
import numpy as np |
|
|
|
from utils.math import generate_permute_matrix |
|
from utils.image import one_hot_mask |
|
|
|
from networks.layers.basic import seq_to_2d |
|
|
|
|
|
class AOTEngine(nn.Module): |
|
def __init__(self, |
|
aot_model, |
|
gpu_id=0, |
|
long_term_mem_gap=9999, |
|
short_term_mem_skip=1, |
|
max_len_long_term=9999): |
|
super().__init__() |
|
|
|
self.cfg = aot_model.cfg |
|
self.align_corners = aot_model.cfg.MODEL_ALIGN_CORNERS |
|
self.AOT = aot_model |
|
|
|
self.max_obj_num = aot_model.max_obj_num |
|
self.gpu_id = gpu_id |
|
self.long_term_mem_gap = long_term_mem_gap |
|
self.short_term_mem_skip = short_term_mem_skip |
|
self.max_len_long_term = max_len_long_term |
|
self.losses = None |
|
|
|
self.restart_engine() |
|
|
|
def forward(self, |
|
all_frames, |
|
all_masks, |
|
batch_size, |
|
obj_nums, |
|
step=0, |
|
tf_board=False, |
|
use_prev_pred=False, |
|
enable_prev_frame=False, |
|
use_prev_prob=False): |
|
if self.losses is None: |
|
self._init_losses() |
|
|
|
self.freeze_id = True if use_prev_pred else False |
|
aux_weight = self.aux_weight * max(self.aux_step - step, |
|
0.) / self.aux_step |
|
|
|
self.offline_encoder(all_frames, all_masks) |
|
|
|
self.add_reference_frame(frame_step=0, obj_nums=obj_nums) |
|
|
|
grad_state = torch.no_grad if aux_weight == 0 else torch.enable_grad |
|
with grad_state(): |
|
ref_aux_loss, ref_aux_mask = self.generate_loss_mask( |
|
self.offline_masks[self.frame_step], step) |
|
|
|
aux_losses = [ref_aux_loss] |
|
aux_masks = [ref_aux_mask] |
|
|
|
curr_losses, curr_masks = [], [] |
|
if enable_prev_frame: |
|
self.set_prev_frame(frame_step=1) |
|
with grad_state(): |
|
prev_aux_loss, prev_aux_mask = self.generate_loss_mask( |
|
self.offline_masks[self.frame_step], step) |
|
aux_losses.append(prev_aux_loss) |
|
aux_masks.append(prev_aux_mask) |
|
else: |
|
self.match_propogate_one_frame() |
|
curr_loss, curr_mask, curr_prob = self.generate_loss_mask( |
|
self.offline_masks[self.frame_step], step, return_prob=True) |
|
self.update_short_term_memory( |
|
curr_mask if not use_prev_prob else curr_prob, |
|
None if use_prev_pred else self.assign_identity( |
|
self.offline_one_hot_masks[self.frame_step])) |
|
curr_losses.append(curr_loss) |
|
curr_masks.append(curr_mask) |
|
|
|
self.match_propogate_one_frame() |
|
curr_loss, curr_mask, curr_prob = self.generate_loss_mask( |
|
self.offline_masks[self.frame_step], step, return_prob=True) |
|
curr_losses.append(curr_loss) |
|
curr_masks.append(curr_mask) |
|
for _ in range(self.total_offline_frame_num - 3): |
|
self.update_short_term_memory( |
|
curr_mask if not use_prev_prob else curr_prob, |
|
None if use_prev_pred else self.assign_identity( |
|
self.offline_one_hot_masks[self.frame_step])) |
|
self.match_propogate_one_frame() |
|
curr_loss, curr_mask, curr_prob = self.generate_loss_mask( |
|
self.offline_masks[self.frame_step], step, return_prob=True) |
|
curr_losses.append(curr_loss) |
|
curr_masks.append(curr_mask) |
|
|
|
aux_loss = torch.cat(aux_losses, dim=0).mean(dim=0) |
|
pred_loss = torch.cat(curr_losses, dim=0).mean(dim=0) |
|
|
|
loss = aux_weight * aux_loss + pred_loss |
|
|
|
all_pred_mask = aux_masks + curr_masks |
|
|
|
all_frame_loss = aux_losses + curr_losses |
|
|
|
boards = {'image': {}, 'scalar': {}} |
|
|
|
return loss, all_pred_mask, all_frame_loss, boards |
|
|
|
def _init_losses(self): |
|
cfg = self.cfg |
|
|
|
from networks.layers.loss import CrossEntropyLoss, SoftJaccordLoss |
|
bce_loss = CrossEntropyLoss( |
|
cfg.TRAIN_TOP_K_PERCENT_PIXELS, |
|
cfg.TRAIN_HARD_MINING_RATIO * cfg.TRAIN_TOTAL_STEPS) |
|
iou_loss = SoftJaccordLoss() |
|
|
|
losses = [bce_loss, iou_loss] |
|
loss_weights = [0.5, 0.5] |
|
|
|
self.losses = nn.ModuleList(losses) |
|
self.loss_weights = loss_weights |
|
self.aux_weight = cfg.TRAIN_AUX_LOSS_WEIGHT |
|
self.aux_step = cfg.TRAIN_TOTAL_STEPS * cfg.TRAIN_AUX_LOSS_RATIO + 1e-5 |
|
|
|
def encode_one_img_mask(self, img=None, mask=None, frame_step=-1): |
|
if frame_step == -1: |
|
frame_step = self.frame_step |
|
|
|
if self.enable_offline_enc: |
|
curr_enc_embs = self.offline_enc_embs[frame_step] |
|
elif img is None: |
|
curr_enc_embs = None |
|
else: |
|
curr_enc_embs = self.AOT.encode_image(img) |
|
|
|
if mask is not None: |
|
curr_one_hot_mask = one_hot_mask(mask, self.max_obj_num) |
|
elif self.enable_offline_enc: |
|
curr_one_hot_mask = self.offline_one_hot_masks[frame_step] |
|
else: |
|
curr_one_hot_mask = None |
|
|
|
return curr_enc_embs, curr_one_hot_mask |
|
|
|
def offline_encoder(self, all_frames, all_masks=None): |
|
self.enable_offline_enc = True |
|
self.offline_frames = all_frames.size(0) // self.batch_size |
|
|
|
|
|
self.offline_enc_embs = self.split_frames( |
|
self.AOT.encode_image(all_frames), self.batch_size) |
|
self.total_offline_frame_num = len(self.offline_enc_embs) |
|
|
|
if all_masks is not None: |
|
|
|
offline_one_hot_masks = one_hot_mask(all_masks, self.max_obj_num) |
|
self.offline_masks = list( |
|
torch.split(all_masks, self.batch_size, dim=0)) |
|
self.offline_one_hot_masks = list( |
|
torch.split(offline_one_hot_masks, self.batch_size, dim=0)) |
|
|
|
if self.input_size_2d is None: |
|
self.update_size(all_frames.size()[2:], |
|
self.offline_enc_embs[0][-1].size()[2:]) |
|
|
|
def assign_identity(self, one_hot_mask): |
|
if self.enable_id_shuffle: |
|
one_hot_mask = torch.einsum('bohw,bot->bthw', one_hot_mask, |
|
self.id_shuffle_matrix) |
|
|
|
id_emb = self.AOT.get_id_emb(one_hot_mask).view( |
|
self.batch_size, -1, self.enc_hw).permute(2, 0, 1) |
|
|
|
if self.training and self.freeze_id: |
|
id_emb = id_emb.detach() |
|
|
|
return id_emb |
|
|
|
def split_frames(self, xs, chunk_size): |
|
new_xs = [] |
|
for x in xs: |
|
all_x = list(torch.split(x, chunk_size, dim=0)) |
|
new_xs.append(all_x) |
|
return list(zip(*new_xs)) |
|
|
|
def add_reference_frame(self, |
|
img=None, |
|
mask=None, |
|
frame_step=-1, |
|
obj_nums=None, |
|
img_embs=None): |
|
if self.obj_nums is None and obj_nums is None: |
|
print('No objects for reference frame!') |
|
exit() |
|
elif obj_nums is not None: |
|
self.obj_nums = obj_nums |
|
|
|
if frame_step == -1: |
|
frame_step = self.frame_step |
|
|
|
if img_embs is None: |
|
curr_enc_embs, curr_one_hot_mask = self.encode_one_img_mask( |
|
img, mask, frame_step) |
|
else: |
|
_, curr_one_hot_mask = self.encode_one_img_mask( |
|
None, mask, frame_step) |
|
curr_enc_embs = img_embs |
|
|
|
if curr_enc_embs is None: |
|
print('No image for reference frame!') |
|
exit() |
|
|
|
if curr_one_hot_mask is None: |
|
print('No mask for reference frame!') |
|
exit() |
|
|
|
if self.input_size_2d is None: |
|
self.update_size(img.size()[2:], curr_enc_embs[-1].size()[2:]) |
|
|
|
self.curr_enc_embs = curr_enc_embs |
|
self.curr_one_hot_mask = curr_one_hot_mask |
|
|
|
if self.pos_emb is None: |
|
self.pos_emb = self.AOT.get_pos_emb(curr_enc_embs[-1]).expand( |
|
self.batch_size, -1, -1, |
|
-1).view(self.batch_size, -1, self.enc_hw).permute(2, 0, 1) |
|
|
|
curr_id_emb = self.assign_identity(curr_one_hot_mask) |
|
self.curr_id_embs = curr_id_emb |
|
|
|
|
|
self.curr_lstt_output = self.AOT.LSTT_forward(curr_enc_embs, |
|
None, |
|
None, |
|
curr_id_emb, |
|
pos_emb=self.pos_emb, |
|
size_2d=self.enc_size_2d) |
|
|
|
lstt_embs, lstt_curr_memories, lstt_long_memories, lstt_short_memories = self.curr_lstt_output |
|
|
|
if self.long_term_memories is None: |
|
self.long_term_memories = lstt_long_memories |
|
else: |
|
self.update_long_term_memory(lstt_long_memories) |
|
|
|
self.last_mem_step = self.frame_step |
|
|
|
self.short_term_memories_list = [lstt_short_memories] |
|
self.short_term_memories = lstt_short_memories |
|
|
|
def set_prev_frame(self, img=None, mask=None, frame_step=1): |
|
self.frame_step = frame_step |
|
curr_enc_embs, curr_one_hot_mask = self.encode_one_img_mask( |
|
img, mask, frame_step) |
|
|
|
if curr_enc_embs is None: |
|
print('No image for previous frame!') |
|
exit() |
|
|
|
if curr_one_hot_mask is None: |
|
print('No mask for previous frame!') |
|
exit() |
|
|
|
self.curr_enc_embs = curr_enc_embs |
|
self.curr_one_hot_mask = curr_one_hot_mask |
|
|
|
curr_id_emb = self.assign_identity(curr_one_hot_mask) |
|
self.curr_id_embs = curr_id_emb |
|
|
|
|
|
self.curr_lstt_output = self.AOT.LSTT_forward(curr_enc_embs, |
|
None, |
|
None, |
|
curr_id_emb, |
|
pos_emb=self.pos_emb, |
|
size_2d=self.enc_size_2d) |
|
|
|
lstt_embs, lstt_curr_memories, lstt_long_memories, lstt_short_memories = self.curr_lstt_output |
|
|
|
if self.long_term_memories is None: |
|
self.long_term_memories = lstt_long_memories |
|
else: |
|
self.update_long_term_memory(lstt_long_memories) |
|
self.last_mem_step = frame_step |
|
|
|
self.short_term_memories_list = [lstt_short_memories] |
|
self.short_term_memories = lstt_short_memories |
|
|
|
def update_long_term_memory(self, new_long_term_memories): |
|
TOKEN_NUM = new_long_term_memories[0][0].shape[0] |
|
if self.long_term_memories is None: |
|
self.long_term_memories = new_long_term_memories |
|
updated_long_term_memories = [] |
|
for new_long_term_memory, last_long_term_memory in zip( |
|
new_long_term_memories, self.long_term_memories): |
|
updated_e = [] |
|
for new_e, last_e in zip(new_long_term_memory, |
|
last_long_term_memory): |
|
if new_e is None or last_e is None: |
|
updated_e.append(None) |
|
else: |
|
if last_e.shape[0] >= self.max_len_long_term * TOKEN_NUM: |
|
last_e = last_e[:(self.max_len_long_term - 1) * TOKEN_NUM] |
|
updated_e.append(torch.cat([new_e, last_e], dim=0)) |
|
updated_long_term_memories.append(updated_e) |
|
self.long_term_memories = updated_long_term_memories |
|
|
|
def update_short_term_memory(self, curr_mask, curr_id_emb=None, skip_long_term_update=False): |
|
if curr_id_emb is None: |
|
if len(curr_mask.size()) == 3 or curr_mask.size()[0] == 1: |
|
curr_one_hot_mask = one_hot_mask(curr_mask, self.max_obj_num) |
|
else: |
|
curr_one_hot_mask = curr_mask |
|
curr_id_emb = self.assign_identity(curr_one_hot_mask) |
|
|
|
lstt_curr_memories = self.curr_lstt_output[1] |
|
lstt_curr_memories_2d = [] |
|
for layer_idx in range(len(lstt_curr_memories)): |
|
curr_k, curr_v = lstt_curr_memories[layer_idx][ |
|
0], lstt_curr_memories[layer_idx][1] |
|
curr_k, curr_v = self.AOT.LSTT.layers[layer_idx].fuse_key_value_id( |
|
curr_k, curr_v, curr_id_emb) |
|
lstt_curr_memories[layer_idx][0], lstt_curr_memories[layer_idx][ |
|
1] = curr_k, curr_v |
|
lstt_curr_memories_2d.append([ |
|
seq_to_2d(lstt_curr_memories[layer_idx][0], self.enc_size_2d), |
|
seq_to_2d(lstt_curr_memories[layer_idx][1], self.enc_size_2d) |
|
]) |
|
|
|
self.short_term_memories_list.append(lstt_curr_memories_2d) |
|
self.short_term_memories_list = self.short_term_memories_list[ |
|
-self.short_term_mem_skip:] |
|
self.short_term_memories = self.short_term_memories_list[0] |
|
|
|
if self.frame_step - self.last_mem_step >= self.long_term_mem_gap: |
|
|
|
if not skip_long_term_update: |
|
self.update_long_term_memory(lstt_curr_memories) |
|
self.last_mem_step = self.frame_step |
|
|
|
def match_propogate_one_frame(self, img=None, img_embs=None): |
|
self.frame_step += 1 |
|
if img_embs is None: |
|
curr_enc_embs, _ = self.encode_one_img_mask( |
|
img, None, self.frame_step) |
|
else: |
|
curr_enc_embs = img_embs |
|
self.curr_enc_embs = curr_enc_embs |
|
|
|
self.curr_lstt_output = self.AOT.LSTT_forward(curr_enc_embs, |
|
self.long_term_memories, |
|
self.short_term_memories, |
|
None, |
|
pos_emb=self.pos_emb, |
|
size_2d=self.enc_size_2d) |
|
|
|
def decode_current_logits(self, output_size=None): |
|
curr_enc_embs = self.curr_enc_embs |
|
curr_lstt_embs = self.curr_lstt_output[0] |
|
|
|
pred_id_logits = self.AOT.decode_id_logits(curr_lstt_embs, |
|
curr_enc_embs) |
|
|
|
if self.enable_id_shuffle: |
|
pred_id_logits = torch.einsum('bohw,bto->bthw', pred_id_logits, |
|
self.id_shuffle_matrix) |
|
|
|
|
|
for batch_idx, obj_num in enumerate(self.obj_nums): |
|
pred_id_logits[batch_idx, (obj_num+1):] = - \ |
|
1e+10 if pred_id_logits.dtype == torch.float32 else -1e+4 |
|
|
|
self.pred_id_logits = pred_id_logits |
|
|
|
if output_size is not None: |
|
pred_id_logits = F.interpolate(pred_id_logits, |
|
size=output_size, |
|
mode="bilinear", |
|
align_corners=self.align_corners) |
|
|
|
return pred_id_logits |
|
|
|
def predict_current_mask(self, output_size=None, return_prob=False): |
|
if output_size is None: |
|
output_size = self.input_size_2d |
|
|
|
pred_id_logits = F.interpolate(self.pred_id_logits, |
|
size=output_size, |
|
mode="bilinear", |
|
align_corners=self.align_corners) |
|
pred_mask = torch.argmax(pred_id_logits, dim=1) |
|
|
|
if not return_prob: |
|
return pred_mask |
|
else: |
|
pred_prob = torch.softmax(pred_id_logits, dim=1) |
|
return pred_mask, pred_prob |
|
|
|
def calculate_current_loss(self, gt_mask, step): |
|
pred_id_logits = self.pred_id_logits |
|
|
|
pred_id_logits = F.interpolate(pred_id_logits, |
|
size=gt_mask.size()[-2:], |
|
mode="bilinear", |
|
align_corners=self.align_corners) |
|
|
|
label_list = [] |
|
logit_list = [] |
|
for batch_idx, obj_num in enumerate(self.obj_nums): |
|
now_label = gt_mask[batch_idx].long() |
|
now_logit = pred_id_logits[batch_idx, :(obj_num + 1)].unsqueeze(0) |
|
label_list.append(now_label.long()) |
|
logit_list.append(now_logit) |
|
|
|
total_loss = 0 |
|
for loss, loss_weight in zip(self.losses, self.loss_weights): |
|
total_loss = total_loss + loss_weight * \ |
|
loss(logit_list, label_list, step) |
|
|
|
return total_loss |
|
|
|
def generate_loss_mask(self, gt_mask, step, return_prob=False): |
|
self.decode_current_logits() |
|
loss = self.calculate_current_loss(gt_mask, step) |
|
if return_prob: |
|
mask, prob = self.predict_current_mask(return_prob=True) |
|
return loss, mask, prob |
|
else: |
|
mask = self.predict_current_mask() |
|
return loss, mask |
|
|
|
def keep_gt_mask(self, pred_mask, keep_prob=0.2): |
|
pred_mask = pred_mask.float() |
|
gt_mask = self.offline_masks[self.frame_step].float().squeeze(1) |
|
|
|
shape = [1 for _ in range(pred_mask.ndim)] |
|
shape[0] = self.batch_size |
|
random_tensor = keep_prob + torch.rand( |
|
shape, dtype=pred_mask.dtype, device=pred_mask.device) |
|
random_tensor.floor_() |
|
|
|
pred_mask = pred_mask * (1 - random_tensor) + gt_mask * random_tensor |
|
|
|
return pred_mask |
|
|
|
def restart_engine(self, batch_size=1, enable_id_shuffle=False): |
|
|
|
self.batch_size = batch_size |
|
self.frame_step = 0 |
|
self.last_mem_step = -1 |
|
self.enable_id_shuffle = enable_id_shuffle |
|
self.freeze_id = False |
|
|
|
self.obj_nums = None |
|
self.pos_emb = None |
|
self.enc_size_2d = None |
|
self.enc_hw = None |
|
self.input_size_2d = None |
|
|
|
self.long_term_memories = None |
|
self.short_term_memories_list = [] |
|
self.short_term_memories = None |
|
|
|
self.enable_offline_enc = False |
|
self.offline_enc_embs = None |
|
self.offline_one_hot_masks = None |
|
self.offline_frames = -1 |
|
self.total_offline_frame_num = 0 |
|
|
|
self.curr_enc_embs = None |
|
self.curr_memories = None |
|
self.curr_id_embs = None |
|
|
|
if enable_id_shuffle: |
|
self.id_shuffle_matrix = generate_permute_matrix( |
|
self.max_obj_num + 1, batch_size, gpu_id=self.gpu_id) |
|
else: |
|
self.id_shuffle_matrix = None |
|
|
|
def update_size(self, input_size, enc_size): |
|
self.input_size_2d = input_size |
|
self.enc_size_2d = enc_size |
|
self.enc_hw = self.enc_size_2d[0] * self.enc_size_2d[1] |
|
|
|
|
|
class AOTInferEngine(nn.Module): |
|
def __init__(self, |
|
aot_model, |
|
gpu_id=0, |
|
long_term_mem_gap=9999, |
|
short_term_mem_skip=1, |
|
max_aot_obj_num=None, |
|
max_len_long_term=9999,): |
|
super().__init__() |
|
|
|
self.cfg = aot_model.cfg |
|
self.AOT = aot_model |
|
|
|
if max_aot_obj_num is None or max_aot_obj_num > aot_model.max_obj_num: |
|
self.max_aot_obj_num = aot_model.max_obj_num |
|
else: |
|
self.max_aot_obj_num = max_aot_obj_num |
|
|
|
self.gpu_id = gpu_id |
|
self.long_term_mem_gap = long_term_mem_gap |
|
self.short_term_mem_skip = short_term_mem_skip |
|
self.max_len_long_term = max_len_long_term |
|
self.aot_engines = [] |
|
|
|
self.restart_engine() |
|
def restart_engine(self): |
|
del (self.aot_engines) |
|
self.aot_engines = [] |
|
self.obj_nums = None |
|
|
|
def separate_mask(self, mask, obj_nums): |
|
if mask is None: |
|
return [None] * len(self.aot_engines) |
|
if len(self.aot_engines) == 1: |
|
return [mask], [obj_nums] |
|
|
|
separated_obj_nums = [ |
|
self.max_aot_obj_num for _ in range(len(self.aot_engines)) |
|
] |
|
if obj_nums % self.max_aot_obj_num > 0: |
|
separated_obj_nums[-1] = obj_nums % self.max_aot_obj_num |
|
|
|
if len(mask.size()) == 3 or mask.size()[0] == 1: |
|
separated_masks = [] |
|
for idx in range(len(self.aot_engines)): |
|
start_id = idx * self.max_aot_obj_num + 1 |
|
end_id = (idx + 1) * self.max_aot_obj_num |
|
fg_mask = ((mask >= start_id) & (mask <= end_id)).float() |
|
separated_mask = (fg_mask * mask - start_id + 1) * fg_mask |
|
separated_masks.append(separated_mask) |
|
return separated_masks, separated_obj_nums |
|
else: |
|
prob = mask |
|
separated_probs = [] |
|
for idx in range(len(self.aot_engines)): |
|
start_id = idx * self.max_aot_obj_num + 1 |
|
end_id = (idx + 1) * self.max_aot_obj_num |
|
fg_prob = prob[start_id:(end_id + 1)] |
|
bg_prob = 1. - torch.sum(fg_prob, dim=1, keepdim=True) |
|
separated_probs.append(torch.cat([bg_prob, fg_prob], dim=1)) |
|
return separated_probs, separated_obj_nums |
|
|
|
def min_logit_aggregation(self, all_logits): |
|
if len(all_logits) == 1: |
|
return all_logits[0] |
|
|
|
fg_logits = [] |
|
bg_logits = [] |
|
|
|
for logit in all_logits: |
|
bg_logits.append(logit[:, 0:1]) |
|
fg_logits.append(logit[:, 1:1 + self.max_aot_obj_num]) |
|
|
|
bg_logit, _ = torch.min(torch.cat(bg_logits, dim=1), |
|
dim=1, |
|
keepdim=True) |
|
merged_logit = torch.cat([bg_logit] + fg_logits, dim=1) |
|
|
|
return merged_logit |
|
|
|
def soft_logit_aggregation(self, all_logits): |
|
if len(all_logits) == 1: |
|
return all_logits[0] |
|
|
|
fg_probs = [] |
|
bg_probs = [] |
|
|
|
for logit in all_logits: |
|
prob = torch.softmax(logit, dim=1) |
|
bg_probs.append(prob[:, 0:1]) |
|
fg_probs.append(prob[:, 1:1 + self.max_aot_obj_num]) |
|
|
|
bg_prob = torch.prod(torch.cat(bg_probs, dim=1), dim=1, keepdim=True) |
|
merged_prob = torch.cat([bg_prob] + fg_probs, |
|
dim=1).clamp(1e-5, 1 - 1e-5) |
|
merged_logit = torch.logit(merged_prob) |
|
|
|
return merged_logit |
|
|
|
def add_reference_frame(self, img, mask, obj_nums, frame_step=-1): |
|
if isinstance(obj_nums, list): |
|
obj_nums = obj_nums[0] |
|
self.obj_nums = obj_nums |
|
aot_num = max(np.ceil(obj_nums / self.max_aot_obj_num), 1) |
|
while (aot_num > len(self.aot_engines)): |
|
new_engine = AOTEngine(self.AOT, self.gpu_id, |
|
self.long_term_mem_gap, |
|
self.short_term_mem_skip, |
|
self.max_len_long_term,) |
|
new_engine.eval() |
|
self.aot_engines.append(new_engine) |
|
|
|
separated_masks, separated_obj_nums = self.separate_mask( |
|
mask, obj_nums) |
|
img_embs = None |
|
for aot_engine, separated_mask, separated_obj_num in zip( |
|
self.aot_engines, separated_masks, separated_obj_nums): |
|
aot_engine.add_reference_frame(img, |
|
separated_mask, |
|
obj_nums=[separated_obj_num], |
|
frame_step=frame_step, |
|
img_embs=img_embs) |
|
|
|
if img_embs is None: |
|
img_embs = aot_engine.curr_enc_embs |
|
|
|
self.update_size() |
|
|
|
def match_propogate_one_frame(self, img=None): |
|
img_embs = None |
|
for aot_engine in self.aot_engines: |
|
aot_engine.match_propogate_one_frame(img, img_embs=img_embs) |
|
if img_embs is None: |
|
img_embs = aot_engine.curr_enc_embs |
|
|
|
def decode_current_logits(self, output_size=None): |
|
all_logits = [] |
|
for aot_engine in self.aot_engines: |
|
all_logits.append(aot_engine.decode_current_logits(output_size)) |
|
pred_id_logits = self.soft_logit_aggregation(all_logits) |
|
return pred_id_logits |
|
|
|
def update_memory(self, curr_mask, skip_long_term_update=False): |
|
_curr_mask = F.interpolate(curr_mask,self.input_size_2d) |
|
separated_masks, _ = self.separate_mask(_curr_mask, self.obj_nums) |
|
for aot_engine, separated_mask in zip(self.aot_engines, |
|
separated_masks): |
|
aot_engine.update_short_term_memory(separated_mask, |
|
skip_long_term_update=skip_long_term_update) |
|
|
|
def update_size(self): |
|
self.input_size_2d = self.aot_engines[0].input_size_2d |
|
self.enc_size_2d = self.aot_engines[0].enc_size_2d |
|
self.enc_hw = self.aot_engines[0].enc_hw |
|
|