Spaces:
Sleeping
Sleeping
File size: 6,009 Bytes
d7e58f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp
from unittest.mock import patch
import pytest
import torchvision
import mmcv
from mmcv.runner.checkpoint import (DEFAULT_CACHE_DIR, ENV_MMCV_HOME,
ENV_XDG_CACHE_HOME, _get_mmcv_home,
_load_checkpoint,
get_deprecated_model_names,
get_external_models)
from mmcv.utils import digit_version
@patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')])
def test_set_mmcv_home():
os.environ.pop(ENV_MMCV_HOME, None)
mmcv_home = osp.join(osp.dirname(__file__), 'data/model_zoo/mmcv_home/')
os.environ[ENV_MMCV_HOME] = mmcv_home
assert _get_mmcv_home() == mmcv_home
@patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')])
def test_default_mmcv_home():
os.environ.pop(ENV_MMCV_HOME, None)
os.environ.pop(ENV_XDG_CACHE_HOME, None)
assert _get_mmcv_home() == os.path.expanduser(
os.path.join(DEFAULT_CACHE_DIR, 'mmcv'))
model_urls = get_external_models()
assert model_urls == mmcv.load(
osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json'))
@patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')])
def test_get_external_models():
os.environ.pop(ENV_MMCV_HOME, None)
mmcv_home = osp.join(osp.dirname(__file__), 'data/model_zoo/mmcv_home/')
os.environ[ENV_MMCV_HOME] = mmcv_home
ext_urls = get_external_models()
assert ext_urls == {
'train': 'https://localhost/train.pth',
'test': 'test.pth',
'val': 'val.pth',
'train_empty': 'train.pth'
}
@patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')])
def test_get_deprecated_models():
os.environ.pop(ENV_MMCV_HOME, None)
mmcv_home = osp.join(osp.dirname(__file__), 'data/model_zoo/mmcv_home/')
os.environ[ENV_MMCV_HOME] = mmcv_home
dep_urls = get_deprecated_model_names()
assert dep_urls == {
'train_old': 'train',
'test_old': 'test',
}
def load_from_http(url, map_location=None):
return 'url:' + url
def load_url(url, map_location=None, model_dir=None):
return load_from_http(url)
def load(filepath, map_location=None):
return 'local:' + filepath
@patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')])
@patch('mmcv.runner.checkpoint.load_from_http', load_from_http)
@patch('mmcv.runner.checkpoint.load_url', load_url)
@patch('torch.load', load)
def test_load_external_url():
# test modelzoo://
torchvision_version = torchvision.__version__
if digit_version(torchvision_version) < digit_version('0.10.0a0'):
assert (_load_checkpoint('modelzoo://resnet50') ==
'url:https://download.pytorch.org/models/resnet50-19c8e'
'357.pth')
assert (_load_checkpoint('torchvision://resnet50') ==
'url:https://download.pytorch.org/models/resnet50-19c8e'
'357.pth')
else:
assert (_load_checkpoint('modelzoo://resnet50') ==
'url:https://download.pytorch.org/models/resnet50-0676b'
'a61.pth')
assert (_load_checkpoint('torchvision://resnet50') ==
'url:https://download.pytorch.org/models/resnet50-0676b'
'a61.pth')
if digit_version(torchvision_version) >= digit_version('0.13.0a0'):
# Test load new format torchvision models.
assert (
_load_checkpoint('torchvision://resnet50.imagenet1k_v1') ==
'url:https://download.pytorch.org/models/resnet50-0676ba61.pth')
assert (
_load_checkpoint('torchvision://ResNet50_Weights.IMAGENET1K_V1') ==
'url:https://download.pytorch.org/models/resnet50-0676ba61.pth')
_load_checkpoint('torchvision://resnet50.default')
# test open-mmlab:// with default MMCV_HOME
os.environ.pop(ENV_MMCV_HOME, None)
os.environ.pop(ENV_XDG_CACHE_HOME, None)
url = _load_checkpoint('open-mmlab://train')
assert url == 'url:https://localhost/train.pth'
# test open-mmlab:// with deprecated model name
os.environ.pop(ENV_MMCV_HOME, None)
os.environ.pop(ENV_XDG_CACHE_HOME, None)
with pytest.warns(
Warning,
match='open-mmlab://train_old is deprecated in favor of '
'open-mmlab://train'):
url = _load_checkpoint('open-mmlab://train_old')
assert url == 'url:https://localhost/train.pth'
# test openmmlab:// with deprecated model name
os.environ.pop(ENV_MMCV_HOME, None)
os.environ.pop(ENV_XDG_CACHE_HOME, None)
with pytest.warns(
Warning,
match='openmmlab://train_old is deprecated in favor of '
'openmmlab://train'):
url = _load_checkpoint('openmmlab://train_old')
assert url == 'url:https://localhost/train.pth'
# test open-mmlab:// with user-defined MMCV_HOME
os.environ.pop(ENV_MMCV_HOME, None)
mmcv_home = osp.join(osp.dirname(__file__), 'data/model_zoo/mmcv_home')
os.environ[ENV_MMCV_HOME] = mmcv_home
url = _load_checkpoint('open-mmlab://train')
assert url == 'url:https://localhost/train.pth'
with pytest.raises(FileNotFoundError, match='train.pth can not be found.'):
_load_checkpoint('open-mmlab://train_empty')
url = _load_checkpoint('open-mmlab://test')
assert url == f'local:{osp.join(_get_mmcv_home(), "test.pth")}'
url = _load_checkpoint('open-mmlab://val')
assert url == f'local:{osp.join(_get_mmcv_home(), "val.pth")}'
# test http:// https://
url = _load_checkpoint('http://localhost/train.pth')
assert url == 'url:http://localhost/train.pth'
# test local file
with pytest.raises(FileNotFoundError, match='train.pth can not be found.'):
_load_checkpoint('train.pth')
url = _load_checkpoint(osp.join(_get_mmcv_home(), 'test.pth'))
assert url == f'local:{osp.join(_get_mmcv_home(), "test.pth")}'
|