Spaces:
Starting
on
L40S
Starting
on
L40S
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch.nn as nn | |
import mmcv | |
from mmcv.cnn import MODELS, build_model_from_cfg | |
def test_build_model_from_cfg(): | |
BACKBONES = mmcv.Registry('backbone', build_func=build_model_from_cfg) | |
class ResNet(nn.Module): | |
def __init__(self, depth, stages=4): | |
super().__init__() | |
self.depth = depth | |
self.stages = stages | |
def forward(self, x): | |
return x | |
class ResNeXt(nn.Module): | |
def __init__(self, depth, stages=4): | |
super().__init__() | |
self.depth = depth | |
self.stages = stages | |
def forward(self, x): | |
return x | |
cfg = dict(type='ResNet', depth=50) | |
model = BACKBONES.build(cfg) | |
assert isinstance(model, ResNet) | |
assert model.depth == 50 and model.stages == 4 | |
cfg = dict(type='ResNeXt', depth=50, stages=3) | |
model = BACKBONES.build(cfg) | |
assert isinstance(model, ResNeXt) | |
assert model.depth == 50 and model.stages == 3 | |
cfg = [ | |
dict(type='ResNet', depth=50), | |
dict(type='ResNeXt', depth=50, stages=3) | |
] | |
model = BACKBONES.build(cfg) | |
assert isinstance(model, nn.Sequential) | |
assert isinstance(model[0], ResNet) | |
assert model[0].depth == 50 and model[0].stages == 4 | |
assert isinstance(model[1], ResNeXt) | |
assert model[1].depth == 50 and model[1].stages == 3 | |
# test inherit `build_func` from parent | |
NEW_MODELS = mmcv.Registry('models', parent=MODELS, scope='new') | |
assert NEW_MODELS.build_func is build_model_from_cfg | |
# test specify `build_func` | |
def pseudo_build(cfg): | |
return cfg | |
NEW_MODELS = mmcv.Registry( | |
'models', parent=MODELS, build_func=pseudo_build) | |
assert NEW_MODELS.build_func is pseudo_build | |