AiOS / mmcv /tests /test_cnn /test_model_registry.py
ttxskk
update
d7e58f0
raw
history blame
1.86 kB
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
import mmcv
from mmcv.cnn import MODELS, build_model_from_cfg
def test_build_model_from_cfg():
BACKBONES = mmcv.Registry('backbone', build_func=build_model_from_cfg)
@BACKBONES.register_module()
class ResNet(nn.Module):
def __init__(self, depth, stages=4):
super().__init__()
self.depth = depth
self.stages = stages
def forward(self, x):
return x
@BACKBONES.register_module()
class ResNeXt(nn.Module):
def __init__(self, depth, stages=4):
super().__init__()
self.depth = depth
self.stages = stages
def forward(self, x):
return x
cfg = dict(type='ResNet', depth=50)
model = BACKBONES.build(cfg)
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
cfg = dict(type='ResNeXt', depth=50, stages=3)
model = BACKBONES.build(cfg)
assert isinstance(model, ResNeXt)
assert model.depth == 50 and model.stages == 3
cfg = [
dict(type='ResNet', depth=50),
dict(type='ResNeXt', depth=50, stages=3)
]
model = BACKBONES.build(cfg)
assert isinstance(model, nn.Sequential)
assert isinstance(model[0], ResNet)
assert model[0].depth == 50 and model[0].stages == 4
assert isinstance(model[1], ResNeXt)
assert model[1].depth == 50 and model[1].stages == 3
# test inherit `build_func` from parent
NEW_MODELS = mmcv.Registry('models', parent=MODELS, scope='new')
assert NEW_MODELS.build_func is build_model_from_cfg
# test specify `build_func`
def pseudo_build(cfg):
return cfg
NEW_MODELS = mmcv.Registry(
'models', parent=MODELS, build_func=pseudo_build)
assert NEW_MODELS.build_func is pseudo_build