samtrack / aot /networks /engines /aot_engine.py
aikenml's picture
Upload folder using huggingface_hub
c985ba4
raw
history blame
25.2 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from utils.math import generate_permute_matrix
from utils.image import one_hot_mask
from networks.layers.basic import seq_to_2d
class AOTEngine(nn.Module):
def __init__(self,
aot_model,
gpu_id=0,
long_term_mem_gap=9999,
short_term_mem_skip=1,
max_len_long_term=9999):
super().__init__()
self.cfg = aot_model.cfg
self.align_corners = aot_model.cfg.MODEL_ALIGN_CORNERS
self.AOT = aot_model
self.max_obj_num = aot_model.max_obj_num
self.gpu_id = gpu_id
self.long_term_mem_gap = long_term_mem_gap
self.short_term_mem_skip = short_term_mem_skip
self.max_len_long_term = max_len_long_term
self.losses = None
self.restart_engine()
def forward(self,
all_frames,
all_masks,
batch_size,
obj_nums,
step=0,
tf_board=False,
use_prev_pred=False,
enable_prev_frame=False,
use_prev_prob=False): # only used for training
if self.losses is None:
self._init_losses()
self.freeze_id = True if use_prev_pred else False
aux_weight = self.aux_weight * max(self.aux_step - step,
0.) / self.aux_step
self.offline_encoder(all_frames, all_masks)
self.add_reference_frame(frame_step=0, obj_nums=obj_nums)
grad_state = torch.no_grad if aux_weight == 0 else torch.enable_grad
with grad_state():
ref_aux_loss, ref_aux_mask = self.generate_loss_mask(
self.offline_masks[self.frame_step], step)
aux_losses = [ref_aux_loss]
aux_masks = [ref_aux_mask]
curr_losses, curr_masks = [], []
if enable_prev_frame:
self.set_prev_frame(frame_step=1)
with grad_state():
prev_aux_loss, prev_aux_mask = self.generate_loss_mask(
self.offline_masks[self.frame_step], step)
aux_losses.append(prev_aux_loss)
aux_masks.append(prev_aux_mask)
else:
self.match_propogate_one_frame()
curr_loss, curr_mask, curr_prob = self.generate_loss_mask(
self.offline_masks[self.frame_step], step, return_prob=True)
self.update_short_term_memory(
curr_mask if not use_prev_prob else curr_prob,
None if use_prev_pred else self.assign_identity(
self.offline_one_hot_masks[self.frame_step]))
curr_losses.append(curr_loss)
curr_masks.append(curr_mask)
self.match_propogate_one_frame()
curr_loss, curr_mask, curr_prob = self.generate_loss_mask(
self.offline_masks[self.frame_step], step, return_prob=True)
curr_losses.append(curr_loss)
curr_masks.append(curr_mask)
for _ in range(self.total_offline_frame_num - 3):
self.update_short_term_memory(
curr_mask if not use_prev_prob else curr_prob,
None if use_prev_pred else self.assign_identity(
self.offline_one_hot_masks[self.frame_step]))
self.match_propogate_one_frame()
curr_loss, curr_mask, curr_prob = self.generate_loss_mask(
self.offline_masks[self.frame_step], step, return_prob=True)
curr_losses.append(curr_loss)
curr_masks.append(curr_mask)
aux_loss = torch.cat(aux_losses, dim=0).mean(dim=0)
pred_loss = torch.cat(curr_losses, dim=0).mean(dim=0)
loss = aux_weight * aux_loss + pred_loss
all_pred_mask = aux_masks + curr_masks
all_frame_loss = aux_losses + curr_losses
boards = {'image': {}, 'scalar': {}}
return loss, all_pred_mask, all_frame_loss, boards
def _init_losses(self):
cfg = self.cfg
from networks.layers.loss import CrossEntropyLoss, SoftJaccordLoss
bce_loss = CrossEntropyLoss(
cfg.TRAIN_TOP_K_PERCENT_PIXELS,
cfg.TRAIN_HARD_MINING_RATIO * cfg.TRAIN_TOTAL_STEPS)
iou_loss = SoftJaccordLoss()
losses = [bce_loss, iou_loss]
loss_weights = [0.5, 0.5]
self.losses = nn.ModuleList(losses)
self.loss_weights = loss_weights
self.aux_weight = cfg.TRAIN_AUX_LOSS_WEIGHT
self.aux_step = cfg.TRAIN_TOTAL_STEPS * cfg.TRAIN_AUX_LOSS_RATIO + 1e-5
def encode_one_img_mask(self, img=None, mask=None, frame_step=-1):
if frame_step == -1:
frame_step = self.frame_step
if self.enable_offline_enc:
curr_enc_embs = self.offline_enc_embs[frame_step]
elif img is None:
curr_enc_embs = None
else:
curr_enc_embs = self.AOT.encode_image(img)
if mask is not None:
curr_one_hot_mask = one_hot_mask(mask, self.max_obj_num)
elif self.enable_offline_enc:
curr_one_hot_mask = self.offline_one_hot_masks[frame_step]
else:
curr_one_hot_mask = None
return curr_enc_embs, curr_one_hot_mask
def offline_encoder(self, all_frames, all_masks=None):
self.enable_offline_enc = True
self.offline_frames = all_frames.size(0) // self.batch_size
# extract backbone features
self.offline_enc_embs = self.split_frames(
self.AOT.encode_image(all_frames), self.batch_size)
self.total_offline_frame_num = len(self.offline_enc_embs)
if all_masks is not None:
# extract mask embeddings
offline_one_hot_masks = one_hot_mask(all_masks, self.max_obj_num)
self.offline_masks = list(
torch.split(all_masks, self.batch_size, dim=0))
self.offline_one_hot_masks = list(
torch.split(offline_one_hot_masks, self.batch_size, dim=0))
if self.input_size_2d is None:
self.update_size(all_frames.size()[2:],
self.offline_enc_embs[0][-1].size()[2:])
def assign_identity(self, one_hot_mask):
if self.enable_id_shuffle:
one_hot_mask = torch.einsum('bohw,bot->bthw', one_hot_mask,
self.id_shuffle_matrix)
id_emb = self.AOT.get_id_emb(one_hot_mask).view(
self.batch_size, -1, self.enc_hw).permute(2, 0, 1)
if self.training and self.freeze_id:
id_emb = id_emb.detach()
return id_emb
def split_frames(self, xs, chunk_size):
new_xs = []
for x in xs:
all_x = list(torch.split(x, chunk_size, dim=0))
new_xs.append(all_x)
return list(zip(*new_xs))
def add_reference_frame(self,
img=None,
mask=None,
frame_step=-1,
obj_nums=None,
img_embs=None):
if self.obj_nums is None and obj_nums is None:
print('No objects for reference frame!')
exit()
elif obj_nums is not None:
self.obj_nums = obj_nums
if frame_step == -1:
frame_step = self.frame_step
if img_embs is None:
curr_enc_embs, curr_one_hot_mask = self.encode_one_img_mask(
img, mask, frame_step)
else:
_, curr_one_hot_mask = self.encode_one_img_mask(
None, mask, frame_step)
curr_enc_embs = img_embs
if curr_enc_embs is None:
print('No image for reference frame!')
exit()
if curr_one_hot_mask is None:
print('No mask for reference frame!')
exit()
if self.input_size_2d is None:
self.update_size(img.size()[2:], curr_enc_embs[-1].size()[2:])
self.curr_enc_embs = curr_enc_embs
self.curr_one_hot_mask = curr_one_hot_mask
if self.pos_emb is None:
self.pos_emb = self.AOT.get_pos_emb(curr_enc_embs[-1]).expand(
self.batch_size, -1, -1,
-1).view(self.batch_size, -1, self.enc_hw).permute(2, 0, 1)
curr_id_emb = self.assign_identity(curr_one_hot_mask)
self.curr_id_embs = curr_id_emb
# self matching and propagation
self.curr_lstt_output = self.AOT.LSTT_forward(curr_enc_embs,
None,
None,
curr_id_emb,
pos_emb=self.pos_emb,
size_2d=self.enc_size_2d)
lstt_embs, lstt_curr_memories, lstt_long_memories, lstt_short_memories = self.curr_lstt_output
if self.long_term_memories is None:
self.long_term_memories = lstt_long_memories
else:
self.update_long_term_memory(lstt_long_memories)
self.last_mem_step = self.frame_step
self.short_term_memories_list = [lstt_short_memories]
self.short_term_memories = lstt_short_memories
def set_prev_frame(self, img=None, mask=None, frame_step=1):
self.frame_step = frame_step
curr_enc_embs, curr_one_hot_mask = self.encode_one_img_mask(
img, mask, frame_step)
if curr_enc_embs is None:
print('No image for previous frame!')
exit()
if curr_one_hot_mask is None:
print('No mask for previous frame!')
exit()
self.curr_enc_embs = curr_enc_embs
self.curr_one_hot_mask = curr_one_hot_mask
curr_id_emb = self.assign_identity(curr_one_hot_mask)
self.curr_id_embs = curr_id_emb
# self matching and propagation
self.curr_lstt_output = self.AOT.LSTT_forward(curr_enc_embs,
None,
None,
curr_id_emb,
pos_emb=self.pos_emb,
size_2d=self.enc_size_2d)
lstt_embs, lstt_curr_memories, lstt_long_memories, lstt_short_memories = self.curr_lstt_output
if self.long_term_memories is None:
self.long_term_memories = lstt_long_memories
else:
self.update_long_term_memory(lstt_long_memories)
self.last_mem_step = frame_step
self.short_term_memories_list = [lstt_short_memories]
self.short_term_memories = lstt_short_memories
def update_long_term_memory(self, new_long_term_memories):
TOKEN_NUM = new_long_term_memories[0][0].shape[0]
if self.long_term_memories is None:
self.long_term_memories = new_long_term_memories
updated_long_term_memories = []
for new_long_term_memory, last_long_term_memory in zip(
new_long_term_memories, self.long_term_memories):
updated_e = []
for new_e, last_e in zip(new_long_term_memory,
last_long_term_memory):
if new_e is None or last_e is None:
updated_e.append(None)
else:
if last_e.shape[0] >= self.max_len_long_term * TOKEN_NUM:
last_e = last_e[:(self.max_len_long_term - 1) * TOKEN_NUM]
updated_e.append(torch.cat([new_e, last_e], dim=0))
updated_long_term_memories.append(updated_e)
self.long_term_memories = updated_long_term_memories
def update_short_term_memory(self, curr_mask, curr_id_emb=None, skip_long_term_update=False):
if curr_id_emb is None:
if len(curr_mask.size()) == 3 or curr_mask.size()[0] == 1:
curr_one_hot_mask = one_hot_mask(curr_mask, self.max_obj_num)
else:
curr_one_hot_mask = curr_mask
curr_id_emb = self.assign_identity(curr_one_hot_mask)
lstt_curr_memories = self.curr_lstt_output[1]
lstt_curr_memories_2d = []
for layer_idx in range(len(lstt_curr_memories)):
curr_k, curr_v = lstt_curr_memories[layer_idx][
0], lstt_curr_memories[layer_idx][1]
curr_k, curr_v = self.AOT.LSTT.layers[layer_idx].fuse_key_value_id(
curr_k, curr_v, curr_id_emb)
lstt_curr_memories[layer_idx][0], lstt_curr_memories[layer_idx][
1] = curr_k, curr_v
lstt_curr_memories_2d.append([
seq_to_2d(lstt_curr_memories[layer_idx][0], self.enc_size_2d),
seq_to_2d(lstt_curr_memories[layer_idx][1], self.enc_size_2d)
])
self.short_term_memories_list.append(lstt_curr_memories_2d)
self.short_term_memories_list = self.short_term_memories_list[
-self.short_term_mem_skip:]
self.short_term_memories = self.short_term_memories_list[0]
if self.frame_step - self.last_mem_step >= self.long_term_mem_gap:
# skip the update of long-term memory or not
if not skip_long_term_update:
self.update_long_term_memory(lstt_curr_memories)
self.last_mem_step = self.frame_step
def match_propogate_one_frame(self, img=None, img_embs=None):
self.frame_step += 1
if img_embs is None:
curr_enc_embs, _ = self.encode_one_img_mask(
img, None, self.frame_step)
else:
curr_enc_embs = img_embs
self.curr_enc_embs = curr_enc_embs
self.curr_lstt_output = self.AOT.LSTT_forward(curr_enc_embs,
self.long_term_memories,
self.short_term_memories,
None,
pos_emb=self.pos_emb,
size_2d=self.enc_size_2d)
def decode_current_logits(self, output_size=None):
curr_enc_embs = self.curr_enc_embs
curr_lstt_embs = self.curr_lstt_output[0]
pred_id_logits = self.AOT.decode_id_logits(curr_lstt_embs,
curr_enc_embs)
if self.enable_id_shuffle: # reverse shuffle
pred_id_logits = torch.einsum('bohw,bto->bthw', pred_id_logits,
self.id_shuffle_matrix)
# remove unused identities
for batch_idx, obj_num in enumerate(self.obj_nums):
pred_id_logits[batch_idx, (obj_num+1):] = - \
1e+10 if pred_id_logits.dtype == torch.float32 else -1e+4
self.pred_id_logits = pred_id_logits
if output_size is not None:
pred_id_logits = F.interpolate(pred_id_logits,
size=output_size,
mode="bilinear",
align_corners=self.align_corners)
return pred_id_logits
def predict_current_mask(self, output_size=None, return_prob=False):
if output_size is None:
output_size = self.input_size_2d
pred_id_logits = F.interpolate(self.pred_id_logits,
size=output_size,
mode="bilinear",
align_corners=self.align_corners)
pred_mask = torch.argmax(pred_id_logits, dim=1)
if not return_prob:
return pred_mask
else:
pred_prob = torch.softmax(pred_id_logits, dim=1)
return pred_mask, pred_prob
def calculate_current_loss(self, gt_mask, step):
pred_id_logits = self.pred_id_logits
pred_id_logits = F.interpolate(pred_id_logits,
size=gt_mask.size()[-2:],
mode="bilinear",
align_corners=self.align_corners)
label_list = []
logit_list = []
for batch_idx, obj_num in enumerate(self.obj_nums):
now_label = gt_mask[batch_idx].long()
now_logit = pred_id_logits[batch_idx, :(obj_num + 1)].unsqueeze(0)
label_list.append(now_label.long())
logit_list.append(now_logit)
total_loss = 0
for loss, loss_weight in zip(self.losses, self.loss_weights):
total_loss = total_loss + loss_weight * \
loss(logit_list, label_list, step)
return total_loss
def generate_loss_mask(self, gt_mask, step, return_prob=False):
self.decode_current_logits()
loss = self.calculate_current_loss(gt_mask, step)
if return_prob:
mask, prob = self.predict_current_mask(return_prob=True)
return loss, mask, prob
else:
mask = self.predict_current_mask()
return loss, mask
def keep_gt_mask(self, pred_mask, keep_prob=0.2):
pred_mask = pred_mask.float()
gt_mask = self.offline_masks[self.frame_step].float().squeeze(1)
shape = [1 for _ in range(pred_mask.ndim)]
shape[0] = self.batch_size
random_tensor = keep_prob + torch.rand(
shape, dtype=pred_mask.dtype, device=pred_mask.device)
random_tensor.floor_() # binarize
pred_mask = pred_mask * (1 - random_tensor) + gt_mask * random_tensor
return pred_mask
def restart_engine(self, batch_size=1, enable_id_shuffle=False):
self.batch_size = batch_size
self.frame_step = 0
self.last_mem_step = -1
self.enable_id_shuffle = enable_id_shuffle
self.freeze_id = False
self.obj_nums = None
self.pos_emb = None
self.enc_size_2d = None
self.enc_hw = None
self.input_size_2d = None
self.long_term_memories = None
self.short_term_memories_list = []
self.short_term_memories = None
self.enable_offline_enc = False
self.offline_enc_embs = None
self.offline_one_hot_masks = None
self.offline_frames = -1
self.total_offline_frame_num = 0
self.curr_enc_embs = None
self.curr_memories = None
self.curr_id_embs = None
if enable_id_shuffle:
self.id_shuffle_matrix = generate_permute_matrix(
self.max_obj_num + 1, batch_size, gpu_id=self.gpu_id)
else:
self.id_shuffle_matrix = None
def update_size(self, input_size, enc_size):
self.input_size_2d = input_size
self.enc_size_2d = enc_size
self.enc_hw = self.enc_size_2d[0] * self.enc_size_2d[1]
class AOTInferEngine(nn.Module):
def __init__(self,
aot_model,
gpu_id=0,
long_term_mem_gap=9999,
short_term_mem_skip=1,
max_aot_obj_num=None,
max_len_long_term=9999,):
super().__init__()
self.cfg = aot_model.cfg
self.AOT = aot_model
if max_aot_obj_num is None or max_aot_obj_num > aot_model.max_obj_num:
self.max_aot_obj_num = aot_model.max_obj_num
else:
self.max_aot_obj_num = max_aot_obj_num
self.gpu_id = gpu_id
self.long_term_mem_gap = long_term_mem_gap
self.short_term_mem_skip = short_term_mem_skip
self.max_len_long_term = max_len_long_term
self.aot_engines = []
self.restart_engine()
def restart_engine(self):
del (self.aot_engines)
self.aot_engines = []
self.obj_nums = None
def separate_mask(self, mask, obj_nums):
if mask is None:
return [None] * len(self.aot_engines)
if len(self.aot_engines) == 1:
return [mask], [obj_nums]
separated_obj_nums = [
self.max_aot_obj_num for _ in range(len(self.aot_engines))
]
if obj_nums % self.max_aot_obj_num > 0:
separated_obj_nums[-1] = obj_nums % self.max_aot_obj_num
if len(mask.size()) == 3 or mask.size()[0] == 1:
separated_masks = []
for idx in range(len(self.aot_engines)):
start_id = idx * self.max_aot_obj_num + 1
end_id = (idx + 1) * self.max_aot_obj_num
fg_mask = ((mask >= start_id) & (mask <= end_id)).float()
separated_mask = (fg_mask * mask - start_id + 1) * fg_mask
separated_masks.append(separated_mask)
return separated_masks, separated_obj_nums
else:
prob = mask
separated_probs = []
for idx in range(len(self.aot_engines)):
start_id = idx * self.max_aot_obj_num + 1
end_id = (idx + 1) * self.max_aot_obj_num
fg_prob = prob[start_id:(end_id + 1)]
bg_prob = 1. - torch.sum(fg_prob, dim=1, keepdim=True)
separated_probs.append(torch.cat([bg_prob, fg_prob], dim=1))
return separated_probs, separated_obj_nums
def min_logit_aggregation(self, all_logits):
if len(all_logits) == 1:
return all_logits[0]
fg_logits = []
bg_logits = []
for logit in all_logits:
bg_logits.append(logit[:, 0:1])
fg_logits.append(logit[:, 1:1 + self.max_aot_obj_num])
bg_logit, _ = torch.min(torch.cat(bg_logits, dim=1),
dim=1,
keepdim=True)
merged_logit = torch.cat([bg_logit] + fg_logits, dim=1)
return merged_logit
def soft_logit_aggregation(self, all_logits):
if len(all_logits) == 1:
return all_logits[0]
fg_probs = []
bg_probs = []
for logit in all_logits:
prob = torch.softmax(logit, dim=1)
bg_probs.append(prob[:, 0:1])
fg_probs.append(prob[:, 1:1 + self.max_aot_obj_num])
bg_prob = torch.prod(torch.cat(bg_probs, dim=1), dim=1, keepdim=True)
merged_prob = torch.cat([bg_prob] + fg_probs,
dim=1).clamp(1e-5, 1 - 1e-5)
merged_logit = torch.logit(merged_prob)
return merged_logit
def add_reference_frame(self, img, mask, obj_nums, frame_step=-1):
if isinstance(obj_nums, list):
obj_nums = obj_nums[0]
self.obj_nums = obj_nums
aot_num = max(np.ceil(obj_nums / self.max_aot_obj_num), 1)
while (aot_num > len(self.aot_engines)):
new_engine = AOTEngine(self.AOT, self.gpu_id,
self.long_term_mem_gap,
self.short_term_mem_skip,
self.max_len_long_term,)
new_engine.eval()
self.aot_engines.append(new_engine)
separated_masks, separated_obj_nums = self.separate_mask(
mask, obj_nums)
img_embs = None
for aot_engine, separated_mask, separated_obj_num in zip(
self.aot_engines, separated_masks, separated_obj_nums):
aot_engine.add_reference_frame(img,
separated_mask,
obj_nums=[separated_obj_num],
frame_step=frame_step,
img_embs=img_embs)
if img_embs is None: # reuse image embeddings
img_embs = aot_engine.curr_enc_embs
self.update_size()
def match_propogate_one_frame(self, img=None):
img_embs = None
for aot_engine in self.aot_engines:
aot_engine.match_propogate_one_frame(img, img_embs=img_embs)
if img_embs is None: # reuse image embeddings
img_embs = aot_engine.curr_enc_embs
def decode_current_logits(self, output_size=None):
all_logits = []
for aot_engine in self.aot_engines:
all_logits.append(aot_engine.decode_current_logits(output_size))
pred_id_logits = self.soft_logit_aggregation(all_logits)
return pred_id_logits
def update_memory(self, curr_mask, skip_long_term_update=False):
_curr_mask = F.interpolate(curr_mask,self.input_size_2d)
separated_masks, _ = self.separate_mask(_curr_mask, self.obj_nums)
for aot_engine, separated_mask in zip(self.aot_engines,
separated_masks):
aot_engine.update_short_term_memory(separated_mask,
skip_long_term_update=skip_long_term_update)
def update_size(self):
self.input_size_2d = self.aot_engines[0].input_size_2d
self.enc_size_2d = self.aot_engines[0].enc_size_2d
self.enc_hw = self.aot_engines[0].enc_hw