|
|
|
import pickle |
|
from os import path as osp |
|
|
|
import mmcv |
|
import mmengine |
|
import numpy as np |
|
from mmcv.ops import roi_align |
|
from mmdet.evaluation import bbox_overlaps |
|
from mmengine import print_log, track_iter_progress |
|
from pycocotools import mask as maskUtils |
|
from pycocotools.coco import COCO |
|
|
|
from mmdet3d.registry import DATASETS |
|
from mmdet3d.structures.ops import box_np_ops as box_np_ops |
|
|
|
|
|
def _poly2mask(mask_ann, img_h, img_w): |
|
if isinstance(mask_ann, list): |
|
|
|
|
|
rles = maskUtils.frPyObjects(mask_ann, img_h, img_w) |
|
rle = maskUtils.merge(rles) |
|
elif isinstance(mask_ann['counts'], list): |
|
|
|
rle = maskUtils.frPyObjects(mask_ann, img_h, img_w) |
|
else: |
|
|
|
rle = mask_ann |
|
mask = maskUtils.decode(rle) |
|
return mask |
|
|
|
|
|
def _parse_coco_ann_info(ann_info): |
|
gt_bboxes = [] |
|
gt_labels = [] |
|
gt_bboxes_ignore = [] |
|
gt_masks_ann = [] |
|
|
|
for i, ann in enumerate(ann_info): |
|
if ann.get('ignore', False): |
|
continue |
|
x1, y1, w, h = ann['bbox'] |
|
if ann['area'] <= 0: |
|
continue |
|
bbox = [x1, y1, x1 + w, y1 + h] |
|
if ann.get('iscrowd', False): |
|
gt_bboxes_ignore.append(bbox) |
|
else: |
|
gt_bboxes.append(bbox) |
|
gt_masks_ann.append(ann['segmentation']) |
|
|
|
if gt_bboxes: |
|
gt_bboxes = np.array(gt_bboxes, dtype=np.float32) |
|
gt_labels = np.array(gt_labels, dtype=np.int64) |
|
else: |
|
gt_bboxes = np.zeros((0, 4), dtype=np.float32) |
|
gt_labels = np.array([], dtype=np.int64) |
|
|
|
if gt_bboxes_ignore: |
|
gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32) |
|
else: |
|
gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32) |
|
|
|
ann = dict( |
|
bboxes=gt_bboxes, bboxes_ignore=gt_bboxes_ignore, masks=gt_masks_ann) |
|
|
|
return ann |
|
|
|
|
|
def crop_image_patch_v2(pos_proposals, pos_assigned_gt_inds, gt_masks): |
|
import torch |
|
from torch.nn.modules.utils import _pair |
|
device = pos_proposals.device |
|
num_pos = pos_proposals.size(0) |
|
fake_inds = ( |
|
torch.arange(num_pos, |
|
device=device).to(dtype=pos_proposals.dtype)[:, None]) |
|
rois = torch.cat([fake_inds, pos_proposals], dim=1) |
|
mask_size = _pair(28) |
|
rois = rois.to(device=device) |
|
gt_masks_th = ( |
|
torch.from_numpy(gt_masks).to(device).index_select( |
|
0, pos_assigned_gt_inds).to(dtype=rois.dtype)) |
|
|
|
targets = ( |
|
roi_align(gt_masks_th, rois, mask_size[::-1], 1.0, 0, True).squeeze(1)) |
|
return targets |
|
|
|
|
|
def crop_image_patch(pos_proposals, gt_masks, pos_assigned_gt_inds, org_img): |
|
num_pos = pos_proposals.shape[0] |
|
masks = [] |
|
img_patches = [] |
|
for i in range(num_pos): |
|
gt_mask = gt_masks[pos_assigned_gt_inds[i]] |
|
bbox = pos_proposals[i, :].astype(np.int32) |
|
x1, y1, x2, y2 = bbox |
|
w = np.maximum(x2 - x1 + 1, 1) |
|
h = np.maximum(y2 - y1 + 1, 1) |
|
|
|
mask_patch = gt_mask[y1:y1 + h, x1:x1 + w] |
|
masked_img = gt_mask[..., None] * org_img |
|
img_patch = masked_img[y1:y1 + h, x1:x1 + w] |
|
|
|
img_patches.append(img_patch) |
|
masks.append(mask_patch) |
|
return img_patches, masks |
|
|
|
|
|
def create_groundtruth_database(dataset_class_name, |
|
data_path, |
|
info_prefix, |
|
info_path=None, |
|
mask_anno_path=None, |
|
used_classes=None, |
|
database_save_path=None, |
|
db_info_save_path=None, |
|
relative_path=True, |
|
add_rgb=False, |
|
lidar_only=False, |
|
bev_only=False, |
|
coors_range=None, |
|
with_mask=False): |
|
"""Given the raw data, generate the ground truth database. |
|
|
|
Args: |
|
dataset_class_name (str): Name of the input dataset. |
|
data_path (str): Path of the data. |
|
info_prefix (str): Prefix of the info file. |
|
info_path (str, optional): Path of the info file. |
|
Default: None. |
|
mask_anno_path (str, optional): Path of the mask_anno. |
|
Default: None. |
|
used_classes (list[str], optional): Classes have been used. |
|
Default: None. |
|
database_save_path (str, optional): Path to save database. |
|
Default: None. |
|
db_info_save_path (str, optional): Path to save db_info. |
|
Default: None. |
|
relative_path (bool, optional): Whether to use relative path. |
|
Default: True. |
|
with_mask (bool, optional): Whether to use mask. |
|
Default: False. |
|
""" |
|
print(f'Create GT Database of {dataset_class_name}') |
|
dataset_cfg = dict( |
|
type=dataset_class_name, data_root=data_path, ann_file=info_path) |
|
if dataset_class_name == 'KittiDataset': |
|
backend_args = None |
|
dataset_cfg.update( |
|
modality=dict( |
|
use_lidar=True, |
|
use_camera=with_mask, |
|
), |
|
data_prefix=dict( |
|
pts='training/velodyne_reduced', img='training/image_2'), |
|
pipeline=[ |
|
dict( |
|
type='LoadPointsFromFile', |
|
coord_type='LIDAR', |
|
load_dim=4, |
|
use_dim=4, |
|
backend_args=backend_args), |
|
dict( |
|
type='LoadAnnotations3D', |
|
with_bbox_3d=True, |
|
with_label_3d=True, |
|
backend_args=backend_args) |
|
]) |
|
|
|
elif dataset_class_name == 'NuScenesDataset': |
|
dataset_cfg.update( |
|
use_valid_flag=True, |
|
data_prefix=dict( |
|
pts='samples/LIDAR_TOP', img='', sweeps='sweeps/LIDAR_TOP'), |
|
pipeline=[ |
|
dict( |
|
type='LoadPointsFromFile', |
|
coord_type='LIDAR', |
|
load_dim=5, |
|
use_dim=5), |
|
dict( |
|
type='LoadPointsFromMultiSweeps', |
|
sweeps_num=10, |
|
use_dim=[0, 1, 2, 3, 4], |
|
pad_empty_sweeps=True, |
|
remove_close=True), |
|
dict( |
|
type='LoadAnnotations3D', |
|
with_bbox_3d=True, |
|
with_label_3d=True) |
|
]) |
|
|
|
elif dataset_class_name == 'WaymoDataset': |
|
backend_args = None |
|
dataset_cfg.update( |
|
test_mode=False, |
|
data_prefix=dict( |
|
pts='training/velodyne', img='', sweeps='training/velodyne'), |
|
modality=dict( |
|
use_lidar=True, |
|
use_depth=False, |
|
use_lidar_intensity=True, |
|
use_camera=False, |
|
), |
|
pipeline=[ |
|
dict( |
|
type='LoadPointsFromFile', |
|
coord_type='LIDAR', |
|
load_dim=6, |
|
use_dim=6, |
|
backend_args=backend_args), |
|
dict( |
|
type='LoadAnnotations3D', |
|
with_bbox_3d=True, |
|
with_label_3d=True, |
|
backend_args=backend_args) |
|
]) |
|
|
|
dataset = DATASETS.build(dataset_cfg) |
|
|
|
if database_save_path is None: |
|
database_save_path = osp.join(data_path, f'{info_prefix}_gt_database') |
|
if db_info_save_path is None: |
|
db_info_save_path = osp.join(data_path, |
|
f'{info_prefix}_dbinfos_train.pkl') |
|
mmengine.mkdir_or_exist(database_save_path) |
|
all_db_infos = dict() |
|
if with_mask: |
|
coco = COCO(osp.join(data_path, mask_anno_path)) |
|
imgIds = coco.getImgIds() |
|
file2id = dict() |
|
for i in imgIds: |
|
info = coco.loadImgs([i])[0] |
|
file2id.update({info['file_name']: i}) |
|
|
|
group_counter = 0 |
|
for j in track_iter_progress(list(range(len(dataset)))): |
|
data_info = dataset.get_data_info(j) |
|
example = dataset.pipeline(data_info) |
|
annos = example['ann_info'] |
|
image_idx = example['sample_idx'] |
|
points = example['points'].numpy() |
|
gt_boxes_3d = annos['gt_bboxes_3d'].numpy() |
|
names = [dataset.metainfo['classes'][i] for i in annos['gt_labels_3d']] |
|
group_dict = dict() |
|
if 'group_ids' in annos: |
|
group_ids = annos['group_ids'] |
|
else: |
|
group_ids = np.arange(gt_boxes_3d.shape[0], dtype=np.int64) |
|
difficulty = np.zeros(gt_boxes_3d.shape[0], dtype=np.int32) |
|
if 'difficulty' in annos: |
|
difficulty = annos['difficulty'] |
|
|
|
num_obj = gt_boxes_3d.shape[0] |
|
point_indices = box_np_ops.points_in_rbbox(points, gt_boxes_3d) |
|
|
|
if with_mask: |
|
|
|
gt_boxes = annos['gt_bboxes'] |
|
img_path = osp.split(example['img_info']['filename'])[-1] |
|
if img_path not in file2id.keys(): |
|
print(f'skip image {img_path} for empty mask') |
|
continue |
|
img_id = file2id[img_path] |
|
kins_annIds = coco.getAnnIds(imgIds=img_id) |
|
kins_raw_info = coco.loadAnns(kins_annIds) |
|
kins_ann_info = _parse_coco_ann_info(kins_raw_info) |
|
h, w = annos['img_shape'][:2] |
|
gt_masks = [ |
|
_poly2mask(mask, h, w) for mask in kins_ann_info['masks'] |
|
] |
|
|
|
bbox_iou = bbox_overlaps(kins_ann_info['bboxes'], gt_boxes) |
|
mask_inds = bbox_iou.argmax(axis=0) |
|
valid_inds = (bbox_iou.max(axis=0) > 0.5) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
object_img_patches, object_masks = crop_image_patch( |
|
gt_boxes, gt_masks, mask_inds, annos['img']) |
|
|
|
for i in range(num_obj): |
|
filename = f'{image_idx}_{names[i]}_{i}.bin' |
|
abs_filepath = osp.join(database_save_path, filename) |
|
rel_filepath = osp.join(f'{info_prefix}_gt_database', filename) |
|
|
|
|
|
gt_points = points[point_indices[:, i]] |
|
gt_points[:, :3] -= gt_boxes_3d[i, :3] |
|
|
|
if with_mask: |
|
if object_masks[i].sum() == 0 or not valid_inds[i]: |
|
|
|
continue |
|
img_patch_path = abs_filepath + '.png' |
|
mask_patch_path = abs_filepath + '.mask.png' |
|
mmcv.imwrite(object_img_patches[i], img_patch_path) |
|
mmcv.imwrite(object_masks[i], mask_patch_path) |
|
|
|
with open(abs_filepath, 'w') as f: |
|
gt_points.tofile(f) |
|
|
|
if (used_classes is None) or names[i] in used_classes: |
|
db_info = { |
|
'name': names[i], |
|
'path': rel_filepath, |
|
'image_idx': image_idx, |
|
'gt_idx': i, |
|
'box3d_lidar': gt_boxes_3d[i], |
|
'num_points_in_gt': gt_points.shape[0], |
|
'difficulty': difficulty[i], |
|
} |
|
local_group_id = group_ids[i] |
|
|
|
if local_group_id not in group_dict: |
|
group_dict[local_group_id] = group_counter |
|
group_counter += 1 |
|
db_info['group_id'] = group_dict[local_group_id] |
|
if 'score' in annos: |
|
db_info['score'] = annos['score'][i] |
|
if with_mask: |
|
db_info.update({'box2d_camera': gt_boxes[i]}) |
|
if names[i] in all_db_infos: |
|
all_db_infos[names[i]].append(db_info) |
|
else: |
|
all_db_infos[names[i]] = [db_info] |
|
|
|
for k, v in all_db_infos.items(): |
|
print(f'load {len(v)} {k} database infos') |
|
|
|
with open(db_info_save_path, 'wb') as f: |
|
pickle.dump(all_db_infos, f) |
|
|
|
|
|
class GTDatabaseCreater: |
|
"""Given the raw data, generate the ground truth database. This is the |
|
parallel version. For serialized version, please refer to |
|
`create_groundtruth_database` |
|
|
|
Args: |
|
dataset_class_name (str): Name of the input dataset. |
|
data_path (str): Path of the data. |
|
info_prefix (str): Prefix of the info file. |
|
info_path (str, optional): Path of the info file. |
|
Default: None. |
|
mask_anno_path (str, optional): Path of the mask_anno. |
|
Default: None. |
|
used_classes (list[str], optional): Classes have been used. |
|
Default: None. |
|
database_save_path (str, optional): Path to save database. |
|
Default: None. |
|
db_info_save_path (str, optional): Path to save db_info. |
|
Default: None. |
|
relative_path (bool, optional): Whether to use relative path. |
|
Default: True. |
|
with_mask (bool, optional): Whether to use mask. |
|
Default: False. |
|
num_worker (int, optional): the number of parallel workers to use. |
|
Default: 8. |
|
""" |
|
|
|
def __init__(self, |
|
dataset_class_name, |
|
data_path, |
|
info_prefix, |
|
info_path=None, |
|
mask_anno_path=None, |
|
used_classes=None, |
|
database_save_path=None, |
|
db_info_save_path=None, |
|
relative_path=True, |
|
add_rgb=False, |
|
lidar_only=False, |
|
bev_only=False, |
|
coors_range=None, |
|
with_mask=False, |
|
num_worker=8) -> None: |
|
self.dataset_class_name = dataset_class_name |
|
self.data_path = data_path |
|
self.info_prefix = info_prefix |
|
self.info_path = info_path |
|
self.mask_anno_path = mask_anno_path |
|
self.used_classes = used_classes |
|
self.database_save_path = database_save_path |
|
self.db_info_save_path = db_info_save_path |
|
self.relative_path = relative_path |
|
self.add_rgb = add_rgb |
|
self.lidar_only = lidar_only |
|
self.bev_only = bev_only |
|
self.coors_range = coors_range |
|
self.with_mask = with_mask |
|
self.num_worker = num_worker |
|
self.pipeline = None |
|
|
|
def create_single(self, input_dict): |
|
group_counter = 0 |
|
single_db_infos = dict() |
|
example = self.pipeline(input_dict) |
|
annos = example['ann_info'] |
|
image_idx = example['sample_idx'] |
|
points = example['points'].numpy() |
|
gt_boxes_3d = annos['gt_bboxes_3d'].numpy() |
|
names = [ |
|
self.dataset.metainfo['classes'][i] for i in annos['gt_labels_3d'] |
|
] |
|
group_dict = dict() |
|
if 'group_ids' in annos: |
|
group_ids = annos['group_ids'] |
|
else: |
|
group_ids = np.arange(gt_boxes_3d.shape[0], dtype=np.int64) |
|
difficulty = np.zeros(gt_boxes_3d.shape[0], dtype=np.int32) |
|
if 'difficulty' in annos: |
|
difficulty = annos['difficulty'] |
|
|
|
num_obj = gt_boxes_3d.shape[0] |
|
point_indices = box_np_ops.points_in_rbbox(points, gt_boxes_3d) |
|
|
|
if self.with_mask: |
|
|
|
gt_boxes = annos['gt_bboxes'] |
|
img_path = osp.split(example['img_info']['filename'])[-1] |
|
if img_path not in self.file2id.keys(): |
|
print(f'skip image {img_path} for empty mask') |
|
return single_db_infos |
|
img_id = self.file2id[img_path] |
|
kins_annIds = self.coco.getAnnIds(imgIds=img_id) |
|
kins_raw_info = self.coco.loadAnns(kins_annIds) |
|
kins_ann_info = _parse_coco_ann_info(kins_raw_info) |
|
h, w = annos['img_shape'][:2] |
|
gt_masks = [ |
|
_poly2mask(mask, h, w) for mask in kins_ann_info['masks'] |
|
] |
|
|
|
bbox_iou = bbox_overlaps(kins_ann_info['bboxes'], gt_boxes) |
|
mask_inds = bbox_iou.argmax(axis=0) |
|
valid_inds = (bbox_iou.max(axis=0) > 0.5) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
object_img_patches, object_masks = crop_image_patch( |
|
gt_boxes, gt_masks, mask_inds, annos['img']) |
|
|
|
for i in range(num_obj): |
|
filename = f'{image_idx}_{names[i]}_{i}.bin' |
|
abs_filepath = osp.join(self.database_save_path, filename) |
|
rel_filepath = osp.join(f'{self.info_prefix}_gt_database', |
|
filename) |
|
|
|
|
|
gt_points = points[point_indices[:, i]] |
|
gt_points[:, :3] -= gt_boxes_3d[i, :3] |
|
|
|
if self.with_mask: |
|
if object_masks[i].sum() == 0 or not valid_inds[i]: |
|
|
|
continue |
|
img_patch_path = abs_filepath + '.png' |
|
mask_patch_path = abs_filepath + '.mask.png' |
|
mmcv.imwrite(object_img_patches[i], img_patch_path) |
|
mmcv.imwrite(object_masks[i], mask_patch_path) |
|
|
|
with open(abs_filepath, 'w') as f: |
|
gt_points.tofile(f) |
|
|
|
if (self.used_classes is None) or names[i] in self.used_classes: |
|
db_info = { |
|
'name': names[i], |
|
'path': rel_filepath, |
|
'image_idx': image_idx, |
|
'gt_idx': i, |
|
'box3d_lidar': gt_boxes_3d[i], |
|
'num_points_in_gt': gt_points.shape[0], |
|
'difficulty': difficulty[i], |
|
} |
|
local_group_id = group_ids[i] |
|
|
|
if local_group_id not in group_dict: |
|
group_dict[local_group_id] = group_counter |
|
group_counter += 1 |
|
db_info['group_id'] = group_dict[local_group_id] |
|
if 'score' in annos: |
|
db_info['score'] = annos['score'][i] |
|
if self.with_mask: |
|
db_info.update({'box2d_camera': gt_boxes[i]}) |
|
if names[i] in single_db_infos: |
|
single_db_infos[names[i]].append(db_info) |
|
else: |
|
single_db_infos[names[i]] = [db_info] |
|
|
|
return single_db_infos |
|
|
|
def create(self): |
|
print_log( |
|
f'Create GT Database of {self.dataset_class_name}', |
|
logger='current') |
|
dataset_cfg = dict( |
|
type=self.dataset_class_name, |
|
data_root=self.data_path, |
|
ann_file=self.info_path) |
|
if self.dataset_class_name == 'KittiDataset': |
|
backend_args = None |
|
dataset_cfg.update( |
|
test_mode=False, |
|
data_prefix=dict( |
|
pts='training/velodyne_reduced', img='training/image_2'), |
|
modality=dict( |
|
use_lidar=True, |
|
use_depth=False, |
|
use_lidar_intensity=True, |
|
use_camera=self.with_mask, |
|
), |
|
pipeline=[ |
|
dict( |
|
type='LoadPointsFromFile', |
|
coord_type='LIDAR', |
|
load_dim=4, |
|
use_dim=4, |
|
backend_args=backend_args), |
|
dict( |
|
type='LoadAnnotations3D', |
|
with_bbox_3d=True, |
|
with_label_3d=True, |
|
backend_args=backend_args) |
|
]) |
|
|
|
elif self.dataset_class_name == 'NuScenesDataset': |
|
dataset_cfg.update( |
|
use_valid_flag=True, |
|
data_prefix=dict( |
|
pts='samples/LIDAR_TOP', img='', |
|
sweeps='sweeps/LIDAR_TOP'), |
|
pipeline=[ |
|
dict( |
|
type='LoadPointsFromFile', |
|
coord_type='LIDAR', |
|
load_dim=5, |
|
use_dim=5), |
|
dict( |
|
type='LoadPointsFromMultiSweeps', |
|
sweeps_num=10, |
|
use_dim=[0, 1, 2, 3, 4], |
|
pad_empty_sweeps=True, |
|
remove_close=True), |
|
dict( |
|
type='LoadAnnotations3D', |
|
with_bbox_3d=True, |
|
with_label_3d=True) |
|
]) |
|
|
|
elif self.dataset_class_name == 'WaymoDataset': |
|
backend_args = None |
|
dataset_cfg.update( |
|
test_mode=False, |
|
data_prefix=dict( |
|
pts='training/velodyne', |
|
img='', |
|
sweeps='training/velodyne'), |
|
modality=dict( |
|
use_lidar=True, |
|
use_depth=False, |
|
use_lidar_intensity=True, |
|
use_camera=False, |
|
), |
|
pipeline=[ |
|
dict( |
|
type='LoadPointsFromFile', |
|
coord_type='LIDAR', |
|
load_dim=6, |
|
use_dim=6, |
|
backend_args=backend_args), |
|
dict( |
|
type='LoadAnnotations3D', |
|
with_bbox_3d=True, |
|
with_label_3d=True, |
|
backend_args=backend_args) |
|
]) |
|
|
|
self.dataset = DATASETS.build(dataset_cfg) |
|
self.pipeline = self.dataset.pipeline |
|
if self.database_save_path is None: |
|
self.database_save_path = osp.join( |
|
self.data_path, f'{self.info_prefix}_gt_database') |
|
if self.db_info_save_path is None: |
|
self.db_info_save_path = osp.join( |
|
self.data_path, f'{self.info_prefix}_dbinfos_train.pkl') |
|
mmengine.mkdir_or_exist(self.database_save_path) |
|
if self.with_mask: |
|
self.coco = COCO(osp.join(self.data_path, self.mask_anno_path)) |
|
imgIds = self.coco.getImgIds() |
|
self.file2id = dict() |
|
for i in imgIds: |
|
info = self.coco.loadImgs([i])[0] |
|
self.file2id.update({info['file_name']: i}) |
|
|
|
def loop_dataset(i): |
|
input_dict = self.dataset.get_data_info(i) |
|
input_dict['box_type_3d'] = self.dataset.box_type_3d |
|
input_dict['box_mode_3d'] = self.dataset.box_mode_3d |
|
return input_dict |
|
|
|
if self.num_worker == 0: |
|
multi_db_infos = mmengine.track_progress( |
|
self.create_single, |
|
((loop_dataset(i) |
|
for i in range(len(self.dataset))), len(self.dataset))) |
|
else: |
|
multi_db_infos = mmengine.track_parallel_progress( |
|
self.create_single, |
|
((loop_dataset(i) |
|
for i in range(len(self.dataset))), len(self.dataset)), |
|
self.num_worker, |
|
chunksize=1000) |
|
print_log('Make global unique group id', logger='current') |
|
group_counter_offset = 0 |
|
all_db_infos = dict() |
|
for single_db_infos in track_iter_progress(multi_db_infos): |
|
group_id = -1 |
|
for name, name_db_infos in single_db_infos.items(): |
|
for db_info in name_db_infos: |
|
group_id = max(group_id, db_info['group_id']) |
|
db_info['group_id'] += group_counter_offset |
|
if name not in all_db_infos: |
|
all_db_infos[name] = [] |
|
all_db_infos[name].extend(name_db_infos) |
|
group_counter_offset += (group_id + 1) |
|
|
|
for k, v in all_db_infos.items(): |
|
print_log(f'load {len(v)} {k} database infos', logger='current') |
|
|
|
print_log(f'Saving GT database infos into {self.db_info_save_path}') |
|
with open(self.db_info_save_path, 'wb') as f: |
|
pickle.dump(all_db_infos, f) |
|
|