Spaces:
Sleeping
Sleeping
File size: 8,820 Bytes
d7e58f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 |
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import mmcv
def test_registry():
CATS = mmcv.Registry('cat')
assert CATS.name == 'cat'
assert CATS.module_dict == {}
assert len(CATS) == 0
@CATS.register_module()
class BritishShorthair:
pass
assert len(CATS) == 1
assert CATS.get('BritishShorthair') is BritishShorthair
class Munchkin:
pass
CATS.register_module(Munchkin)
assert len(CATS) == 2
assert CATS.get('Munchkin') is Munchkin
assert 'Munchkin' in CATS
with pytest.raises(KeyError):
CATS.register_module(Munchkin)
CATS.register_module(Munchkin, force=True)
assert len(CATS) == 2
# force=False
with pytest.raises(KeyError):
@CATS.register_module()
class BritishShorthair:
pass
@CATS.register_module(force=True)
class BritishShorthair:
pass
assert len(CATS) == 2
assert CATS.get('PersianCat') is None
assert 'PersianCat' not in CATS
@CATS.register_module(name=['Siamese', 'Siamese2'])
class SiameseCat:
pass
assert CATS.get('Siamese').__name__ == 'SiameseCat'
assert CATS.get('Siamese2').__name__ == 'SiameseCat'
class SphynxCat:
pass
CATS.register_module(name='Sphynx', module=SphynxCat)
assert CATS.get('Sphynx') is SphynxCat
CATS.register_module(name=['Sphynx1', 'Sphynx2'], module=SphynxCat)
assert CATS.get('Sphynx2') is SphynxCat
repr_str = 'Registry(name=cat, items={'
repr_str += ("'BritishShorthair': <class 'test_registry.test_registry."
"<locals>.BritishShorthair'>, ")
repr_str += ("'Munchkin': <class 'test_registry.test_registry."
"<locals>.Munchkin'>, ")
repr_str += ("'Siamese': <class 'test_registry.test_registry."
"<locals>.SiameseCat'>, ")
repr_str += ("'Siamese2': <class 'test_registry.test_registry."
"<locals>.SiameseCat'>, ")
repr_str += ("'Sphynx': <class 'test_registry.test_registry."
"<locals>.SphynxCat'>, ")
repr_str += ("'Sphynx1': <class 'test_registry.test_registry."
"<locals>.SphynxCat'>, ")
repr_str += ("'Sphynx2': <class 'test_registry.test_registry."
"<locals>.SphynxCat'>")
repr_str += '})'
assert repr(CATS) == repr_str
# name type
with pytest.raises(TypeError):
CATS.register_module(name=7474741, module=SphynxCat)
# the registered module should be a class
with pytest.raises(TypeError):
CATS.register_module(0)
@CATS.register_module()
def muchkin():
pass
assert CATS.get('muchkin') is muchkin
assert 'muchkin' in CATS
# can only decorate a class or a function
with pytest.raises(TypeError):
class Demo:
def some_method(self):
pass
method = Demo().some_method
CATS.register_module(name='some_method', module=method)
# begin: test old APIs
with pytest.warns(DeprecationWarning):
CATS.register_module(SphynxCat)
assert CATS.get('SphynxCat').__name__ == 'SphynxCat'
with pytest.warns(DeprecationWarning):
CATS.register_module(SphynxCat, force=True)
assert CATS.get('SphynxCat').__name__ == 'SphynxCat'
with pytest.warns(DeprecationWarning):
@CATS.register_module
class NewCat:
pass
assert CATS.get('NewCat').__name__ == 'NewCat'
with pytest.warns(DeprecationWarning):
CATS.deprecated_register_module(SphynxCat, force=True)
assert CATS.get('SphynxCat').__name__ == 'SphynxCat'
with pytest.warns(DeprecationWarning):
@CATS.deprecated_register_module
class CuteCat:
pass
assert CATS.get('CuteCat').__name__ == 'CuteCat'
with pytest.warns(DeprecationWarning):
@CATS.deprecated_register_module(force=True)
class NewCat2:
pass
assert CATS.get('NewCat2').__name__ == 'NewCat2'
# end: test old APIs
def test_multi_scope_registry():
DOGS = mmcv.Registry('dogs')
assert DOGS.name == 'dogs'
assert DOGS.scope == 'test_registry'
assert DOGS.module_dict == {}
assert len(DOGS) == 0
@DOGS.register_module()
class GoldenRetriever:
pass
assert len(DOGS) == 1
assert DOGS.get('GoldenRetriever') is GoldenRetriever
HOUNDS = mmcv.Registry('dogs', parent=DOGS, scope='hound')
@HOUNDS.register_module()
class BloodHound:
pass
assert len(HOUNDS) == 1
assert HOUNDS.get('BloodHound') is BloodHound
assert DOGS.get('hound.BloodHound') is BloodHound
assert HOUNDS.get('hound.BloodHound') is BloodHound
LITTLE_HOUNDS = mmcv.Registry('dogs', parent=HOUNDS, scope='little_hound')
@LITTLE_HOUNDS.register_module()
class Dachshund:
pass
assert len(LITTLE_HOUNDS) == 1
assert LITTLE_HOUNDS.get('Dachshund') is Dachshund
assert LITTLE_HOUNDS.get('hound.BloodHound') is BloodHound
assert HOUNDS.get('little_hound.Dachshund') is Dachshund
assert DOGS.get('hound.little_hound.Dachshund') is Dachshund
MID_HOUNDS = mmcv.Registry('dogs', parent=HOUNDS, scope='mid_hound')
@MID_HOUNDS.register_module()
class Beagle:
pass
assert MID_HOUNDS.get('Beagle') is Beagle
assert HOUNDS.get('mid_hound.Beagle') is Beagle
assert DOGS.get('hound.mid_hound.Beagle') is Beagle
assert LITTLE_HOUNDS.get('hound.mid_hound.Beagle') is Beagle
assert MID_HOUNDS.get('hound.BloodHound') is BloodHound
assert MID_HOUNDS.get('hound.Dachshund') is None
def test_build_from_cfg():
BACKBONES = mmcv.Registry('backbone')
@BACKBONES.register_module()
class ResNet:
def __init__(self, depth, stages=4):
self.depth = depth
self.stages = stages
@BACKBONES.register_module()
class ResNeXt:
def __init__(self, depth, stages=4):
self.depth = depth
self.stages = stages
cfg = dict(type='ResNet', depth=50)
model = mmcv.build_from_cfg(cfg, BACKBONES)
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
cfg = dict(type='ResNet', depth=50)
model = mmcv.build_from_cfg(cfg, BACKBONES, default_args={'stages': 3})
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 3
cfg = dict(type='ResNeXt', depth=50, stages=3)
model = mmcv.build_from_cfg(cfg, BACKBONES)
assert isinstance(model, ResNeXt)
assert model.depth == 50 and model.stages == 3
cfg = dict(type=ResNet, depth=50)
model = mmcv.build_from_cfg(cfg, BACKBONES)
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
# type defined using default_args
cfg = dict(depth=50)
model = mmcv.build_from_cfg(
cfg, BACKBONES, default_args=dict(type='ResNet'))
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
cfg = dict(depth=50)
model = mmcv.build_from_cfg(cfg, BACKBONES, default_args=dict(type=ResNet))
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
# not a registry
with pytest.raises(TypeError):
cfg = dict(type='VGG')
model = mmcv.build_from_cfg(cfg, 'BACKBONES')
# non-registered class
with pytest.raises(KeyError):
cfg = dict(type='VGG')
model = mmcv.build_from_cfg(cfg, BACKBONES)
# default_args must be a dict or None
with pytest.raises(TypeError):
cfg = dict(type='ResNet', depth=50)
model = mmcv.build_from_cfg(cfg, BACKBONES, default_args=1)
# cfg['type'] should be a str or class
with pytest.raises(TypeError):
cfg = dict(type=1000)
model = mmcv.build_from_cfg(cfg, BACKBONES)
# cfg should contain the key "type"
with pytest.raises(KeyError, match='must contain the key "type"'):
cfg = dict(depth=50, stages=4)
model = mmcv.build_from_cfg(cfg, BACKBONES)
# cfg or default_args should contain the key "type"
with pytest.raises(KeyError, match='must contain the key "type"'):
cfg = dict(depth=50)
model = mmcv.build_from_cfg(
cfg, BACKBONES, default_args=dict(stages=4))
# incorrect registry type
with pytest.raises(TypeError):
cfg = dict(type='ResNet', depth=50)
model = mmcv.build_from_cfg(cfg, 'BACKBONES')
# incorrect default_args type
with pytest.raises(TypeError):
cfg = dict(type='ResNet', depth=50)
model = mmcv.build_from_cfg(cfg, BACKBONES, default_args=0)
# incorrect arguments
with pytest.raises(TypeError):
cfg = dict(type='ResNet', non_existing_arg=50)
model = mmcv.build_from_cfg(cfg, BACKBONES)
|