AiOS / models /aios /utils.py
ttxskk
update
d7e58f0
raw
history blame
13.5 kB
import torch
import random
from torch import nn, Tensor
import os
import numpy as np
import math
import torch.nn.functional as F
from torch import nn
class PoseProjector(nn.Module):
def __init__(self, hidden_dim=256, num_body_points=17):
super().__init__()
self.num_body_points = num_body_points
self.V_projector = nn.Linear(hidden_dim, num_body_points)
nn.init.constant_(self.V_projector.bias.data, 0)
self.Z_projector = MLP(hidden_dim, hidden_dim, num_body_points * 2, 3)
nn.init.constant_(self.Z_projector.layers[-1].weight.data, 0)
nn.init.constant_(self.Z_projector.layers[-1].bias.data, 0)
def forward(self, hs):
"""_summary_
Args:
hs (_type_): ..., bs, nq, hidden_dim
"""
Z = self.Z_projector(hs) # ..., bs, nq, 34
V = self.V_projector(hs) # ..., bs, nq, 17
return Z, V
def gen_encoder_output_proposals(memory: Tensor,
memory_padding_mask: Tensor,
spatial_shapes: Tensor,
learnedwh=None):
"""
Input:
- memory: bs, \sum{hw}, d_model
- memory_padding_mask: bs, \sum{hw}
- spatial_shapes: nlevel, 2
- learnedwh: 2
Output:
- output_memory: bs, \sum{hw}, d_model
- output_proposals: bs, \sum{hw}, 4
"""
N_, S_, C_ = memory.shape
base_scale = 4.0
proposals = []
_cur = 0
for lvl, (H_, W_) in enumerate(spatial_shapes):
mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view(
N_, H_, W_, 1)
valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
grid_y, grid_x = torch.meshgrid(
torch.linspace(0,
H_ - 1,
H_,
dtype=torch.float32,
device=memory.device),
torch.linspace(0,
W_ - 1,
W_,
dtype=torch.float32,
device=memory.device))
grid = torch.cat(
[grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) # H_, W_, 2
scale = torch.cat([valid_W.unsqueeze(-1),
valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2)
grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
if learnedwh is not None:
wh = torch.ones_like(grid) * learnedwh.sigmoid() * (2.0**lvl)
else:
wh = torch.ones_like(grid) * 0.05 * (2.0**lvl)
proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)
proposals.append(proposal)
_cur += (H_ * W_)
# import pdb; pdb.set_trace()
output_proposals = torch.cat(proposals, 1)
output_proposals_valid = ((output_proposals > 0.01) &
(output_proposals < 0.99)).all(-1, keepdim=True)
output_proposals = torch.log(output_proposals /
(1 - output_proposals)) # unsigmoid
output_proposals = output_proposals.masked_fill(
memory_padding_mask.unsqueeze(-1), float('inf'))
output_proposals = output_proposals.masked_fill(~output_proposals_valid,
float('inf'))
output_memory = memory
output_memory = output_memory.masked_fill(
memory_padding_mask.unsqueeze(-1), float(0))
output_memory = output_memory.masked_fill(~output_proposals_valid,
float(0))
return output_memory, output_proposals
class RandomBoxPerturber():
def __init__(self,
x_noise_scale=0.2,
y_noise_scale=0.2,
w_noise_scale=0.2,
h_noise_scale=0.2) -> None:
self.noise_scale = torch.Tensor(
[x_noise_scale, y_noise_scale, w_noise_scale, h_noise_scale])
def __call__(self, refanchors: Tensor) -> Tensor:
nq, bs, query_dim = refanchors.shape
device = refanchors.device
noise_raw = torch.rand_like(refanchors)
noise_scale = self.noise_scale.to(device)[:query_dim]
new_refanchors = refanchors * (1 + (noise_raw - 0.5) * noise_scale)
return new_refanchors.clamp_(0, 1)
def sigmoid_focal_loss(inputs,
targets,
num_boxes,
alpha: float = 0.25,
gamma: float = 2):
"""
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
alpha: (optional) Weighting factor in range (0,1) to balance
positive vs negative examples. Default = -1 (no weighting).
gamma: Exponent of the modulating factor (1 - p_t) to
balance easy vs hard examples.
Returns:
Loss tensor
"""
prob = inputs.sigmoid()
ce_loss = F.binary_cross_entropy_with_logits(inputs,
targets,
reduction='none')
p_t = prob * targets + (1 - prob) * (1 - targets)
loss = ce_loss * ((1 - p_t)**gamma)
if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss
return loss.mean(1).sum() / num_boxes
class MLP(nn.Module):
"""Very simple multi-layer perceptron (also called FFN)"""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x
def _get_activation_fn(activation, d_model=256, batch_dim=0):
"""Return an activation function given a string."""
if activation == 'relu':
return F.relu
if activation == 'gelu':
return F.gelu
if activation == 'glu':
return F.glu
if activation == 'prelu':
return nn.PReLU()
if activation == 'selu':
return F.selu
raise RuntimeError(F'activation should be relu/gelu, not {activation}.')
def gen_sineembed_for_position(pos_tensor):
# n_query, bs, _ = pos_tensor.size()
# sineembed_tensor = torch.zeros(n_query, bs, 256)
scale = 2 * math.pi
dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device)
dim_t = 10000**(2 * (dim_t // 2) / 128)
x_embed = pos_tensor[:, :, 0] * scale
y_embed = pos_tensor[:, :, 1] * scale
pos_x = x_embed[:, :, None] / dim_t
pos_y = y_embed[:, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()),
dim=3).flatten(2)
pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()),
dim=3).flatten(2)
if pos_tensor.size(-1) == 2:
pos = torch.cat((pos_y, pos_x), dim=2)
elif pos_tensor.size(-1) == 4:
w_embed = pos_tensor[:, :, 2] * scale
pos_w = w_embed[:, :, None] / dim_t
pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()),
dim=3).flatten(2)
h_embed = pos_tensor[:, :, 3] * scale
pos_h = h_embed[:, :, None] / dim_t
pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()),
dim=3).flatten(2)
pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
else:
raise ValueError('Unknown pos_tensor shape(-1):{}'.format(
pos_tensor.size(-1)))
return pos
def oks_overlaps(kpt_preds, kpt_gts, kpt_valids, kpt_areas, sigmas):
sigmas = kpt_preds.new_tensor(sigmas)
variances = (sigmas * 2)**2
assert kpt_preds.size(0) == kpt_gts.size(0)
kpt_preds = kpt_preds.reshape(-1, kpt_preds.size(-1) // 2, 2)
kpt_gts = kpt_gts.reshape(-1, kpt_gts.size(-1) // 2, 2)
squared_distance = (kpt_preds[:, :, 0] - kpt_gts[:, :, 0]) ** 2 + \
(kpt_preds[:, :, 1] - kpt_gts[:, :, 1]) ** 2
# import pdb
# pdb.set_trace()
# assert (kpt_valids.sum(-1) > 0).all()
squared_distance0 = squared_distance / (kpt_areas[:, None] *
variances[None, :] * 2)
squared_distance1 = torch.exp(-squared_distance0)
squared_distance1 = squared_distance1 * kpt_valids
oks = squared_distance1.sum(dim=1) / (kpt_valids.sum(dim=1) + 1e-6)
return oks
def oks_loss(pred,
target,
valid=None,
area=None,
linear=False,
sigmas=None,
eps=1e-6):
"""Oks loss.
Computing the oks loss between a set of predicted poses and target poses.
The loss is calculated as negative log of oks.
Args:
pred (torch.Tensor): Predicted poses of format (x1, y1, x2, y2, ...),
shape (n, 2K).
target (torch.Tensor): Corresponding gt poses, shape (n, 2K).
linear (bool, optional): If True, use linear scale of loss instead of
log scale. Default: False.
eps (float): Eps to avoid log(0).
Return:
torch.Tensor: Loss tensor.
"""
oks = oks_overlaps(pred, target, valid, area, sigmas).clamp(min=eps)
if linear:
loss = 1 - oks
else:
loss = -oks.log()
loss = loss * valid.sum(-1) / (valid.sum(-1) + eps)
return loss
class OKSLoss(nn.Module):
"""IoULoss.
Computing the oks loss between a set of predicted poses and target poses.
Args:
linear (bool): If True, use linear scale of loss instead of log scale.
Default: False.
eps (float): Eps to avoid log(0).
reduction (str): Options are "none", "mean" and "sum".
loss_weight (float): Weight of loss.
"""
def __init__(self,
linear=False,
num_keypoints=17,
eps=1e-6,
reduction='mean',
loss_weight=1.0):
super(OKSLoss, self).__init__()
self.linear = linear
self.eps = eps
self.reduction = reduction
self.loss_weight = loss_weight
if num_keypoints == 17:
self.sigmas = np.array([
.26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07,
1.07, .87, .87, .89, .89
],
dtype=np.float32) / 10.0
elif num_keypoints == 14:
self.sigmas = np.array([
.79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89,
.79, .79
]) / 10.0
elif num_keypoints == 6:
self.sigmas = np.array(
[
.25,.25,.25,.25,.25,.25
], dtype=np.float32
)/ 10.0
else:
raise ValueError(f'Unsupported keypoints number {num_keypoints}')
def forward(self,
pred,
target,
valid,
area,
weight=None,
avg_factor=None,
reduction_override=None):
"""Forward function.
Args:
pred (torch.Tensor): The prediction.
target (torch.Tensor): The learning target of the prediction.
valid (torch.Tensor): The visible flag of the target pose.
area (torch.Tensor): The area of the target pose.
weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None. Options are "none", "mean" and "sum".
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (reduction_override
if reduction_override else self.reduction)
if (weight is not None) and (not torch.any(weight > 0)) and (
reduction != 'none'):
if pred.dim() == weight.dim() + 1:
weight = weight.unsqueeze(1)
return (pred * weight).sum() # 0
if weight is not None and weight.dim() > 1:
# TODO: remove this in the future
# reduce the weight of shape (n, 4) to (n,) to match the
# iou_loss of shape (n,)
assert weight.shape == pred.shape
weight = weight.mean(-1)
loss = self.loss_weight * oks_loss(pred,
target,
valid=valid,
area=area,
linear=self.linear,
sigmas=self.sigmas,
eps=self.eps)
return loss