|
|
|
import copy |
|
import warnings |
|
from abc import ABCMeta |
|
from collections import defaultdict |
|
from logging import FileHandler |
|
|
|
import torch.nn as nn |
|
|
|
from annotator.mmpkg.mmcv.runner.dist_utils import master_only |
|
from annotator.mmpkg.mmcv.utils.logging import get_logger, logger_initialized, print_log |
|
|
|
|
|
class BaseModule(nn.Module, metaclass=ABCMeta): |
|
"""Base module for all modules in openmmlab. |
|
|
|
``BaseModule`` is a wrapper of ``torch.nn.Module`` with additional |
|
functionality of parameter initialization. Compared with |
|
``torch.nn.Module``, ``BaseModule`` mainly adds three attributes. |
|
|
|
- ``init_cfg``: the config to control the initialization. |
|
- ``init_weights``: The function of parameter |
|
initialization and recording initialization |
|
information. |
|
- ``_params_init_info``: Used to track the parameter |
|
initialization information. This attribute only |
|
exists during executing the ``init_weights``. |
|
|
|
Args: |
|
init_cfg (dict, optional): Initialization config dict. |
|
""" |
|
|
|
def __init__(self, init_cfg=None): |
|
"""Initialize BaseModule, inherited from `torch.nn.Module`""" |
|
|
|
|
|
|
|
|
|
super(BaseModule, self).__init__() |
|
|
|
|
|
self._is_init = False |
|
|
|
self.init_cfg = copy.deepcopy(init_cfg) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@property |
|
def is_init(self): |
|
return self._is_init |
|
|
|
def init_weights(self): |
|
"""Initialize the weights.""" |
|
|
|
is_top_level_module = False |
|
|
|
if not hasattr(self, '_params_init_info'): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._params_init_info = defaultdict(dict) |
|
is_top_level_module = True |
|
|
|
|
|
|
|
|
|
|
|
for name, param in self.named_parameters(): |
|
self._params_init_info[param][ |
|
'init_info'] = f'The value is the same before and ' \ |
|
f'after calling `init_weights` ' \ |
|
f'of {self.__class__.__name__} ' |
|
self._params_init_info[param][ |
|
'tmp_mean_value'] = param.data.mean() |
|
|
|
|
|
|
|
|
|
|
|
for sub_module in self.modules(): |
|
sub_module._params_init_info = self._params_init_info |
|
|
|
|
|
|
|
logger_names = list(logger_initialized.keys()) |
|
logger_name = logger_names[0] if logger_names else 'mmcv' |
|
|
|
from ..cnn import initialize |
|
from ..cnn.utils.weight_init import update_init_info |
|
module_name = self.__class__.__name__ |
|
if not self._is_init: |
|
if self.init_cfg: |
|
print_log( |
|
f'initialize {module_name} with init_cfg {self.init_cfg}', |
|
logger=logger_name) |
|
initialize(self, self.init_cfg) |
|
if isinstance(self.init_cfg, dict): |
|
|
|
|
|
|
|
|
|
if self.init_cfg['type'] == 'Pretrained': |
|
return |
|
|
|
for m in self.children(): |
|
if hasattr(m, 'init_weights'): |
|
m.init_weights() |
|
|
|
update_init_info( |
|
m, |
|
init_info=f'Initialized by ' |
|
f'user-defined `init_weights`' |
|
f' in {m.__class__.__name__} ') |
|
|
|
self._is_init = True |
|
else: |
|
warnings.warn(f'init_weights of {self.__class__.__name__} has ' |
|
f'been called more than once.') |
|
|
|
if is_top_level_module: |
|
self._dump_init_info(logger_name) |
|
|
|
for sub_module in self.modules(): |
|
del sub_module._params_init_info |
|
|
|
@master_only |
|
def _dump_init_info(self, logger_name): |
|
"""Dump the initialization information to a file named |
|
`initialization.log.json` in workdir. |
|
|
|
Args: |
|
logger_name (str): The name of logger. |
|
""" |
|
|
|
logger = get_logger(logger_name) |
|
|
|
with_file_handler = False |
|
|
|
for handler in logger.handlers: |
|
if isinstance(handler, FileHandler): |
|
handler.stream.write( |
|
'Name of parameter - Initialization information\n') |
|
for name, param in self.named_parameters(): |
|
handler.stream.write( |
|
f'\n{name} - {param.shape}: ' |
|
f"\n{self._params_init_info[param]['init_info']} \n") |
|
handler.stream.flush() |
|
with_file_handler = True |
|
if not with_file_handler: |
|
for name, param in self.named_parameters(): |
|
print_log( |
|
f'\n{name} - {param.shape}: ' |
|
f"\n{self._params_init_info[param]['init_info']} \n ", |
|
logger=logger_name) |
|
|
|
def __repr__(self): |
|
s = super().__repr__() |
|
if self.init_cfg: |
|
s += f'\ninit_cfg={self.init_cfg}' |
|
return s |
|
|
|
|
|
class Sequential(BaseModule, nn.Sequential): |
|
"""Sequential module in openmmlab. |
|
|
|
Args: |
|
init_cfg (dict, optional): Initialization config dict. |
|
""" |
|
|
|
def __init__(self, *args, init_cfg=None): |
|
BaseModule.__init__(self, init_cfg) |
|
nn.Sequential.__init__(self, *args) |
|
|
|
|
|
class ModuleList(BaseModule, nn.ModuleList): |
|
"""ModuleList in openmmlab. |
|
|
|
Args: |
|
modules (iterable, optional): an iterable of modules to add. |
|
init_cfg (dict, optional): Initialization config dict. |
|
""" |
|
|
|
def __init__(self, modules=None, init_cfg=None): |
|
BaseModule.__init__(self, init_cfg) |
|
nn.ModuleList.__init__(self, modules) |
|
|