Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import time | |
import torch | |
from collections import OrderedDict | |
from copy import deepcopy | |
from torch.nn.parallel import DataParallel, DistributedDataParallel | |
from basicsr.models import lr_scheduler as lr_scheduler | |
from basicsr.utils import get_root_logger | |
from basicsr.utils.dist_util import master_only | |
class BaseModel(): | |
"""Base model.""" | |
def __init__(self, opt): | |
self.opt = opt | |
self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') | |
self.is_train = opt['is_train'] | |
self.schedulers = [] | |
self.optimizers = [] | |
def feed_data(self, data): | |
pass | |
def optimize_parameters(self): | |
pass | |
def get_current_visuals(self): | |
pass | |
def save(self, epoch, current_iter): | |
"""Save networks and training state.""" | |
pass | |
def validation(self, dataloader, current_iter, tb_logger, save_img=False): | |
"""Validation function. | |
Args: | |
dataloader (torch.utils.data.DataLoader): Validation dataloader. | |
current_iter (int): Current iteration. | |
tb_logger (tensorboard logger): Tensorboard logger. | |
save_img (bool): Whether to save images. Default: False. | |
""" | |
if self.opt['dist']: | |
self.dist_validation(dataloader, current_iter, tb_logger, save_img) | |
else: | |
self.nondist_validation(dataloader, current_iter, tb_logger, save_img) | |
def _initialize_best_metric_results(self, dataset_name): | |
"""Initialize the best metric results dict for recording the best metric value and iteration.""" | |
if hasattr(self, 'best_metric_results') and dataset_name in self.best_metric_results: | |
return | |
elif not hasattr(self, 'best_metric_results'): | |
self.best_metric_results = dict() | |
# add a dataset record | |
record = dict() | |
for metric, content in self.opt['val']['metrics'].items(): | |
better = content.get('better', 'higher') | |
init_val = float('-inf') if better == 'higher' else float('inf') | |
record[metric] = dict(better=better, val=init_val, iter=-1) | |
self.best_metric_results[dataset_name] = record | |
def _update_best_metric_result(self, dataset_name, metric, val, current_iter): | |
if self.best_metric_results[dataset_name][metric]['better'] == 'higher': | |
if val >= self.best_metric_results[dataset_name][metric]['val']: | |
self.best_metric_results[dataset_name][metric]['val'] = val | |
self.best_metric_results[dataset_name][metric]['iter'] = current_iter | |
else: | |
if val <= self.best_metric_results[dataset_name][metric]['val']: | |
self.best_metric_results[dataset_name][metric]['val'] = val | |
self.best_metric_results[dataset_name][metric]['iter'] = current_iter | |
def model_ema(self, decay=0.999): | |
net_g = self.get_bare_model(self.net_g) | |
net_g_params = dict(net_g.named_parameters()) | |
net_g_ema_params = dict(self.net_g_ema.named_parameters()) | |
for k in net_g_ema_params.keys(): | |
net_g_ema_params[k].data.mul_(decay).add_(net_g_params[k].data, alpha=1 - decay) | |
def get_current_log(self): | |
return self.log_dict | |
def model_to_device(self, net): | |
"""Model to device. It also warps models with DistributedDataParallel | |
or DataParallel. | |
Args: | |
net (nn.Module) | |
""" | |
net = net.to(self.device) | |
if self.opt['dist']: | |
find_unused_parameters = self.opt.get('find_unused_parameters', False) | |
net = DistributedDataParallel( | |
net, device_ids=[torch.cuda.current_device()], find_unused_parameters=find_unused_parameters) | |
elif self.opt['num_gpu'] > 1: | |
net = DataParallel(net) | |
return net | |
def get_optimizer(self, optim_type, params, lr, **kwargs): | |
if optim_type == 'Adam': | |
optimizer = torch.optim.Adam(params, lr, **kwargs) | |
elif optim_type == 'AdamW': | |
optimizer = torch.optim.AdamW(params, lr, **kwargs) | |
elif optim_type == 'Adamax': | |
optimizer = torch.optim.Adamax(params, lr, **kwargs) | |
elif optim_type == 'SGD': | |
optimizer = torch.optim.SGD(params, lr, **kwargs) | |
elif optim_type == 'ASGD': | |
optimizer = torch.optim.ASGD(params, lr, **kwargs) | |
elif optim_type == 'RMSprop': | |
optimizer = torch.optim.RMSprop(params, lr, **kwargs) | |
elif optim_type == 'Rprop': | |
optimizer = torch.optim.Rprop(params, lr, **kwargs) | |
else: | |
raise NotImplementedError(f'optimizer {optim_type} is not supported yet.') | |
return optimizer | |
def setup_schedulers(self): | |
"""Set up schedulers.""" | |
train_opt = self.opt['train'] | |
scheduler_type = train_opt['scheduler'].pop('type') | |
if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']: | |
for optimizer in self.optimizers: | |
self.schedulers.append(lr_scheduler.MultiStepRestartLR(optimizer, **train_opt['scheduler'])) | |
elif scheduler_type == 'CosineAnnealingRestartLR': | |
for optimizer in self.optimizers: | |
self.schedulers.append(lr_scheduler.CosineAnnealingRestartLR(optimizer, **train_opt['scheduler'])) | |
else: | |
raise NotImplementedError(f'Scheduler {scheduler_type} is not implemented yet.') | |
def get_bare_model(self, net): | |
"""Get bare model, especially under wrapping with | |
DistributedDataParallel or DataParallel. | |
""" | |
if isinstance(net, (DataParallel, DistributedDataParallel)): | |
net = net.module | |
return net | |
def print_network(self, net): | |
"""Print the str and parameter number of a network. | |
Args: | |
net (nn.Module) | |
""" | |
if isinstance(net, (DataParallel, DistributedDataParallel)): | |
net_cls_str = f'{net.__class__.__name__} - {net.module.__class__.__name__}' | |
else: | |
net_cls_str = f'{net.__class__.__name__}' | |
net = self.get_bare_model(net) | |
net_str = str(net) | |
net_params = sum(map(lambda x: x.numel(), net.parameters())) | |
logger = get_root_logger() | |
logger.info(f'Network: {net_cls_str}, with parameters: {net_params:,d}') | |
logger.info(net_str) | |
def _set_lr(self, lr_groups_l): | |
"""Set learning rate for warm-up. | |
Args: | |
lr_groups_l (list): List for lr_groups, each for an optimizer. | |
""" | |
for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): | |
for param_group, lr in zip(optimizer.param_groups, lr_groups): | |
param_group['lr'] = lr | |
def _get_init_lr(self): | |
"""Get the initial lr, which is set by the scheduler. | |
""" | |
init_lr_groups_l = [] | |
for optimizer in self.optimizers: | |
init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups]) | |
return init_lr_groups_l | |
def update_learning_rate(self, current_iter, warmup_iter=-1): | |
"""Update learning rate. | |
Args: | |
current_iter (int): Current iteration. | |
warmup_iter (int): Warm-up iter numbers. -1 for no warm-up. | |
Default: -1. | |
""" | |
if current_iter > 1: | |
for scheduler in self.schedulers: | |
scheduler.step() | |
# set up warm-up learning rate | |
if current_iter < warmup_iter: | |
# get initial lr for each group | |
init_lr_g_l = self._get_init_lr() | |
# modify warming-up learning rates | |
# currently only support linearly warm up | |
warm_up_lr_l = [] | |
for init_lr_g in init_lr_g_l: | |
warm_up_lr_l.append([v / warmup_iter * current_iter for v in init_lr_g]) | |
# set learning rate | |
self._set_lr(warm_up_lr_l) | |
def get_current_learning_rate(self): | |
return [param_group['lr'] for param_group in self.optimizers[0].param_groups] | |
def save_network(self, net, net_label, current_iter, param_key='params'): | |
"""Save networks. | |
Args: | |
net (nn.Module | list[nn.Module]): Network(s) to be saved. | |
net_label (str): Network label. | |
current_iter (int): Current iter number. | |
param_key (str | list[str]): The parameter key(s) to save network. | |
Default: 'params'. | |
""" | |
if current_iter == -1: | |
current_iter = 'latest' | |
save_filename = f'{net_label}_{current_iter}.pth' | |
save_path = os.path.join(self.opt['path']['models'], save_filename) | |
net = net if isinstance(net, list) else [net] | |
param_key = param_key if isinstance(param_key, list) else [param_key] | |
assert len(net) == len(param_key), 'The lengths of net and param_key should be the same.' | |
save_dict = {} | |
for net_, param_key_ in zip(net, param_key): | |
net_ = self.get_bare_model(net_) | |
state_dict = net_.state_dict() | |
for key, param in state_dict.items(): | |
if key.startswith('module.'): # remove unnecessary 'module.' | |
key = key[7:] | |
state_dict[key] = param.cpu() | |
save_dict[param_key_] = state_dict | |
# avoid occasional writing errors | |
retry = 3 | |
while retry > 0: | |
try: | |
torch.save(save_dict, save_path) | |
except Exception as e: | |
logger = get_root_logger() | |
logger.warning(f'Save model error: {e}, remaining retry times: {retry - 1}') | |
time.sleep(1) | |
else: | |
break | |
finally: | |
retry -= 1 | |
if retry == 0: | |
logger.warning(f'Still cannot save {save_path}. Just ignore it.') | |
# raise IOError(f'Cannot save {save_path}.') | |
def _print_different_keys_loading(self, crt_net, load_net, strict=True): | |
"""Print keys with different name or different size when loading models. | |
1. Print keys with different names. | |
2. If strict=False, print the same key but with different tensor size. | |
It also ignore these keys with different sizes (not load). | |
Args: | |
crt_net (torch model): Current network. | |
load_net (dict): Loaded network. | |
strict (bool): Whether strictly loaded. Default: True. | |
""" | |
crt_net = self.get_bare_model(crt_net) | |
crt_net = crt_net.state_dict() | |
crt_net_keys = set(crt_net.keys()) | |
load_net_keys = set(load_net.keys()) | |
logger = get_root_logger() | |
if crt_net_keys != load_net_keys: | |
logger.warning('Current net - loaded net:') | |
for v in sorted(list(crt_net_keys - load_net_keys)): | |
logger.warning(f' {v}') | |
logger.warning('Loaded net - current net:') | |
for v in sorted(list(load_net_keys - crt_net_keys)): | |
logger.warning(f' {v}') | |
# check the size for the same keys | |
if not strict: | |
common_keys = crt_net_keys & load_net_keys | |
for k in common_keys: | |
if crt_net[k].size() != load_net[k].size(): | |
logger.warning(f'Size different, ignore [{k}]: crt_net: ' | |
f'{crt_net[k].shape}; load_net: {load_net[k].shape}') | |
load_net[k + '.ignore'] = load_net.pop(k) | |
def load_network(self, net, load_path, strict=True, param_key='params'): | |
"""Load network. | |
Args: | |
load_path (str): The path of networks to be loaded. | |
net (nn.Module): Network. | |
strict (bool): Whether strictly loaded. | |
param_key (str): The parameter key of loaded network. If set to | |
None, use the root 'path'. | |
Default: 'params'. | |
""" | |
logger = get_root_logger() | |
net = self.get_bare_model(net) | |
load_net = torch.load(load_path, map_location=lambda storage, loc: storage) | |
if param_key is not None: | |
if param_key not in load_net and 'params' in load_net: | |
param_key = 'params' | |
logger.info('Loading: params_ema does not exist, use params.') | |
load_net = load_net[param_key] | |
logger.info(f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].') | |
# remove unnecessary 'module.' | |
for k, v in deepcopy(load_net).items(): | |
if k.startswith('module.'): | |
load_net[k[7:]] = v | |
load_net.pop(k) | |
self._print_different_keys_loading(net, load_net, strict) | |
net.load_state_dict(load_net, strict=strict) | |
def save_training_state(self, epoch, current_iter): | |
"""Save training states during training, which will be used for | |
resuming. | |
Args: | |
epoch (int): Current epoch. | |
current_iter (int): Current iteration. | |
""" | |
if current_iter != -1: | |
state = {'epoch': epoch, 'iter': current_iter, 'optimizers': [], 'schedulers': []} | |
for o in self.optimizers: | |
state['optimizers'].append(o.state_dict()) | |
for s in self.schedulers: | |
state['schedulers'].append(s.state_dict()) | |
save_filename = f'{current_iter}.state' | |
save_path = os.path.join(self.opt['path']['training_states'], save_filename) | |
# avoid occasional writing errors | |
retry = 3 | |
while retry > 0: | |
try: | |
torch.save(state, save_path) | |
except Exception as e: | |
logger = get_root_logger() | |
logger.warning(f'Save training state error: {e}, remaining retry times: {retry - 1}') | |
time.sleep(1) | |
else: | |
break | |
finally: | |
retry -= 1 | |
if retry == 0: | |
logger.warning(f'Still cannot save {save_path}. Just ignore it.') | |
# raise IOError(f'Cannot save {save_path}.') | |
def resume_training(self, resume_state): | |
"""Reload the optimizers and schedulers for resumed training. | |
Args: | |
resume_state (dict): Resume state. | |
""" | |
resume_optimizers = resume_state['optimizers'] | |
resume_schedulers = resume_state['schedulers'] | |
assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers' | |
assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers' | |
for i, o in enumerate(resume_optimizers): | |
self.optimizers[i].load_state_dict(o) | |
for i, s in enumerate(resume_schedulers): | |
self.schedulers[i].load_state_dict(s) | |
def reduce_loss_dict(self, loss_dict): | |
"""reduce loss dict. | |
In distributed training, it averages the losses among different GPUs . | |
Args: | |
loss_dict (OrderedDict): Loss dict. | |
""" | |
with torch.no_grad(): | |
if self.opt['dist']: | |
keys = [] | |
losses = [] | |
for name, value in loss_dict.items(): | |
keys.append(name) | |
losses.append(value) | |
losses = torch.stack(losses, 0) | |
torch.distributed.reduce(losses, dst=0) | |
if self.opt['rank'] == 0: | |
losses /= self.opt['world_size'] | |
loss_dict = {key: loss for key, loss in zip(keys, losses)} | |
log_dict = OrderedDict() | |
for name, value in loss_dict.items(): | |
log_dict[name] = value.mean().item() | |
return log_dict | |