import torch import random from torch import nn, Tensor import os import numpy as np import math import torch.nn.functional as F from torch import nn class PoseProjector(nn.Module): def __init__(self, hidden_dim=256, num_body_points=17): super().__init__() self.num_body_points = num_body_points self.V_projector = nn.Linear(hidden_dim, num_body_points) nn.init.constant_(self.V_projector.bias.data, 0) self.Z_projector = MLP(hidden_dim, hidden_dim, num_body_points * 2, 3) nn.init.constant_(self.Z_projector.layers[-1].weight.data, 0) nn.init.constant_(self.Z_projector.layers[-1].bias.data, 0) def forward(self, hs): """_summary_ Args: hs (_type_): ..., bs, nq, hidden_dim """ Z = self.Z_projector(hs) # ..., bs, nq, 34 V = self.V_projector(hs) # ..., bs, nq, 17 return Z, V def gen_encoder_output_proposals(memory: Tensor, memory_padding_mask: Tensor, spatial_shapes: Tensor, learnedwh=None): """ Input: - memory: bs, \sum{hw}, d_model - memory_padding_mask: bs, \sum{hw} - spatial_shapes: nlevel, 2 - learnedwh: 2 Output: - output_memory: bs, \sum{hw}, d_model - output_proposals: bs, \sum{hw}, 4 """ N_, S_, C_ = memory.shape base_scale = 4.0 proposals = [] _cur = 0 for lvl, (H_, W_) in enumerate(spatial_shapes): mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view( N_, H_, W_, 1) valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1) valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1) grid_y, grid_x = torch.meshgrid( torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device), torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device)) grid = torch.cat( [grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) # H_, W_, 2 scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2) grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale if learnedwh is not None: wh = torch.ones_like(grid) * learnedwh.sigmoid() * (2.0**lvl) else: wh = torch.ones_like(grid) * 0.05 * (2.0**lvl) proposal = torch.cat((grid, wh), -1).view(N_, -1, 4) proposals.append(proposal) _cur += (H_ * W_) # import pdb; pdb.set_trace() output_proposals = torch.cat(proposals, 1) output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True) output_proposals = torch.log(output_proposals / (1 - output_proposals)) # unsigmoid output_proposals = output_proposals.masked_fill( memory_padding_mask.unsqueeze(-1), float('inf')) output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf')) output_memory = memory output_memory = output_memory.masked_fill( memory_padding_mask.unsqueeze(-1), float(0)) output_memory = output_memory.masked_fill(~output_proposals_valid, float(0)) return output_memory, output_proposals class RandomBoxPerturber(): def __init__(self, x_noise_scale=0.2, y_noise_scale=0.2, w_noise_scale=0.2, h_noise_scale=0.2) -> None: self.noise_scale = torch.Tensor( [x_noise_scale, y_noise_scale, w_noise_scale, h_noise_scale]) def __call__(self, refanchors: Tensor) -> Tensor: nq, bs, query_dim = refanchors.shape device = refanchors.device noise_raw = torch.rand_like(refanchors) noise_scale = self.noise_scale.to(device)[:query_dim] new_refanchors = refanchors * (1 + (noise_raw - 0.5) * noise_scale) return new_refanchors.clamp_(0, 1) def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2): """ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). alpha: (optional) Weighting factor in range (0,1) to balance positive vs negative examples. Default = -1 (no weighting). gamma: Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples. Returns: Loss tensor """ prob = inputs.sigmoid() ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none') p_t = prob * targets + (1 - prob) * (1 - targets) loss = ce_loss * ((1 - p_t)**gamma) if alpha >= 0: alpha_t = alpha * targets + (1 - alpha) * (1 - targets) loss = alpha_t * loss return loss.mean(1).sum() / num_boxes class MLP(nn.Module): """Very simple multi-layer perceptron (also called FFN)""" def __init__(self, input_dim, hidden_dim, output_dim, num_layers): super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) self.layers = nn.ModuleList( nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) def forward(self, x): for i, layer in enumerate(self.layers): x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) return x def _get_activation_fn(activation, d_model=256, batch_dim=0): """Return an activation function given a string.""" if activation == 'relu': return F.relu if activation == 'gelu': return F.gelu if activation == 'glu': return F.glu if activation == 'prelu': return nn.PReLU() if activation == 'selu': return F.selu raise RuntimeError(F'activation should be relu/gelu, not {activation}.') def gen_sineembed_for_position(pos_tensor): # n_query, bs, _ = pos_tensor.size() # sineembed_tensor = torch.zeros(n_query, bs, 256) scale = 2 * math.pi dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device) dim_t = 10000**(2 * (dim_t // 2) / 128) x_embed = pos_tensor[:, :, 0] * scale y_embed = pos_tensor[:, :, 1] * scale pos_x = x_embed[:, :, None] / dim_t pos_y = y_embed[:, :, None] / dim_t pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2) if pos_tensor.size(-1) == 2: pos = torch.cat((pos_y, pos_x), dim=2) elif pos_tensor.size(-1) == 4: w_embed = pos_tensor[:, :, 2] * scale pos_w = w_embed[:, :, None] / dim_t pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2) h_embed = pos_tensor[:, :, 3] * scale pos_h = h_embed[:, :, None] / dim_t pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2) pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) else: raise ValueError('Unknown pos_tensor shape(-1):{}'.format( pos_tensor.size(-1))) return pos def oks_overlaps(kpt_preds, kpt_gts, kpt_valids, kpt_areas, sigmas): sigmas = kpt_preds.new_tensor(sigmas) variances = (sigmas * 2)**2 assert kpt_preds.size(0) == kpt_gts.size(0) kpt_preds = kpt_preds.reshape(-1, kpt_preds.size(-1) // 2, 2) kpt_gts = kpt_gts.reshape(-1, kpt_gts.size(-1) // 2, 2) squared_distance = (kpt_preds[:, :, 0] - kpt_gts[:, :, 0]) ** 2 + \ (kpt_preds[:, :, 1] - kpt_gts[:, :, 1]) ** 2 # import pdb # pdb.set_trace() # assert (kpt_valids.sum(-1) > 0).all() squared_distance0 = squared_distance / (kpt_areas[:, None] * variances[None, :] * 2) squared_distance1 = torch.exp(-squared_distance0) squared_distance1 = squared_distance1 * kpt_valids oks = squared_distance1.sum(dim=1) / (kpt_valids.sum(dim=1) + 1e-6) return oks def oks_loss(pred, target, valid=None, area=None, linear=False, sigmas=None, eps=1e-6): """Oks loss. Computing the oks loss between a set of predicted poses and target poses. The loss is calculated as negative log of oks. Args: pred (torch.Tensor): Predicted poses of format (x1, y1, x2, y2, ...), shape (n, 2K). target (torch.Tensor): Corresponding gt poses, shape (n, 2K). linear (bool, optional): If True, use linear scale of loss instead of log scale. Default: False. eps (float): Eps to avoid log(0). Return: torch.Tensor: Loss tensor. """ oks = oks_overlaps(pred, target, valid, area, sigmas).clamp(min=eps) if linear: loss = 1 - oks else: loss = -oks.log() loss = loss * valid.sum(-1) / (valid.sum(-1) + eps) return loss class OKSLoss(nn.Module): """IoULoss. Computing the oks loss between a set of predicted poses and target poses. Args: linear (bool): If True, use linear scale of loss instead of log scale. Default: False. eps (float): Eps to avoid log(0). reduction (str): Options are "none", "mean" and "sum". loss_weight (float): Weight of loss. """ def __init__(self, linear=False, num_keypoints=17, eps=1e-6, reduction='mean', loss_weight=1.0): super(OKSLoss, self).__init__() self.linear = linear self.eps = eps self.reduction = reduction self.loss_weight = loss_weight if num_keypoints == 17: self.sigmas = np.array([ .26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89 ], dtype=np.float32) / 10.0 elif num_keypoints == 14: self.sigmas = np.array([ .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89, .79, .79 ]) / 10.0 elif num_keypoints == 6: self.sigmas = np.array( [ .25,.25,.25,.25,.25,.25 ], dtype=np.float32 )/ 10.0 else: raise ValueError(f'Unsupported keypoints number {num_keypoints}') def forward(self, pred, target, valid, area, weight=None, avg_factor=None, reduction_override=None): """Forward function. Args: pred (torch.Tensor): The prediction. target (torch.Tensor): The learning target of the prediction. valid (torch.Tensor): The visible flag of the target pose. area (torch.Tensor): The area of the target pose. weight (torch.Tensor, optional): The weight of loss for each prediction. Defaults to None. avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. reduction_override (str, optional): The reduction method used to override the original reduction method of the loss. Defaults to None. Options are "none", "mean" and "sum". """ assert reduction_override in (None, 'none', 'mean', 'sum') reduction = (reduction_override if reduction_override else self.reduction) if (weight is not None) and (not torch.any(weight > 0)) and ( reduction != 'none'): if pred.dim() == weight.dim() + 1: weight = weight.unsqueeze(1) return (pred * weight).sum() # 0 if weight is not None and weight.dim() > 1: # TODO: remove this in the future # reduce the weight of shape (n, 4) to (n,) to match the # iou_loss of shape (n,) assert weight.shape == pred.shape weight = weight.mean(-1) loss = self.loss_weight * oks_loss(pred, target, valid=valid, area=area, linear=self.linear, sigmas=self.sigmas, eps=self.eps) return loss