|
|
|
import argparse |
|
import tempfile |
|
|
|
import torch |
|
from mmcv import Config |
|
from mmengine.runner import load_state_dict |
|
|
|
from mmdet3d.registry import MODELS |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser( |
|
description='MMDet3D upgrade model version(before v0.6.0) of H3DNet') |
|
parser.add_argument('checkpoint', help='checkpoint file') |
|
parser.add_argument('--out', help='path of the output checkpoint file') |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def parse_config(config_strings): |
|
"""Parse config from strings. |
|
|
|
Args: |
|
config_strings (string): strings of model config. |
|
|
|
Returns: |
|
Config: model config |
|
""" |
|
temp_file = tempfile.NamedTemporaryFile() |
|
config_path = f'{temp_file.name}.py' |
|
with open(config_path, 'w') as f: |
|
f.write(config_strings) |
|
|
|
config = Config.fromfile(config_path) |
|
|
|
|
|
if 'pool_mod' in config.model.backbone.backbones: |
|
config.model.backbone.backbones.pop('pool_mod') |
|
|
|
if 'sa_cfg' not in config.model.backbone: |
|
config.model.backbone['sa_cfg'] = dict( |
|
type='PointSAModule', |
|
pool_mod='max', |
|
use_xyz=True, |
|
normalize_xyz=True) |
|
|
|
if 'type' not in config.model.rpn_head.vote_aggregation_cfg: |
|
config.model.rpn_head.vote_aggregation_cfg['type'] = 'PointSAModule' |
|
|
|
|
|
if 'pred_layer_cfg' not in config.model.rpn_head: |
|
config.model.rpn_head['pred_layer_cfg'] = dict( |
|
in_channels=128, shared_conv_channels=(128, 128), bias=True) |
|
|
|
if 'feat_channels' in config.model.rpn_head: |
|
config.model.rpn_head.pop('feat_channels') |
|
|
|
if 'vote_moudule_cfg' in config.model.rpn_head: |
|
config.model.rpn_head['vote_module_cfg'] = config.model.rpn_head.pop( |
|
'vote_moudule_cfg') |
|
|
|
if config.model.rpn_head.vote_aggregation_cfg.use_xyz: |
|
config.model.rpn_head.vote_aggregation_cfg.mlp_channels[0] -= 3 |
|
|
|
for cfg in config.model.roi_head.primitive_list: |
|
cfg['vote_module_cfg'] = cfg.pop('vote_moudule_cfg') |
|
cfg.vote_aggregation_cfg.mlp_channels[0] -= 3 |
|
if 'type' not in cfg.vote_aggregation_cfg: |
|
cfg.vote_aggregation_cfg['type'] = 'PointSAModule' |
|
|
|
if 'type' not in config.model.roi_head.bbox_head.suface_matching_cfg: |
|
config.model.roi_head.bbox_head.suface_matching_cfg[ |
|
'type'] = 'PointSAModule' |
|
|
|
if config.model.roi_head.bbox_head.suface_matching_cfg.use_xyz: |
|
config.model.roi_head.bbox_head.suface_matching_cfg.mlp_channels[ |
|
0] -= 3 |
|
|
|
if 'type' not in config.model.roi_head.bbox_head.line_matching_cfg: |
|
config.model.roi_head.bbox_head.line_matching_cfg[ |
|
'type'] = 'PointSAModule' |
|
|
|
if config.model.roi_head.bbox_head.line_matching_cfg.use_xyz: |
|
config.model.roi_head.bbox_head.line_matching_cfg.mlp_channels[0] -= 3 |
|
|
|
if 'proposal_module_cfg' in config.model.roi_head.bbox_head: |
|
config.model.roi_head.bbox_head.pop('proposal_module_cfg') |
|
|
|
temp_file.close() |
|
|
|
return config |
|
|
|
|
|
def main(): |
|
"""Convert keys in checkpoints for VoteNet. |
|
|
|
There can be some breaking changes during the development of mmdetection3d, |
|
and this tool is used for upgrading checkpoints trained with old versions |
|
(before v0.6.0) to the latest one. |
|
""" |
|
args = parse_args() |
|
checkpoint = torch.load(args.checkpoint) |
|
cfg = parse_config(checkpoint['meta']['config']) |
|
|
|
model = MODELS.build( |
|
cfg.model, |
|
train_cfg=cfg.get('train_cfg'), |
|
test_cfg=cfg.get('test_cfg')) |
|
orig_ckpt = checkpoint['state_dict'] |
|
converted_ckpt = orig_ckpt.copy() |
|
|
|
if cfg['dataset_type'] == 'ScanNetDataset': |
|
NUM_CLASSES = 18 |
|
elif cfg['dataset_type'] == 'SUNRGBDDataset': |
|
NUM_CLASSES = 10 |
|
else: |
|
raise NotImplementedError |
|
|
|
RENAME_PREFIX = { |
|
'rpn_head.conv_pred.0': 'rpn_head.conv_pred.shared_convs.layer0', |
|
'rpn_head.conv_pred.1': 'rpn_head.conv_pred.shared_convs.layer1' |
|
} |
|
|
|
DEL_KEYS = [ |
|
'rpn_head.conv_pred.0.bn.num_batches_tracked', |
|
'rpn_head.conv_pred.1.bn.num_batches_tracked' |
|
] |
|
|
|
EXTRACT_KEYS = { |
|
'rpn_head.conv_pred.conv_cls.weight': |
|
('rpn_head.conv_pred.conv_out.weight', [(0, 2), (-NUM_CLASSES, -1)]), |
|
'rpn_head.conv_pred.conv_cls.bias': |
|
('rpn_head.conv_pred.conv_out.bias', [(0, 2), (-NUM_CLASSES, -1)]), |
|
'rpn_head.conv_pred.conv_reg.weight': |
|
('rpn_head.conv_pred.conv_out.weight', [(2, -NUM_CLASSES)]), |
|
'rpn_head.conv_pred.conv_reg.bias': |
|
('rpn_head.conv_pred.conv_out.bias', [(2, -NUM_CLASSES)]) |
|
} |
|
|
|
|
|
for key in DEL_KEYS: |
|
converted_ckpt.pop(key) |
|
|
|
|
|
RENAME_KEYS = dict() |
|
for old_key in converted_ckpt.keys(): |
|
for rename_prefix in RENAME_PREFIX.keys(): |
|
if rename_prefix in old_key: |
|
new_key = old_key.replace(rename_prefix, |
|
RENAME_PREFIX[rename_prefix]) |
|
RENAME_KEYS[new_key] = old_key |
|
for new_key, old_key in RENAME_KEYS.items(): |
|
converted_ckpt[new_key] = converted_ckpt.pop(old_key) |
|
|
|
|
|
for new_key, (old_key, indices) in EXTRACT_KEYS.items(): |
|
cur_layers = orig_ckpt[old_key] |
|
converted_layers = [] |
|
for (start, end) in indices: |
|
if end != -1: |
|
converted_layers.append(cur_layers[start:end]) |
|
else: |
|
converted_layers.append(cur_layers[start:]) |
|
converted_layers = torch.cat(converted_layers, 0) |
|
converted_ckpt[new_key] = converted_layers |
|
if old_key in converted_ckpt.keys(): |
|
converted_ckpt.pop(old_key) |
|
|
|
|
|
load_state_dict(model, converted_ckpt, strict=True) |
|
checkpoint['state_dict'] = converted_ckpt |
|
torch.save(checkpoint, args.out) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|