|
|
|
import torch |
|
from torch import nn |
|
from torch.autograd import Function |
|
|
|
from ..utils import ext_loader |
|
|
|
ext_module = ext_loader.load_ext( |
|
'_ext', |
|
['dynamic_point_to_voxel_forward', 'dynamic_point_to_voxel_backward']) |
|
|
|
|
|
class _DynamicScatter(Function): |
|
|
|
@staticmethod |
|
def forward(ctx, feats, coors, reduce_type='max'): |
|
"""convert kitti points(N, >=3) to voxels. |
|
|
|
Args: |
|
feats (torch.Tensor): [N, C]. Points features to be reduced |
|
into voxels. |
|
coors (torch.Tensor): [N, ndim]. Corresponding voxel coordinates |
|
(specifically multi-dim voxel index) of each points. |
|
reduce_type (str, optional): Reduce op. support 'max', 'sum' and |
|
'mean'. Default: 'max'. |
|
|
|
Returns: |
|
voxel_feats (torch.Tensor): [M, C]. Reduced features, input |
|
features that shares the same voxel coordinates are reduced to |
|
one row. |
|
voxel_coors (torch.Tensor): [M, ndim]. Voxel coordinates. |
|
""" |
|
results = ext_module.dynamic_point_to_voxel_forward( |
|
feats, coors, reduce_type) |
|
(voxel_feats, voxel_coors, point2voxel_map, |
|
voxel_points_count) = results |
|
ctx.reduce_type = reduce_type |
|
ctx.save_for_backward(feats, voxel_feats, point2voxel_map, |
|
voxel_points_count) |
|
ctx.mark_non_differentiable(voxel_coors) |
|
return voxel_feats, voxel_coors |
|
|
|
@staticmethod |
|
def backward(ctx, grad_voxel_feats, grad_voxel_coors=None): |
|
(feats, voxel_feats, point2voxel_map, |
|
voxel_points_count) = ctx.saved_tensors |
|
grad_feats = torch.zeros_like(feats) |
|
|
|
|
|
ext_module.dynamic_point_to_voxel_backward( |
|
grad_feats, grad_voxel_feats.contiguous(), feats, voxel_feats, |
|
point2voxel_map, voxel_points_count, ctx.reduce_type) |
|
return grad_feats, None, None |
|
|
|
|
|
dynamic_scatter = _DynamicScatter.apply |
|
|
|
|
|
class DynamicScatter(nn.Module): |
|
"""Scatters points into voxels, used in the voxel encoder with dynamic |
|
voxelization. |
|
|
|
Note: |
|
The CPU and GPU implementation get the same output, but have numerical |
|
difference after summation and division (e.g., 5e-7). |
|
|
|
Args: |
|
voxel_size (list): list [x, y, z] size of three dimension. |
|
point_cloud_range (list): The coordinate range of points, [x_min, |
|
y_min, z_min, x_max, y_max, z_max]. |
|
average_points (bool): whether to use avg pooling to scatter points |
|
into voxel. |
|
""" |
|
|
|
def __init__(self, voxel_size, point_cloud_range, average_points: bool): |
|
super().__init__() |
|
|
|
self.voxel_size = voxel_size |
|
self.point_cloud_range = point_cloud_range |
|
self.average_points = average_points |
|
|
|
def forward_single(self, points, coors): |
|
"""Scatters points into voxels. |
|
|
|
Args: |
|
points (torch.Tensor): Points to be reduced into voxels. |
|
coors (torch.Tensor): Corresponding voxel coordinates (specifically |
|
multi-dim voxel index) of each points. |
|
|
|
Returns: |
|
voxel_feats (torch.Tensor): Reduced features, input features that |
|
shares the same voxel coordinates are reduced to one row. |
|
voxel_coors (torch.Tensor): Voxel coordinates. |
|
""" |
|
reduce = 'mean' if self.average_points else 'max' |
|
return dynamic_scatter(points.contiguous(), coors.contiguous(), reduce) |
|
|
|
def forward(self, points, coors): |
|
"""Scatters points/features into voxels. |
|
|
|
Args: |
|
points (torch.Tensor): Points to be reduced into voxels. |
|
coors (torch.Tensor): Corresponding voxel coordinates (specifically |
|
multi-dim voxel index) of each points. |
|
|
|
Returns: |
|
voxel_feats (torch.Tensor): Reduced features, input features that |
|
shares the same voxel coordinates are reduced to one row. |
|
voxel_coors (torch.Tensor): Voxel coordinates. |
|
""" |
|
if coors.size(-1) == 3: |
|
return self.forward_single(points, coors) |
|
else: |
|
batch_size = coors[-1, 0] + 1 |
|
voxels, voxel_coors = [], [] |
|
for i in range(batch_size): |
|
inds = torch.where(coors[:, 0] == i) |
|
voxel, voxel_coor = self.forward_single( |
|
points[inds], coors[inds][:, 1:]) |
|
coor_pad = nn.functional.pad( |
|
voxel_coor, (1, 0), mode='constant', value=i) |
|
voxel_coors.append(coor_pad) |
|
voxels.append(voxel) |
|
features = torch.cat(voxels, dim=0) |
|
feature_coors = torch.cat(voxel_coors, dim=0) |
|
|
|
return features, feature_coors |
|
|
|
def __repr__(self): |
|
s = self.__class__.__name__ + '(' |
|
s += 'voxel_size=' + str(self.voxel_size) |
|
s += ', point_cloud_range=' + str(self.point_cloud_range) |
|
s += ', average_points=' + str(self.average_points) |
|
s += ')' |
|
return s |
|
|