Spaces:
Sleeping
Sleeping
# Copyright (c) OpenMMLab. All rights reserved. | |
import json | |
import os.path as osp | |
import sys | |
import tempfile | |
import unittest.mock as mock | |
from collections import OrderedDict | |
from unittest.mock import MagicMock, patch | |
import pytest | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torch.utils.data import DataLoader, Dataset | |
from mmcv.fileio.file_client import PetrelBackend | |
from mmcv.runner import DistEvalHook as BaseDistEvalHook | |
from mmcv.runner import EpochBasedRunner | |
from mmcv.runner import EvalHook as BaseEvalHook | |
from mmcv.runner import IterBasedRunner | |
from mmcv.utils import get_logger, scandir | |
sys.modules['petrel_client'] = MagicMock() | |
sys.modules['petrel_client.client'] = MagicMock() | |
class ExampleDataset(Dataset): | |
def __init__(self): | |
self.index = 0 | |
self.eval_result = [1, 4, 3, 7, 2, -3, 4, 6] | |
def __getitem__(self, idx): | |
results = dict(x=torch.tensor([1])) | |
return results | |
def __len__(self): | |
return 1 | |
def evaluate(self, results, logger=None): | |
pass | |
class EvalDataset(ExampleDataset): | |
def evaluate(self, results, logger=None): | |
acc = self.eval_result[self.index] | |
output = OrderedDict( | |
acc=acc, index=self.index, score=acc, loss_top=acc) | |
self.index += 1 | |
return output | |
class Model(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.param = nn.Parameter(torch.tensor([1.0])) | |
def forward(self, x, **kwargs): | |
return self.param * x | |
def train_step(self, data_batch, optimizer, **kwargs): | |
return {'loss': torch.sum(self(data_batch['x']))} | |
def val_step(self, data_batch, optimizer, **kwargs): | |
return {'loss': torch.sum(self(data_batch['x']))} | |
def _build_epoch_runner(): | |
model = Model() | |
tmp_dir = tempfile.mkdtemp() | |
runner = EpochBasedRunner( | |
model=model, work_dir=tmp_dir, logger=get_logger('demo')) | |
return runner | |
def _build_iter_runner(): | |
model = Model() | |
tmp_dir = tempfile.mkdtemp() | |
runner = IterBasedRunner( | |
model=model, work_dir=tmp_dir, logger=get_logger('demo')) | |
return runner | |
class EvalHook(BaseEvalHook): | |
_default_greater_keys = ['acc', 'top'] | |
_default_less_keys = ['loss', 'loss_top'] | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
class DistEvalHook(BaseDistEvalHook): | |
greater_keys = ['acc', 'top'] | |
less_keys = ['loss', 'loss_top'] | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
def test_eval_hook(): | |
with pytest.raises(AssertionError): | |
# `save_best` should be a str | |
test_dataset = Model() | |
data_loader = DataLoader(test_dataset) | |
EvalHook(data_loader, save_best=True) | |
with pytest.raises(TypeError): | |
# dataloader must be a pytorch DataLoader | |
test_dataset = Model() | |
data_loader = [DataLoader(test_dataset)] | |
EvalHook(data_loader) | |
with pytest.raises(ValueError): | |
# key_indicator must be valid when rule_map is None | |
test_dataset = ExampleDataset() | |
data_loader = DataLoader(test_dataset) | |
EvalHook(data_loader, save_best='unsupport') | |
with pytest.raises(KeyError): | |
# rule must be in keys of rule_map | |
test_dataset = ExampleDataset() | |
data_loader = DataLoader(test_dataset) | |
EvalHook(data_loader, save_best='auto', rule='unsupport') | |
# if eval_res is an empty dict, print a warning information | |
with pytest.warns(UserWarning) as record_warnings: | |
class _EvalDataset(ExampleDataset): | |
def evaluate(self, results, logger=None): | |
return {} | |
test_dataset = _EvalDataset() | |
data_loader = DataLoader(test_dataset) | |
eval_hook = EvalHook(data_loader, save_best='auto') | |
runner = _build_epoch_runner() | |
runner.register_hook(eval_hook) | |
runner.run([data_loader], [('train', 1)], 1) | |
# Since there will be many warnings thrown, we just need to check if the | |
# expected exceptions are thrown | |
expected_message = ('Since `eval_res` is an empty dict, the behavior to ' | |
'save the best checkpoint will be skipped in this ' | |
'evaluation.') | |
for warning in record_warnings: | |
if str(warning.message) == expected_message: | |
break | |
else: | |
assert False | |
test_dataset = ExampleDataset() | |
loader = DataLoader(test_dataset) | |
model = Model() | |
data_loader = DataLoader(test_dataset) | |
eval_hook = EvalHook(data_loader, save_best=None) | |
with tempfile.TemporaryDirectory() as tmpdir: | |
# total_epochs = 1 | |
logger = get_logger('test_eval') | |
runner = EpochBasedRunner(model=model, work_dir=tmpdir, logger=logger) | |
runner.register_hook(eval_hook) | |
runner.run([loader], [('train', 1)], 1) | |
test_dataset.evaluate.assert_called_with( | |
test_dataset, [torch.tensor([1])], logger=runner.logger) | |
assert runner.meta is None or 'best_score' not in runner.meta[ | |
'hook_msgs'] | |
assert runner.meta is None or 'best_ckpt' not in runner.meta[ | |
'hook_msgs'] | |
# when `save_best` is set to 'auto', first metric will be used. | |
loader = DataLoader(EvalDataset()) | |
model = Model() | |
data_loader = DataLoader(EvalDataset()) | |
eval_hook = EvalHook(data_loader, interval=1, save_best='auto') | |
with tempfile.TemporaryDirectory() as tmpdir: | |
logger = get_logger('test_eval') | |
runner = EpochBasedRunner(model=model, work_dir=tmpdir, logger=logger) | |
runner.register_checkpoint_hook(dict(interval=1)) | |
runner.register_hook(eval_hook) | |
runner.run([loader], [('train', 1)], 8) | |
ckpt_path = osp.join(tmpdir, 'best_acc_epoch_4.pth') | |
assert runner.meta['hook_msgs']['best_ckpt'] == ckpt_path | |
assert osp.exists(ckpt_path) | |
assert runner.meta['hook_msgs']['best_score'] == 7 | |
# total_epochs = 8, return the best acc and corresponding epoch | |
loader = DataLoader(EvalDataset()) | |
model = Model() | |
data_loader = DataLoader(EvalDataset()) | |
eval_hook = EvalHook(data_loader, interval=1, save_best='acc') | |
with tempfile.TemporaryDirectory() as tmpdir: | |
logger = get_logger('test_eval') | |
runner = EpochBasedRunner(model=model, work_dir=tmpdir, logger=logger) | |
runner.register_checkpoint_hook(dict(interval=1)) | |
runner.register_hook(eval_hook) | |
runner.run([loader], [('train', 1)], 8) | |
ckpt_path = osp.join(tmpdir, 'best_acc_epoch_4.pth') | |
assert runner.meta['hook_msgs']['best_ckpt'] == ckpt_path | |
assert osp.exists(ckpt_path) | |
assert runner.meta['hook_msgs']['best_score'] == 7 | |
# total_epochs = 8, return the best loss_top and corresponding epoch | |
loader = DataLoader(EvalDataset()) | |
model = Model() | |
data_loader = DataLoader(EvalDataset()) | |
eval_hook = EvalHook(data_loader, interval=1, save_best='loss_top') | |
with tempfile.TemporaryDirectory() as tmpdir: | |
logger = get_logger('test_eval') | |
runner = EpochBasedRunner(model=model, work_dir=tmpdir, logger=logger) | |
runner.register_checkpoint_hook(dict(interval=1)) | |
runner.register_hook(eval_hook) | |
runner.run([loader], [('train', 1)], 8) | |
ckpt_path = osp.join(tmpdir, 'best_loss_top_epoch_6.pth') | |
assert runner.meta['hook_msgs']['best_ckpt'] == ckpt_path | |
assert osp.exists(ckpt_path) | |
assert runner.meta['hook_msgs']['best_score'] == -3 | |
# total_epochs = 8, return the best score and corresponding epoch | |
data_loader = DataLoader(EvalDataset()) | |
eval_hook = EvalHook( | |
data_loader, interval=1, save_best='score', rule='greater') | |
with tempfile.TemporaryDirectory() as tmpdir: | |
logger = get_logger('test_eval') | |
runner = EpochBasedRunner(model=model, work_dir=tmpdir, logger=logger) | |
runner.register_checkpoint_hook(dict(interval=1)) | |
runner.register_hook(eval_hook) | |
runner.run([loader], [('train', 1)], 8) | |
ckpt_path = osp.join(tmpdir, 'best_score_epoch_4.pth') | |
assert runner.meta['hook_msgs']['best_ckpt'] == ckpt_path | |
assert osp.exists(ckpt_path) | |
assert runner.meta['hook_msgs']['best_score'] == 7 | |
# total_epochs = 8, return the best score using less compare func | |
# and indicate corresponding epoch | |
data_loader = DataLoader(EvalDataset()) | |
eval_hook = EvalHook(data_loader, save_best='acc', rule='less') | |
with tempfile.TemporaryDirectory() as tmpdir: | |
logger = get_logger('test_eval') | |
runner = EpochBasedRunner(model=model, work_dir=tmpdir, logger=logger) | |
runner.register_checkpoint_hook(dict(interval=1)) | |
runner.register_hook(eval_hook) | |
runner.run([loader], [('train', 1)], 8) | |
ckpt_path = osp.join(tmpdir, 'best_acc_epoch_6.pth') | |
assert runner.meta['hook_msgs']['best_ckpt'] == ckpt_path | |
assert osp.exists(ckpt_path) | |
assert runner.meta['hook_msgs']['best_score'] == -3 | |
# Test the EvalHook when resume happened | |
data_loader = DataLoader(EvalDataset()) | |
eval_hook = EvalHook(data_loader, save_best='acc') | |
with tempfile.TemporaryDirectory() as tmpdir: | |
logger = get_logger('test_eval') | |
runner = EpochBasedRunner(model=model, work_dir=tmpdir, logger=logger) | |
runner.register_checkpoint_hook(dict(interval=1)) | |
runner.register_hook(eval_hook) | |
runner.run([loader], [('train', 1)], 2) | |
old_ckpt_path = osp.join(tmpdir, 'best_acc_epoch_2.pth') | |
assert runner.meta['hook_msgs']['best_ckpt'] == old_ckpt_path | |
assert osp.exists(old_ckpt_path) | |
assert runner.meta['hook_msgs']['best_score'] == 4 | |
resume_from = old_ckpt_path | |
loader = DataLoader(ExampleDataset()) | |
eval_hook = EvalHook(data_loader, save_best='acc') | |
runner = EpochBasedRunner(model=model, work_dir=tmpdir, logger=logger) | |
runner.register_checkpoint_hook(dict(interval=1)) | |
runner.register_hook(eval_hook) | |
runner.resume(resume_from) | |
assert runner.meta['hook_msgs']['best_ckpt'] == old_ckpt_path | |
assert osp.exists(old_ckpt_path) | |
assert runner.meta['hook_msgs']['best_score'] == 4 | |
runner.run([loader], [('train', 1)], 8) | |
ckpt_path = osp.join(tmpdir, 'best_acc_epoch_4.pth') | |
assert runner.meta['hook_msgs']['best_ckpt'] == ckpt_path | |
assert osp.exists(ckpt_path) | |
assert runner.meta['hook_msgs']['best_score'] == 7 | |
assert not osp.exists(old_ckpt_path) | |
# test EvalHook with customer test_fn and greater/less keys | |
loader = DataLoader(EvalDataset()) | |
model = Model() | |
data_loader = DataLoader(EvalDataset()) | |
eval_hook = EvalHook( | |
data_loader, | |
save_best='acc', | |
test_fn=mock.MagicMock(return_value={}), | |
greater_keys=[], | |
less_keys=['acc']) | |
with tempfile.TemporaryDirectory() as tmpdir: | |
logger = get_logger('test_eval') | |
runner = EpochBasedRunner(model=model, work_dir=tmpdir, logger=logger) | |
runner.register_checkpoint_hook(dict(interval=1)) | |
runner.register_hook(eval_hook) | |
runner.run([loader], [('train', 1)], 8) | |
ckpt_path = osp.join(tmpdir, 'best_acc_epoch_6.pth') | |
assert runner.meta['hook_msgs']['best_ckpt'] == ckpt_path | |
assert osp.exists(ckpt_path) | |
assert runner.meta['hook_msgs']['best_score'] == -3 | |
# test EvalHook with specified `out_dir` | |
loader = DataLoader(EvalDataset()) | |
model = Model() | |
data_loader = DataLoader(EvalDataset()) | |
out_dir = 's3://user/data' | |
eval_hook = EvalHook( | |
data_loader, interval=1, save_best='auto', out_dir=out_dir) | |
with patch.object(PetrelBackend, 'put') as mock_put, \ | |
patch.object(PetrelBackend, 'remove') as mock_remove, \ | |
patch.object(PetrelBackend, 'isfile') as mock_isfile, \ | |
tempfile.TemporaryDirectory() as tmpdir: | |
logger = get_logger('test_eval') | |
runner = EpochBasedRunner(model=model, work_dir=tmpdir, logger=logger) | |
runner.register_checkpoint_hook(dict(interval=1)) | |
runner.register_hook(eval_hook) | |
runner.run([loader], [('train', 1)], 8) | |
basename = osp.basename(runner.work_dir.rstrip(osp.sep)) | |
ckpt_path = f'{out_dir}/{basename}/best_acc_epoch_4.pth' | |
assert runner.meta['hook_msgs']['best_ckpt'] == ckpt_path | |
assert runner.meta['hook_msgs']['best_score'] == 7 | |
assert mock_put.call_count == 3 | |
assert mock_remove.call_count == 2 | |
assert mock_isfile.call_count == 2 | |
def test_start_param(EvalHookParam, _build_demo_runner, by_epoch): | |
# create dummy data | |
dataloader = DataLoader(EvalDataset()) | |
# 0.1. dataloader is not a DataLoader object | |
with pytest.raises(TypeError): | |
EvalHookParam(dataloader=MagicMock(), interval=-1) | |
# 0.2. negative interval | |
with pytest.raises(ValueError): | |
EvalHookParam(dataloader, interval=-1) | |
# 0.3. negative start | |
with pytest.raises(ValueError): | |
EvalHookParam(dataloader, start=-1) | |
# 1. start=None, interval=1: perform evaluation after each epoch. | |
runner = _build_demo_runner() | |
evalhook = EvalHookParam(dataloader, interval=1, by_epoch=by_epoch) | |
evalhook.evaluate = MagicMock() | |
runner.register_hook(evalhook) | |
runner.run([dataloader], [('train', 1)], 2) | |
assert evalhook.evaluate.call_count == 2 # after epoch 1 & 2 | |
# 2. start=1, interval=1: perform evaluation after each epoch. | |
runner = _build_demo_runner() | |
evalhook = EvalHookParam( | |
dataloader, start=1, interval=1, by_epoch=by_epoch) | |
evalhook.evaluate = MagicMock() | |
runner.register_hook(evalhook) | |
runner.run([dataloader], [('train', 1)], 2) | |
assert evalhook.evaluate.call_count == 2 # after epoch 1 & 2 | |
# 3. start=None, interval=2: perform evaluation after epoch 2, 4, 6, etc | |
runner = _build_demo_runner() | |
evalhook = EvalHookParam(dataloader, interval=2, by_epoch=by_epoch) | |
evalhook.evaluate = MagicMock() | |
runner.register_hook(evalhook) | |
runner.run([dataloader], [('train', 1)], 2) | |
assert evalhook.evaluate.call_count == 1 # after epoch 2 | |
# 4. start=1, interval=2: perform evaluation after epoch 1, 3, 5, etc | |
runner = _build_demo_runner() | |
evalhook = EvalHookParam( | |
dataloader, start=1, interval=2, by_epoch=by_epoch) | |
evalhook.evaluate = MagicMock() | |
runner.register_hook(evalhook) | |
runner.run([dataloader], [('train', 1)], 3) | |
assert evalhook.evaluate.call_count == 2 # after epoch 1 & 3 | |
# 5. start=0, interval=1: perform evaluation after each epoch and | |
# before epoch 1. | |
runner = _build_demo_runner() | |
evalhook = EvalHookParam(dataloader, start=0, by_epoch=by_epoch) | |
evalhook.evaluate = MagicMock() | |
runner.register_hook(evalhook) | |
runner.run([dataloader], [('train', 1)], 2) | |
assert evalhook.evaluate.call_count == 3 # before epoch1 and after e1 & e2 | |
# 6. resuming from epoch i, start = x (x<=i), interval =1: perform | |
# evaluation after each epoch and before the first epoch. | |
runner = _build_demo_runner() | |
evalhook = EvalHookParam(dataloader, start=1, by_epoch=by_epoch) | |
evalhook.evaluate = MagicMock() | |
runner.register_hook(evalhook) | |
if by_epoch: | |
runner._epoch = 2 | |
else: | |
runner._iter = 2 | |
runner.run([dataloader], [('train', 1)], 3) | |
assert evalhook.evaluate.call_count == 2 # before & after epoch 3 | |
# 7. resuming from epoch i, start = i+1/None, interval =1: perform | |
# evaluation after each epoch. | |
runner = _build_demo_runner() | |
evalhook = EvalHookParam(dataloader, start=2, by_epoch=by_epoch) | |
evalhook.evaluate = MagicMock() | |
runner.register_hook(evalhook) | |
if by_epoch: | |
runner._epoch = 1 | |
else: | |
runner._iter = 1 | |
runner.run([dataloader], [('train', 1)], 3) | |
assert evalhook.evaluate.call_count == 2 # after epoch 2 & 3 | |
def test_logger(runner, by_epoch, eval_hook_priority): | |
loader = DataLoader(EvalDataset()) | |
model = Model() | |
data_loader = DataLoader(EvalDataset()) | |
eval_hook = EvalHook( | |
data_loader, interval=1, by_epoch=by_epoch, save_best='acc') | |
with tempfile.TemporaryDirectory() as tmpdir: | |
logger = get_logger('test_logger') | |
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) | |
runner = EpochBasedRunner( | |
model=model, optimizer=optimizer, work_dir=tmpdir, logger=logger) | |
runner.register_logger_hooks( | |
dict( | |
interval=1, | |
hooks=[dict(type='TextLoggerHook', by_epoch=by_epoch)])) | |
runner.register_timer_hook(dict(type='IterTimerHook')) | |
runner.register_hook(eval_hook, priority=eval_hook_priority) | |
runner.run([loader], [('train', 1)], 1) | |
path = osp.join(tmpdir, next(scandir(tmpdir, '.json'))) | |
with open(path) as fr: | |
fr.readline() # skip the first line which is `hook_msg` | |
train_log = json.loads(fr.readline()) | |
assert train_log['mode'] == 'train' and 'time' in train_log | |
val_log = json.loads(fr.readline()) | |
assert val_log['mode'] == 'val' and 'time' not in val_log | |