Spaces:
Sleeping
Sleeping
# Copyright (c) OpenMMLab. All rights reserved. | |
import pytest | |
import mmcv | |
def test_registry(): | |
CATS = mmcv.Registry('cat') | |
assert CATS.name == 'cat' | |
assert CATS.module_dict == {} | |
assert len(CATS) == 0 | |
class BritishShorthair: | |
pass | |
assert len(CATS) == 1 | |
assert CATS.get('BritishShorthair') is BritishShorthair | |
class Munchkin: | |
pass | |
CATS.register_module(Munchkin) | |
assert len(CATS) == 2 | |
assert CATS.get('Munchkin') is Munchkin | |
assert 'Munchkin' in CATS | |
with pytest.raises(KeyError): | |
CATS.register_module(Munchkin) | |
CATS.register_module(Munchkin, force=True) | |
assert len(CATS) == 2 | |
# force=False | |
with pytest.raises(KeyError): | |
class BritishShorthair: | |
pass | |
class BritishShorthair: | |
pass | |
assert len(CATS) == 2 | |
assert CATS.get('PersianCat') is None | |
assert 'PersianCat' not in CATS | |
class SiameseCat: | |
pass | |
assert CATS.get('Siamese').__name__ == 'SiameseCat' | |
assert CATS.get('Siamese2').__name__ == 'SiameseCat' | |
class SphynxCat: | |
pass | |
CATS.register_module(name='Sphynx', module=SphynxCat) | |
assert CATS.get('Sphynx') is SphynxCat | |
CATS.register_module(name=['Sphynx1', 'Sphynx2'], module=SphynxCat) | |
assert CATS.get('Sphynx2') is SphynxCat | |
repr_str = 'Registry(name=cat, items={' | |
repr_str += ("'BritishShorthair': <class 'test_registry.test_registry." | |
"<locals>.BritishShorthair'>, ") | |
repr_str += ("'Munchkin': <class 'test_registry.test_registry." | |
"<locals>.Munchkin'>, ") | |
repr_str += ("'Siamese': <class 'test_registry.test_registry." | |
"<locals>.SiameseCat'>, ") | |
repr_str += ("'Siamese2': <class 'test_registry.test_registry." | |
"<locals>.SiameseCat'>, ") | |
repr_str += ("'Sphynx': <class 'test_registry.test_registry." | |
"<locals>.SphynxCat'>, ") | |
repr_str += ("'Sphynx1': <class 'test_registry.test_registry." | |
"<locals>.SphynxCat'>, ") | |
repr_str += ("'Sphynx2': <class 'test_registry.test_registry." | |
"<locals>.SphynxCat'>") | |
repr_str += '})' | |
assert repr(CATS) == repr_str | |
# name type | |
with pytest.raises(TypeError): | |
CATS.register_module(name=7474741, module=SphynxCat) | |
# the registered module should be a class | |
with pytest.raises(TypeError): | |
CATS.register_module(0) | |
def muchkin(): | |
pass | |
assert CATS.get('muchkin') is muchkin | |
assert 'muchkin' in CATS | |
# can only decorate a class or a function | |
with pytest.raises(TypeError): | |
class Demo: | |
def some_method(self): | |
pass | |
method = Demo().some_method | |
CATS.register_module(name='some_method', module=method) | |
# begin: test old APIs | |
with pytest.warns(DeprecationWarning): | |
CATS.register_module(SphynxCat) | |
assert CATS.get('SphynxCat').__name__ == 'SphynxCat' | |
with pytest.warns(DeprecationWarning): | |
CATS.register_module(SphynxCat, force=True) | |
assert CATS.get('SphynxCat').__name__ == 'SphynxCat' | |
with pytest.warns(DeprecationWarning): | |
class NewCat: | |
pass | |
assert CATS.get('NewCat').__name__ == 'NewCat' | |
with pytest.warns(DeprecationWarning): | |
CATS.deprecated_register_module(SphynxCat, force=True) | |
assert CATS.get('SphynxCat').__name__ == 'SphynxCat' | |
with pytest.warns(DeprecationWarning): | |
class CuteCat: | |
pass | |
assert CATS.get('CuteCat').__name__ == 'CuteCat' | |
with pytest.warns(DeprecationWarning): | |
class NewCat2: | |
pass | |
assert CATS.get('NewCat2').__name__ == 'NewCat2' | |
# end: test old APIs | |
def test_multi_scope_registry(): | |
DOGS = mmcv.Registry('dogs') | |
assert DOGS.name == 'dogs' | |
assert DOGS.scope == 'test_registry' | |
assert DOGS.module_dict == {} | |
assert len(DOGS) == 0 | |
class GoldenRetriever: | |
pass | |
assert len(DOGS) == 1 | |
assert DOGS.get('GoldenRetriever') is GoldenRetriever | |
HOUNDS = mmcv.Registry('dogs', parent=DOGS, scope='hound') | |
class BloodHound: | |
pass | |
assert len(HOUNDS) == 1 | |
assert HOUNDS.get('BloodHound') is BloodHound | |
assert DOGS.get('hound.BloodHound') is BloodHound | |
assert HOUNDS.get('hound.BloodHound') is BloodHound | |
LITTLE_HOUNDS = mmcv.Registry('dogs', parent=HOUNDS, scope='little_hound') | |
class Dachshund: | |
pass | |
assert len(LITTLE_HOUNDS) == 1 | |
assert LITTLE_HOUNDS.get('Dachshund') is Dachshund | |
assert LITTLE_HOUNDS.get('hound.BloodHound') is BloodHound | |
assert HOUNDS.get('little_hound.Dachshund') is Dachshund | |
assert DOGS.get('hound.little_hound.Dachshund') is Dachshund | |
MID_HOUNDS = mmcv.Registry('dogs', parent=HOUNDS, scope='mid_hound') | |
class Beagle: | |
pass | |
assert MID_HOUNDS.get('Beagle') is Beagle | |
assert HOUNDS.get('mid_hound.Beagle') is Beagle | |
assert DOGS.get('hound.mid_hound.Beagle') is Beagle | |
assert LITTLE_HOUNDS.get('hound.mid_hound.Beagle') is Beagle | |
assert MID_HOUNDS.get('hound.BloodHound') is BloodHound | |
assert MID_HOUNDS.get('hound.Dachshund') is None | |
def test_build_from_cfg(): | |
BACKBONES = mmcv.Registry('backbone') | |
class ResNet: | |
def __init__(self, depth, stages=4): | |
self.depth = depth | |
self.stages = stages | |
class ResNeXt: | |
def __init__(self, depth, stages=4): | |
self.depth = depth | |
self.stages = stages | |
cfg = dict(type='ResNet', depth=50) | |
model = mmcv.build_from_cfg(cfg, BACKBONES) | |
assert isinstance(model, ResNet) | |
assert model.depth == 50 and model.stages == 4 | |
cfg = dict(type='ResNet', depth=50) | |
model = mmcv.build_from_cfg(cfg, BACKBONES, default_args={'stages': 3}) | |
assert isinstance(model, ResNet) | |
assert model.depth == 50 and model.stages == 3 | |
cfg = dict(type='ResNeXt', depth=50, stages=3) | |
model = mmcv.build_from_cfg(cfg, BACKBONES) | |
assert isinstance(model, ResNeXt) | |
assert model.depth == 50 and model.stages == 3 | |
cfg = dict(type=ResNet, depth=50) | |
model = mmcv.build_from_cfg(cfg, BACKBONES) | |
assert isinstance(model, ResNet) | |
assert model.depth == 50 and model.stages == 4 | |
# type defined using default_args | |
cfg = dict(depth=50) | |
model = mmcv.build_from_cfg( | |
cfg, BACKBONES, default_args=dict(type='ResNet')) | |
assert isinstance(model, ResNet) | |
assert model.depth == 50 and model.stages == 4 | |
cfg = dict(depth=50) | |
model = mmcv.build_from_cfg(cfg, BACKBONES, default_args=dict(type=ResNet)) | |
assert isinstance(model, ResNet) | |
assert model.depth == 50 and model.stages == 4 | |
# not a registry | |
with pytest.raises(TypeError): | |
cfg = dict(type='VGG') | |
model = mmcv.build_from_cfg(cfg, 'BACKBONES') | |
# non-registered class | |
with pytest.raises(KeyError): | |
cfg = dict(type='VGG') | |
model = mmcv.build_from_cfg(cfg, BACKBONES) | |
# default_args must be a dict or None | |
with pytest.raises(TypeError): | |
cfg = dict(type='ResNet', depth=50) | |
model = mmcv.build_from_cfg(cfg, BACKBONES, default_args=1) | |
# cfg['type'] should be a str or class | |
with pytest.raises(TypeError): | |
cfg = dict(type=1000) | |
model = mmcv.build_from_cfg(cfg, BACKBONES) | |
# cfg should contain the key "type" | |
with pytest.raises(KeyError, match='must contain the key "type"'): | |
cfg = dict(depth=50, stages=4) | |
model = mmcv.build_from_cfg(cfg, BACKBONES) | |
# cfg or default_args should contain the key "type" | |
with pytest.raises(KeyError, match='must contain the key "type"'): | |
cfg = dict(depth=50) | |
model = mmcv.build_from_cfg( | |
cfg, BACKBONES, default_args=dict(stages=4)) | |
# incorrect registry type | |
with pytest.raises(TypeError): | |
cfg = dict(type='ResNet', depth=50) | |
model = mmcv.build_from_cfg(cfg, 'BACKBONES') | |
# incorrect default_args type | |
with pytest.raises(TypeError): | |
cfg = dict(type='ResNet', depth=50) | |
model = mmcv.build_from_cfg(cfg, BACKBONES, default_args=0) | |
# incorrect arguments | |
with pytest.raises(TypeError): | |
cfg = dict(type='ResNet', non_existing_arg=50) | |
model = mmcv.build_from_cfg(cfg, BACKBONES) | |