AiOS / models /aios /ops /functions /ms_deform_attn_func.py
ttxskk
update
d7e58f0
raw
history blame
3.48 kB
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn.functional as F
from torch.autograd import Function
from torch.autograd.function import once_differentiable
import MultiScaleDeformableAttention as MSDA
class MSDeformAttnFunction(Function):
@staticmethod
def forward(ctx, value, value_spatial_shapes, value_level_start_index,
sampling_locations, attention_weights, im2col_step):
ctx.im2col_step = im2col_step
output = MSDA.ms_deform_attn_forward(value, value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
ctx.im2col_step)
ctx.save_for_backward(value, value_spatial_shapes,
value_level_start_index, sampling_locations,
attention_weights)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors
grad_value, grad_sampling_loc, grad_attn_weight = \
MSDA.ms_deform_attn_backward(
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step)
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
def ms_deform_attn_core_pytorch(value, value_spatial_shapes,
sampling_locations, attention_weights):
N_, S_, M_, D_ = value.shape
_, Lq_, M_, L_, P_, _ = sampling_locations.shape
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes],
dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for lid_, (H_, W_) in enumerate(value_spatial_shapes):
value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(
N_ * M_, D_, H_, W_)
sampling_grid_l_ = sampling_grids[:, :, :,
lid_].transpose(1, 2).flatten(0, 1)
sampling_value_l_ = F.grid_sample(value_l_,
sampling_grid_l_,
mode='bilinear',
padding_mode='zeros',
align_corners=False)
sampling_value_list.append(sampling_value_l_)
attention_weights = attention_weights.transpose(1, 2).reshape(
N_ * M_, 1, Lq_, L_ * P_)
output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) *
attention_weights).sum(-1).view(N_, M_ * D_, Lq_)
return output.transpose(1, 2).contiguous()