File size: 1,924 Bytes
18dd6ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
from .builder import RUNNER_BUILDERS, RUNNERS
@RUNNER_BUILDERS.register_module()
class DefaultRunnerConstructor:
"""Default constructor for runners.
Custom existing `Runner` like `EpocBasedRunner` though `RunnerConstructor`.
For example, We can inject some new properties and functions for `Runner`.
Example:
>>> from annotator.mmpkg.mmcv.runner import RUNNER_BUILDERS, build_runner
>>> # Define a new RunnerReconstructor
>>> @RUNNER_BUILDERS.register_module()
>>> class MyRunnerConstructor:
... def __init__(self, runner_cfg, default_args=None):
... if not isinstance(runner_cfg, dict):
... raise TypeError('runner_cfg should be a dict',
... f'but got {type(runner_cfg)}')
... self.runner_cfg = runner_cfg
... self.default_args = default_args
...
... def __call__(self):
... runner = RUNNERS.build(self.runner_cfg,
... default_args=self.default_args)
... # Add new properties for existing runner
... runner.my_name = 'my_runner'
... runner.my_function = lambda self: print(self.my_name)
... ...
>>> # build your runner
>>> runner_cfg = dict(type='EpochBasedRunner', max_epochs=40,
... constructor='MyRunnerConstructor')
>>> runner = build_runner(runner_cfg)
"""
def __init__(self, runner_cfg, default_args=None):
if not isinstance(runner_cfg, dict):
raise TypeError('runner_cfg should be a dict',
f'but got {type(runner_cfg)}')
self.runner_cfg = runner_cfg
self.default_args = default_args
def __call__(self):
return RUNNERS.build(self.runner_cfg, default_args=self.default_args)
|