Spaces:
Sleeping
Sleeping
File size: 5,311 Bytes
d7e58f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
# Copyright (c) OpenMMLab. All rights reserved.
import tempfile
import warnings
from mmcv.runner import DistEvalHook as BaseDistEvalHook
from mmcv.runner import EvalHook as BaseEvalHook
MMHUMAN3D_GREATER_KEYS = ['3dpck', 'pa-3dpck', '3dauc', 'pa-3dauc']
MMHUMAN3D_LESS_KEYS = ['mpjpe', 'pa-mpjpe', 'pve']
class EvalHook(BaseEvalHook):
def __init__(self,
dataloader,
start=None,
interval=1,
by_epoch=True,
save_best=None,
rule=None,
test_fn=None,
greater_keys=MMHUMAN3D_GREATER_KEYS,
less_keys=MMHUMAN3D_LESS_KEYS,
**eval_kwargs):
if test_fn is None:
from detrsmpl.apis import single_gpu_test
test_fn = single_gpu_test
# remove "gpu_collect" from eval_kwargs
if 'gpu_collect' in eval_kwargs:
warnings.warn(
'"gpu_collect" will be deprecated in EvalHook.'
'Please remove it from the config.', DeprecationWarning)
_ = eval_kwargs.pop('gpu_collect')
# update "save_best" according to "key_indicator" and remove the
# latter from eval_kwargs
if 'key_indicator' in eval_kwargs or isinstance(save_best, bool):
warnings.warn(
'"key_indicator" will be deprecated in EvalHook.'
'Please use "save_best" to specify the metric key,'
'e.g., save_best="pa-mpjpe".', DeprecationWarning)
key_indicator = eval_kwargs.pop('key_indicator', None)
if save_best is True and key_indicator is None:
raise ValueError('key_indicator should not be None, when '
'save_best is set to True.')
save_best = key_indicator
super().__init__(dataloader, start, interval, by_epoch, save_best,
rule, test_fn, greater_keys, less_keys, **eval_kwargs)
def evaluate(self, runner, results):
with tempfile.TemporaryDirectory() as tmp_dir:
eval_res = self.dataloader.dataset.evaluate(results,
res_folder=tmp_dir,
logger=runner.logger,
**self.eval_kwargs)
for name, val in eval_res.items():
runner.log_buffer.output[name] = val
runner.log_buffer.ready = True
if self.save_best is not None:
if self.key_indicator == 'auto':
self._init_rule(self.rule, list(eval_res.keys())[0])
return eval_res[self.key_indicator]
return None
class DistEvalHook(BaseDistEvalHook):
def __init__(self,
dataloader,
start=None,
interval=1,
by_epoch=True,
save_best=None,
rule=None,
test_fn=None,
greater_keys=MMHUMAN3D_GREATER_KEYS,
less_keys=MMHUMAN3D_LESS_KEYS,
broadcast_bn_buffer=True,
tmpdir=None,
gpu_collect=False,
**eval_kwargs):
if test_fn is None:
from detrsmpl.apis import multi_gpu_test
test_fn = multi_gpu_test
# update "save_best" according to "key_indicator" and remove the
# latter from eval_kwargs
if 'key_indicator' in eval_kwargs or isinstance(save_best, bool):
warnings.warn(
'"key_indicator" will be deprecated in EvalHook.'
'Please use "save_best" to specify the metric key,'
'e.g., save_best="pa-mpjpe".', DeprecationWarning)
key_indicator = eval_kwargs.pop('key_indicator', None)
if save_best is True and key_indicator is None:
raise ValueError('key_indicator should not be None, when '
'save_best is set to True.')
save_best = key_indicator
super().__init__(dataloader, start, interval, by_epoch, save_best,
rule, test_fn, greater_keys, less_keys,
broadcast_bn_buffer, tmpdir, gpu_collect,
**eval_kwargs)
def evaluate(self, runner, results):
"""Evaluate the results.
Args:
runner (:obj:`mmcv.Runner`): The underlined training runner.
results (list): Output results.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
eval_res = self.dataloader.dataset.evaluate(results,
res_folder=tmp_dir,
logger=runner.logger,
**self.eval_kwargs)
for name, val in eval_res.items():
runner.log_buffer.output[name] = val
runner.log_buffer.ready = True
if self.save_best is not None:
if self.key_indicator == 'auto':
# infer from eval_results
self._init_rule(self.rule, list(eval_res.keys())[0])
return eval_res[self.key_indicator]
return None
|