Spaces:
Starting
on
L40S
Starting
on
L40S
# Copyright (c) OpenMMLab. All rights reserved. | |
import pytest | |
import torch | |
from torch.utils import model_zoo | |
from mmcv.utils import TORCH_VERSION, digit_version, load_url | |
def test_load_url(): | |
url1 = 'https://download.openmmlab.com/mmcv/test_data/saved_in_pt1.5.pth' | |
url2 = 'https://download.openmmlab.com/mmcv/test_data/saved_in_pt1.6.pth' | |
# The 1.6 release of PyTorch switched torch.save to use a new zipfile-based | |
# file format. It will cause RuntimeError when a checkpoint was saved in | |
# torch >= 1.6.0 but loaded in torch < 1.7.0. | |
# More details at https://github.com/open-mmlab/mmpose/issues/904 | |
if digit_version(TORCH_VERSION) < digit_version('1.7.0'): | |
model_zoo.load_url(url1) | |
with pytest.raises(RuntimeError): | |
model_zoo.load_url(url2) | |
else: | |
# high version of PyTorch can load checkpoints from url, regardless | |
# of which version they were saved in | |
model_zoo.load_url(url1) | |
model_zoo.load_url(url2) | |
load_url(url1) | |
# if a checkpoint was saved in torch >= 1.6.0 but loaded in torch < 1.5.0, | |
# it will raise a RuntimeError | |
if digit_version(TORCH_VERSION) < digit_version('1.5.0'): | |
with pytest.raises(RuntimeError): | |
load_url(url2) | |
else: | |
load_url(url2) | |