import torch
import numpy as np
import torch.nn.functional as F
import math
from utils.box_utils import bbox_iou, xywh2xyxy, xyxy2xywh, generalized_box_iou
from utils.misc import get_world_size
from torch.autograd import Variable
from opt_einsum import contract
def build_target(args, gt_bbox, pred, device):
batch_size = gt_bbox.size(0)
num_scales = len(pred)
coord_list, bbox_list = [], []
for scale_ii in range(num_scales):
this_stride = 32 // (2 ** scale_ii)
grid = args.size // this_stride
center_x = (gt_bbox[:, 0] + gt_bbox[:, 2]) / 2
center_y = (gt_bbox[:, 1] + gt_bbox[:, 3]) / 2
box_w = gt_bbox[:, 2] - gt_bbox[:, 0]
box_h = gt_bbox[:, 3] - gt_bbox[:, 1]
coord = torch.stack((center_x, center_y, box_w, box_h), dim=1)
coord = coord / args.size
coord = coord * grid
bbox_list.append(torch.zeros(coord.size(0), 3, 5, grid, grid))
best_n_list, best_gi, best_gj = [], [], []
for ii in range(batch_size):
anch_ious = []
for scale_ii in range(num_scales):
this_stride = 32 // (2 ** scale_ii)
grid = args.size // this_stride
gw = coord_list[scale_ii][ii,2]
gh = coord_list[scale_ii][ii,3]
anchor_idxs = [x + 3*scale_ii for x in [0,1,2]]
anchors = [args.anchors_full[i] for i in anchor_idxs]
scaled_anchors = [ (x[0] / (args.anchor_imsize/grid), \
x[1] / (args.anchor_imsize/grid)) for x in anchors]
gt_box = torch.from_numpy(np.array([0, 0, gw.cpu().numpy(), gh.cpu().numpy()])).float().unsqueeze(0)
anchor_shapes = torch.FloatTensor(np.concatenate((np.zeros((len(scaled_anchors), 2)), np.array(scaled_anchors)), 1))
anch_ious += list(bbox_iou(gt_box, anchor_shapes))
best_n = np.argmax(np.array(anch_ious))
best_scale = best_n // 3
best_grid = args.size//(32/(2**best_scale))
anchor_idxs = [x + 3*best_scale for x in [0,1,2]]
anchors = [args.anchors_full[i] for i in anchor_idxs]
scaled_anchors = [ (x[0] / (args.anchor_imsize/best_grid), \
x[1] / (args.anchor_imsize/best_grid)) for x in anchors]
gi = coord_list[best_scale][ii,0].long()
gj = coord_list[best_scale][ii,1].long()
tx = coord_list[best_scale][ii,0] - gi.float()
ty = coord_list[best_scale][ii,1] - gj.float()
gw = coord_list[best_scale][ii,2]
gh = coord_list[best_scale][ii,3]
tw = torch.log(gw / scaled_anchors[best_n%3][0] + 1e-16)
th = torch.log(gh / scaled_anchors[best_n%3][1] + 1e-16)
bbox_list[best_scale][ii, best_n%3, :, gj, gi] = torch.stack([tx, ty, tw, th, torch.ones(1).to(device).squeeze()])
for ii in range(len(bbox_list)):
bbox_list[ii] = bbox_list[ii].to(device)
return bbox_list, best_gi, best_gj, best_n_list
def yolo_loss(pred_list, target, gi, gj, best_n_list, device, w_coord=5., w_neg=1./5, size_average=True):
mseloss = torch.nn.MSELoss(size_average=True)
celoss = torch.nn.CrossEntropyLoss(size_average=True)
num_scale = len(pred_list)
batch_size = pred_list[0].size(0)
pred_bbox = torch.zeros(batch_size, 4).to(device)
gt_bbox = torch.zeros(batch_size, 4).to(device)
for ii in range(batch_size):
pred_bbox[ii, 0:2] = torch.sigmoid(pred_list[best_n_list[ii]//3][ii, best_n_list[ii]%3,0:2, gj[ii], gi[ii]])
pred_bbox[ii, 2:4] = pred_list[best_n_list[ii]//3][ii, best_n_list[ii]%3, 2:4, gj[ii], gi[ii]]
gt_bbox[ii, :] = target[best_n_list[ii]//3][ii, best_n_list[ii]%3, :4, gj[ii], gi[ii]]
loss_x = mseloss(pred_bbox[:,0], gt_bbox[:,0])
loss_y = mseloss(pred_bbox[:,1], gt_bbox[:,1])
loss_w = mseloss(pred_bbox[:,2], gt_bbox[:,2])
loss_h = mseloss(pred_bbox[:,3], gt_bbox[:,3])
pred_conf_list, gt_conf_list = [], []
for scale_ii in range(num_scale):
pred_conf = torch.cat(pred_conf_list, dim=1)
gt_conf = torch.cat(gt_conf_list, dim=1)
loss_conf = celoss(pred_conf, gt_conf.max(1)[1])
return (loss_x + loss_y + loss_w + loss_h) * w_coord + loss_conf
def trans_vg_loss(batch_pred, batch_target):
"""Compute the losses related to the bounding boxes,
including the L1 regression loss and the GIoU loss
batch_size = batch_pred.shape[0]
num_boxes = batch_size
loss_bbox = F.l1_loss(batch_pred, batch_target, reduction='none')
loss_giou = 1 - torch.diag(generalized_box_iou(
losses = {}
losses['loss_bbox'] = loss_bbox.sum() / num_boxes
losses['loss_giou'] = loss_giou.sum() / num_boxes
return losses
def trans_vg_cls_loss(batch_pred, batch_target):
"""Compute the losses related to the disease prediction,
including the CE loss.
return F.cross_entropy(batch_pred, batch_target, reduction='elementwise_mean')
def visuPooling(visu_src, target, att_weights=None):
"""pooling the visual features according to the target bbox
visu_bboxs = []
bs = target.shape[0]
width = height = math.floor(math.sqrt(visu_src.shape[0]))
visu_src = visu_src.transpose(0, 1).view(bs, height, width, -1).contiguous()
att_weights_batch = []
for i in range(bs):
visu = visu_src[i]
bbox = target[i]
if att_weights is not None:
att_weight = att_weights[i][21:]
att_weight = att_weight.view(height, width, -1).contiguous()
bbox = xywh2xyxy(bbox)
bbox = [max(math.floor(bbox[0]*width), 0), max(math.floor(bbox[1]*height), 0), math.floor(bbox[2]*width), math.floor(bbox[3]*height)]
if bbox[0] == bbox[2]:
bbox[0] = max(0, bbox[0] - 1)
bbox[2] = min(20, bbox[2] + 1)
if bbox[1] == bbox[3]:
bbox[1] = max(0, bbox[1] - 1)
bbox[3] = min(20, bbox[3] + 1)
visu_bbox = visu[bbox[1]:bbox[3], bbox[0]:bbox[2], :]
visu_bbox = visu_bbox.mean(dim=0).mean(dim=0).unsqueeze(0)
if att_weights is not None:
att_weight_bbox = att_weight[bbox[1]:bbox[3], bbox[0]:bbox[2], :]
att_weight_bbox = att_weight_bbox.mean(dim=0).mean(dim=0).unsqueeze(0)
visu_pool = torch.cat(visu_bboxs, dim=0)
if att_weights is not None:
att_weights_batch = torch.cat(att_weights_batch, dim=0)
return visu_pool, att_weights_batch
return visu_pool
def textPooling(text_src, text_mask, type='mask', att_weights=None, text_data=None, lcpTriple=None):
"""pooling the text features according to the text mask or cls
bs = text_src.shape[1]
text_pools = []
text_src = text_src.transpose(0, 1).contiguous()
att_text_batch = []
att_reg_batch = []
if type == 'marker':
text_ids_batch = text_data.tensors
for i in range(bs):
text = text_src[i]
mask = text_mask[i]
word_count = (mask==False).int().sum()
if type == 'mask':
text_pool = text[:word_count, :].mean(dim=0).unsqueeze(0)
if att_weights is not None:
att_text = att_weights[i][1:21]
att_text = att_text[:word_count, :].mean(dim=0).unsqueeze(0)
elif type == 'all':
text_pool = text.mean(dim=0).unsqueeze(0)
elif type == 'cls':
text_pool = text[0].unsqueeze(0)
if att_weights is not None:
if lcpTriple == 'lcpTriple':
att_reg = att_weights[i][0:1]
att_text = att_weights[i][1:2]
elif type == 'marker':
text_ids = text_ids_batch[i]
marker_idx = (text_ids == 1008).nonzero().squeeze()
id1 = marker_idx[0]
id2 = marker_idx[1]
assert id2-id1>1
text_pool = text[id1+1:id2].mean(dim=0).unsqueeze(0)
if att_weights is not None:
att_text = att_weights[i][id1+1:id2].mean(dim=0).unsqueeze(0)
text_pools = torch.cat(text_pools, dim=0)
if att_weights is not None:
att_text_batch = torch.cat(att_text_batch, dim=0)
if lcpTriple == 'lcpTriple':
att_reg_batch = torch.cat(att_reg_batch, dim=0)
return text_pools, att_text_batch, att_reg_batch
return text_pools, att_text_batch
return text_pools
def trans_vg_btloss(visu_pool, text_pool, type='l1'):
if type == 'l1':
return F.l1_loss(visu_pool, text_pool, reduction='elementwise_mean')
elif type == 'l2':
return F.mse_loss(visu_pool, text_pool)
raise ValueError('loss type not supportted ')
def trans_vg_caloss(pos_pool, neg_pools, text_pool, temperature=0.07, mode='max'):
text_pool = text_pool.unsqueeze(1)
pos_pool = pos_pool.unsqueeze(1)
if 'projection' in mode:
visu_pools = torch.cat([pos_pool, neg_pools], dim=1)
visu_pools = F.normalize(visu_pools, p=2, dim=2)
text_pool = F.normalize(text_pool, p=2, dim=2)
anchor_dot_contrast = torch.div(torch.matmul(text_pool, visu_pools.transpose(1,2)), temperature)
if 'max' in mode:
logit_max, _ = torch.max(anchor_dot_contrast, dim=2, keepdim=True)
logit = anchor_dot_contrast - logit_max.detach()
logit = anchor_dot_contrast
exp_total = torch.exp(logit).sum(dim=2).squeeze()
logit_pos = logit[:, :, 0].squeeze()
loss = torch.mean(exp_total - logit_pos)
return loss
def trans_vg_caloss_crossbatch(cnn_code, _, rnn_code, eps=1e-8, temp=0.1):
batch_size = cnn_code.shape[0]
labels = Variable(torch.LongTensor(range(batch_size))).to(cnn_code.device)
if cnn_code.dim() == 2:
cnn_code = cnn_code.unsqueeze(0)
rnn_code = rnn_code.unsqueeze(0)
cnn_code_norm = torch.norm(cnn_code, 2, dim=2, keepdim=True)
rnn_code_norm = torch.norm(rnn_code, 2, dim=2, keepdim=True)
scores0 = torch.bmm(cnn_code, rnn_code.transpose(1, 2))
norm0 = torch.bmm(cnn_code_norm, rnn_code_norm.transpose(1, 2))
scores0 = scores0 / norm0.clamp(min=eps) / temp3
scores0 = scores0.squeeze()
scores1 = scores0.transpose(0, 1)
loss0 = torch.nn.CrossEntropyLoss()(scores0, labels)
loss1 = torch.nn.CrossEntropyLoss()(scores1, labels)
loss = loss0 + loss1
return loss
def trans_vg_caloss_inimage(pos_pool, neg_pools, rnn_code, eps=1e-8, temp3=0.1):
rnn_code = rnn_code.unsqueeze(1)
pos_pool = pos_pool.unsqueeze(1)
cnn_code = torch.cat([pos_pool, neg_pools], dim=1)
batch_size = cnn_code.shape[0]
labels = Variable(torch.LongTensor([0]*batch_size)).to(cnn_code.device)
cnn_code_norm = torch.norm(cnn_code, 2, dim=2, keepdim=True)
rnn_code_norm = torch.norm(rnn_code, 2, dim=2, keepdim=True)
scores0 = torch.bmm(cnn_code, rnn_code.transpose(1, 2))
norm0 = torch.bmm(cnn_code_norm, rnn_code_norm.transpose(1, 2))
scores0 = scores0 / norm0.clamp(min=eps) / temp3
scores0 = scores0.squeeze()
loss = torch.nn.CrossEntropyLoss()(scores0, labels)
return loss
def cal_lcp_triple(h_att, t_att, g_att, emb):
bs = h_att.shape[0]
ht_att = h_att * t_att * g_att
ht_att = ht_att / (ht_att.sum(1, keepdim=True) + 1e-5)
rss = []
for i in range(bs):
rs = contract("ld,rl->rd", emb[:, i, :], ht_att[i:i+1, :])
rss = torch.cat(rss, dim=0)
return rss
def trans_vg_caloss_inimage_lcp_triple(pos_pool, neg_pools, rnn_code, att_pos, att_negs, att_text, att_reg, emb, eps=1e-8, temp3=0.1):
rnn_code = rnn_code.unsqueeze(1)
pos_pool = pos_pool.unsqueeze(1)
cnn_code = torch.cat([pos_pool, neg_pools], dim=1)
batch_size = cnn_code.shape[0]
labels = Variable(torch.LongTensor([0]*batch_size)).to(cnn_code.device)
tp = cal_lcp_triple(att_text, att_pos, att_reg, emb)
tns = []
neg_num = neg_pools.shape[1]
for j in range(neg_num):
tn = cal_lcp_triple(att_text, att_negs[:, j, :], att_reg, emb)
c = torch.cat([tp.unsqueeze(1)] + tns, dim=1)
rnn_code = rnn_code.repeat(1, neg_num+1, 1)
rnn_code = rnn_code + c
cnn_code = cnn_code + c
cnn_code_norm = torch.norm(cnn_code, 2, dim=2, keepdim=True)
rnn_code_norm = torch.norm(rnn_code, 2, dim=2, keepdim=True)
scores0 = cnn_code * rnn_code
norm0 = cnn_code_norm * rnn_code_norm
scores0 = scores0 / norm0.clamp(min=eps) / temp3
scores0 = scores0.sum(2)
loss = torch.nn.CrossEntropyLoss()(scores0, labels)
return loss
def cal_lcp(h_att, t_att, emb):
bs = h_att.shape[0]
ht_att = h_att * t_att
ht_att = ht_att / (ht_att.sum(1, keepdim=True) + 1e-5)
rss = []
for i in range(bs):
rs = contract("ld,rl->rd", emb[:, i, :], ht_att[i:i+1, :])
rss = torch.cat(rss, dim=0)
return rss
def trans_vg_caloss_inimage_lcp(pos_pool, neg_pools, rnn_code, att_pos, att_negs, att_text, emb, ws=None, wo=None, wc1=None, wc2=None, eps=1e-8, temp3=0.1):
rnn_code = rnn_code.unsqueeze(1)
pos_pool = pos_pool.unsqueeze(1)
cnn_code = torch.cat([pos_pool, neg_pools], dim=1)
batch_size = cnn_code.shape[0]
labels = Variable(torch.LongTensor([0]*batch_size)).to(cnn_code.device)
tp = cal_lcp(att_text, att_pos, emb)
tns = []
neg_num = neg_pools.shape[1]
for j in range(neg_num):
tn = cal_lcp(att_text, att_negs[:, j, :], emb)
c = torch.cat([tp.unsqueeze(1)] + tns, dim=1)
if wc1 is None:
rnn_code = rnn_code.repeat(1, neg_num+1, 1)
rnn_code = rnn_code + c
cnn_code = cnn_code + c
rnn_code = rnn_code.repeat(1, neg_num+1, 1)
rnn_code = rnn_code + wc1(c)
cnn_code = cnn_code + wc1(c)
cnn_code_norm = torch.norm(cnn_code, 2, dim=2, keepdim=True)
rnn_code_norm = torch.norm(rnn_code, 2, dim=2, keepdim=True)
scores0 = cnn_code * rnn_code
norm0 = cnn_code_norm * rnn_code_norm
scores0 = scores0 / norm0.clamp(min=eps) / temp3
scores0 = scores0.sum(2)
loss = torch.nn.CrossEntropyLoss()(scores0, labels)
return loss
def trans_vg_conBox(batch_pred, batch_target):
"""Compute the losses related to the bounding boxes,
including the L1 regression loss and the GIoU loss
batch_size = batch_pred.shape[0]
num_boxes = batch_size
loss_bbox = F.l1_loss(batch_pred, batch_target, reduction='none')
loss_giou = 1 - torch.diag(generalized_box_iou(
return loss_bbox.sum() / num_boxes, loss_giou.sum() / num_boxes
def CAlossFunc(epoch, max_epoch, type='poly'):
if type == 'poly':
power = 0.9
return (epoch/max_epoch)**power
def trans_vg_gn_loss(batch_pred, batch_target):
including the Multi-BCE loss.
return F.binary_cross_entropy_with_logits(batch_pred, batch_target)