|
|
|
|
|
|
|
|
|
import os, sys, math, time, random, datetime, functools |
|
import lpips |
|
import numpy as np |
|
from pathlib import Path |
|
from loguru import logger |
|
from copy import deepcopy |
|
from omegaconf import OmegaConf |
|
from collections import OrderedDict |
|
from einops import rearrange |
|
from contextlib import nullcontext |
|
|
|
from datapipe.datasets import create_dataset |
|
|
|
from utils import util_net |
|
from utils import util_common |
|
from utils import util_image |
|
|
|
from basicsr.utils import DiffJPEG, USMSharp |
|
from basicsr.utils.img_process_util import filter2D |
|
from basicsr.data.transforms import paired_random_crop |
|
from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.cuda.amp as amp |
|
import torch.nn.functional as F |
|
import torch.utils.data as udata |
|
import torch.distributed as dist |
|
import torch.multiprocessing as mp |
|
import torchvision.utils as vutils |
|
|
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
|
|
|
class TrainerBase: |
|
def __init__(self, configs): |
|
self.configs = configs |
|
|
|
|
|
self.setup_dist() |
|
|
|
|
|
self.setup_seed() |
|
|
|
def setup_dist(self): |
|
num_gpus = torch.cuda.device_count() |
|
|
|
if num_gpus > 1: |
|
if mp.get_start_method(allow_none=True) is None: |
|
mp.set_start_method('spawn') |
|
rank = int(os.environ['LOCAL_RANK']) |
|
torch.cuda.set_device(rank % num_gpus) |
|
dist.init_process_group( |
|
timeout=datetime.timedelta(seconds=3600), |
|
backend='nccl', |
|
init_method='env://', |
|
) |
|
|
|
self.num_gpus = num_gpus |
|
self.rank = int(os.environ['LOCAL_RANK']) if num_gpus > 1 else 0 |
|
|
|
def setup_seed(self, seed=None, global_seeding=None): |
|
if seed is None: |
|
seed = self.configs.train.get('seed', 12345) |
|
if global_seeding is None: |
|
global_seeding = self.configs.train.global_seeding |
|
assert isinstance(global_seeding, bool) |
|
if not global_seeding: |
|
seed += self.rank |
|
torch.cuda.manual_seed(seed) |
|
else: |
|
torch.cuda.manual_seed_all(seed) |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
|
|
def init_logger(self): |
|
if self.configs.resume: |
|
assert self.configs.resume.endswith(".pth") |
|
save_dir = Path(self.configs.resume).parents[1] |
|
project_id = save_dir.name |
|
else: |
|
project_id = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M") |
|
save_dir = Path(self.configs.save_dir) / project_id |
|
if not save_dir.exists() and self.rank == 0: |
|
save_dir.mkdir(parents=True) |
|
|
|
|
|
if self.rank == 0: |
|
self.log_step = {phase: 1 for phase in ['train', 'val']} |
|
self.log_step_img = {phase: 1 for phase in ['train', 'val']} |
|
|
|
|
|
logtxet_path = save_dir / 'training.log' |
|
if self.rank == 0: |
|
if logtxet_path.exists(): |
|
assert self.configs.resume |
|
self.logger = logger |
|
self.logger.remove() |
|
self.logger.add(logtxet_path, format="{message}", mode='a', level='INFO') |
|
self.logger.add(sys.stdout, format="{message}") |
|
|
|
|
|
log_dir = save_dir / 'tf_logs' |
|
self.tf_logging = self.configs.train.tf_logging |
|
if self.rank == 0 and self.tf_logging: |
|
if not log_dir.exists(): |
|
log_dir.mkdir() |
|
self.writer = SummaryWriter(str(log_dir)) |
|
|
|
|
|
ckpt_dir = save_dir / 'ckpts' |
|
self.ckpt_dir = ckpt_dir |
|
if self.rank == 0 and (not ckpt_dir.exists()): |
|
ckpt_dir.mkdir() |
|
if 'ema_rate' in self.configs.train: |
|
self.ema_rate = self.configs.train.ema_rate |
|
assert isinstance(self.ema_rate, float), "Ema rate must be a float number" |
|
ema_ckpt_dir = save_dir / 'ema_ckpts' |
|
self.ema_ckpt_dir = ema_ckpt_dir |
|
if self.rank == 0 and (not ema_ckpt_dir.exists()): |
|
ema_ckpt_dir.mkdir() |
|
|
|
|
|
self.local_logging = self.configs.train.local_logging |
|
if self.rank == 0 and self.local_logging: |
|
image_dir = save_dir / 'images' |
|
if not image_dir.exists(): |
|
(image_dir / 'train').mkdir(parents=True) |
|
(image_dir / 'val').mkdir(parents=True) |
|
self.image_dir = image_dir |
|
|
|
|
|
if self.rank == 0: |
|
self.logger.info(OmegaConf.to_yaml(self.configs)) |
|
|
|
def close_logger(self): |
|
if self.rank == 0 and self.tf_logging: |
|
self.writer.close() |
|
|
|
def resume_from_ckpt(self): |
|
def _load_ema_state(ema_state, ckpt): |
|
for key in ema_state.keys(): |
|
if key not in ckpt and key.startswith('module'): |
|
ema_state[key] = deepcopy(ckpt[7:].detach().data) |
|
elif key not in ckpt and (not key.startswith('module')): |
|
ema_state[key] = deepcopy(ckpt['module.'+key].detach().data) |
|
else: |
|
ema_state[key] = deepcopy(ckpt[key].detach().data) |
|
|
|
if self.configs.resume: |
|
assert self.configs.resume.endswith(".pth") and os.path.isfile(self.configs.resume) |
|
|
|
if self.rank == 0: |
|
self.logger.info(f"=> Loaded checkpoint from {self.configs.resume}") |
|
ckpt = torch.load(self.configs.resume, map_location=f"cuda:{self.rank}") |
|
util_net.reload_model(self.model, ckpt['state_dict']) |
|
torch.cuda.empty_cache() |
|
|
|
|
|
self.iters_start = ckpt['iters_start'] |
|
for ii in range(1, self.iters_start+1): |
|
self.adjust_lr(ii) |
|
|
|
|
|
if self.rank == 0: |
|
self.log_step = ckpt['log_step'] |
|
self.log_step_img = ckpt['log_step_img'] |
|
|
|
|
|
if self.rank == 0 and hasattr(self, 'ema_rate'): |
|
ema_ckpt_path = self.ema_ckpt_dir / ("ema_"+Path(self.configs.resume).name) |
|
self.logger.info(f"=> Loaded EMA checkpoint from {str(ema_ckpt_path)}") |
|
ema_ckpt = torch.load(ema_ckpt_path, map_location=f"cuda:{self.rank}") |
|
_load_ema_state(self.ema_state, ema_ckpt) |
|
torch.cuda.empty_cache() |
|
|
|
|
|
if self.amp_scaler is not None: |
|
if "amp_scaler" in ckpt: |
|
self.amp_scaler.load_state_dict(ckpt["amp_scaler"]) |
|
if self.rank == 0: |
|
self.logger.info("Loading scaler from resumed state...") |
|
|
|
|
|
self.setup_seed(seed=self.iters_start) |
|
else: |
|
self.iters_start = 0 |
|
|
|
def setup_optimizaton(self): |
|
self.optimizer = torch.optim.AdamW(self.model.parameters(), |
|
lr=self.configs.train.lr, |
|
weight_decay=self.configs.train.weight_decay) |
|
|
|
|
|
self.amp_scaler = amp.GradScaler() if self.configs.train.use_amp else None |
|
|
|
def build_model(self): |
|
params = self.configs.model.get('params', dict) |
|
model = util_common.get_obj_from_str(self.configs.model.target)(**params) |
|
model.cuda() |
|
if self.configs.model.ckpt_path is not None: |
|
ckpt_path = self.configs.model.ckpt_path |
|
if self.rank == 0: |
|
self.logger.info(f"Initializing model from {ckpt_path}") |
|
ckpt = torch.load(ckpt_path, map_location=f"cuda:{self.rank}") |
|
if 'state_dict' in ckpt: |
|
ckpt = ckpt['state_dict'] |
|
util_net.reload_model(model, ckpt) |
|
if self.configs.train.compile.flag: |
|
if self.rank == 0: |
|
self.logger.info("Begin compiling model...") |
|
model = torch.compile(model, mode=self.configs.train.compile.mode) |
|
if self.rank == 0: |
|
self.logger.info("Compiling Done") |
|
if self.num_gpus > 1: |
|
self.model = DDP(model, device_ids=[self.rank,], static_graph=False) |
|
else: |
|
self.model = model |
|
|
|
|
|
if self.rank == 0 and hasattr(self.configs.train, 'ema_rate'): |
|
self.ema_model = deepcopy(model).cuda() |
|
self.ema_state = OrderedDict( |
|
{key:deepcopy(value.data) for key, value in self.model.state_dict().items()} |
|
) |
|
self.ema_ignore_keys = [x for x in self.ema_state.keys() if ('running_' in x or 'num_batches_tracked' in x)] |
|
|
|
|
|
self.print_model_info() |
|
|
|
def build_dataloader(self): |
|
def _wrap_loader(loader): |
|
while True: yield from loader |
|
|
|
|
|
datasets = {'train': create_dataset(self.configs.data.get('train', dict)), } |
|
if hasattr(self.configs.data, 'val') and self.rank == 0: |
|
datasets['val'] = create_dataset(self.configs.data.get('val', dict)) |
|
if self.rank == 0: |
|
for phase in datasets.keys(): |
|
length = len(datasets[phase]) |
|
self.logger.info('Number of images in {:s} data set: {:d}'.format(phase, length)) |
|
|
|
|
|
if self.num_gpus > 1: |
|
sampler = udata.distributed.DistributedSampler( |
|
datasets['train'], |
|
num_replicas=self.num_gpus, |
|
rank=self.rank, |
|
) |
|
else: |
|
sampler = None |
|
dataloaders = {'train': _wrap_loader(udata.DataLoader( |
|
datasets['train'], |
|
batch_size=self.configs.train.batch[0] // self.num_gpus, |
|
shuffle=False if self.num_gpus > 1 else True, |
|
drop_last=True, |
|
num_workers=min(self.configs.train.num_workers, 4), |
|
pin_memory=True, |
|
prefetch_factor=self.configs.train.get('prefetch_factor', 2), |
|
worker_init_fn=my_worker_init_fn, |
|
sampler=sampler, |
|
))} |
|
if hasattr(self.configs.data, 'val') and self.rank == 0: |
|
dataloaders['val'] = udata.DataLoader(datasets['val'], |
|
batch_size=self.configs.train.batch[1], |
|
shuffle=False, |
|
drop_last=False, |
|
num_workers=0, |
|
pin_memory=True, |
|
) |
|
|
|
self.datasets = datasets |
|
self.dataloaders = dataloaders |
|
self.sampler = sampler |
|
|
|
def print_model_info(self): |
|
if self.rank == 0: |
|
num_params = util_net.calculate_parameters(self.model) / 1000**2 |
|
|
|
|
|
self.logger.info(f"Number of parameters: {num_params:.2f}M") |
|
|
|
def prepare_data(self, data, dtype=torch.float32, phase='train'): |
|
data = {key:value.cuda().to(dtype=dtype) for key, value in data.items()} |
|
return data |
|
|
|
def validation(self): |
|
pass |
|
|
|
def train(self): |
|
self.init_logger() |
|
|
|
self.build_model() |
|
|
|
self.setup_optimizaton() |
|
|
|
self.resume_from_ckpt() |
|
|
|
self.build_dataloader() |
|
|
|
self.model.train() |
|
num_iters_epoch = math.ceil(len(self.datasets['train']) / self.configs.train.batch[0]) |
|
for ii in range(self.iters_start, self.configs.train.iterations): |
|
self.current_iters = ii + 1 |
|
|
|
|
|
data = self.prepare_data(next(self.dataloaders['train'])) |
|
|
|
|
|
self.training_step(data) |
|
|
|
|
|
if 'val' in self.dataloaders and (ii+1) % self.configs.train.get('val_freq', 10000) == 0: |
|
self.validation() |
|
|
|
|
|
self.adjust_lr() |
|
|
|
|
|
if (ii+1) % self.configs.train.save_freq == 0: |
|
self.save_ckpt() |
|
|
|
if (ii+1) % num_iters_epoch == 0 and self.sampler is not None: |
|
self.sampler.set_epoch(ii+1) |
|
|
|
|
|
self.close_logger() |
|
|
|
def training_step(self, data): |
|
pass |
|
|
|
def adjust_lr(self, current_iters=None): |
|
assert hasattr(self, 'lr_scheduler') |
|
self.lr_scheduler.step() |
|
|
|
def save_ckpt(self): |
|
if self.rank == 0: |
|
ckpt_path = self.ckpt_dir / 'model_{:d}.pth'.format(self.current_iters) |
|
ckpt = { |
|
'iters_start': self.current_iters, |
|
'log_step': {phase:self.log_step[phase] for phase in ['train', 'val']}, |
|
'log_step_img': {phase:self.log_step_img[phase] for phase in ['train', 'val']}, |
|
'state_dict': self.model.state_dict(), |
|
} |
|
if self.amp_scaler is not None: |
|
ckpt['amp_scaler'] = self.amp_scaler.state_dict() |
|
torch.save(ckpt, ckpt_path) |
|
if hasattr(self, 'ema_rate'): |
|
ema_ckpt_path = self.ema_ckpt_dir / 'ema_model_{:d}.pth'.format(self.current_iters) |
|
torch.save(self.ema_state, ema_ckpt_path) |
|
|
|
def reload_ema_model(self): |
|
if self.rank == 0: |
|
if self.num_gpus > 1: |
|
model_state = {key[7:]:value for key, value in self.ema_state.items()} |
|
else: |
|
model_state = self.ema_state |
|
self.ema_model.load_state_dict(model_state) |
|
|
|
@torch.no_grad() |
|
def update_ema_model(self): |
|
if self.num_gpus > 1: |
|
dist.barrier() |
|
if self.rank == 0: |
|
source_state = self.model.state_dict() |
|
rate = self.ema_rate |
|
for key, value in self.ema_state.items(): |
|
if key in self.ema_ignore_keys: |
|
self.ema_state[key] = source_state[key] |
|
else: |
|
self.ema_state[key].mul_(rate).add_(source_state[key].detach().data, alpha=1-rate) |
|
|
|
def logging_image(self, im_tensor, tag, phase, add_global_step=False, nrow=8): |
|
""" |
|
Args: |
|
im_tensor: b x c x h x w tensor |
|
im_tag: str |
|
phase: 'train' or 'val' |
|
nrow: number of displays in each row |
|
""" |
|
assert self.tf_logging or self.local_logging |
|
im_tensor = vutils.make_grid(im_tensor, nrow=nrow, normalize=True, scale_each=True) |
|
if self.local_logging: |
|
im_path = str(self.image_dir / phase / f"{tag}-{self.log_step_img[phase]}.png") |
|
im_np = im_tensor.cpu().permute(1,2,0).numpy() |
|
util_image.imwrite(im_np, im_path) |
|
if self.tf_logging: |
|
self.writer.add_image( |
|
f"{phase}-{tag}-{self.log_step_img[phase]}", |
|
im_tensor, |
|
self.log_step_img[phase], |
|
) |
|
if add_global_step: |
|
self.log_step_img[phase] += 1 |
|
|
|
def logging_metric(self, metrics, tag, phase, add_global_step=False): |
|
""" |
|
Args: |
|
metrics: dict |
|
tag: str |
|
phase: 'train' or 'val' |
|
""" |
|
if self.tf_logging: |
|
tag = f"{phase}-{tag}" |
|
if isinstance(metrics, dict): |
|
self.writer.add_scalars(tag, metrics, self.log_step[phase]) |
|
else: |
|
self.writer.add_scalar(tag, metrics, self.log_step[phase]) |
|
if add_global_step: |
|
self.log_step[phase] += 1 |
|
else: |
|
pass |
|
|
|
def load_model(self, model, ckpt_path=None): |
|
if self.rank == 0: |
|
self.logger.info(f'Loading from {ckpt_path}...') |
|
ckpt = torch.load(ckpt_path, map_location=f"cuda:{self.rank}") |
|
if 'state_dict' in ckpt: |
|
ckpt = ckpt['state_dict'] |
|
util_net.reload_model(model, ckpt) |
|
if self.rank == 0: |
|
self.logger.info('Loaded Done') |
|
|
|
def freeze_model(self, net): |
|
for params in net.parameters(): |
|
params.requires_grad = False |
|
|
|
class TrainerDifIR(TrainerBase): |
|
def setup_optimizaton(self): |
|
super().setup_optimizaton() |
|
if self.configs.train.lr_schedule == 'cosin': |
|
self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( |
|
optimizer=self.optimizer, |
|
T_max=self.configs.train.iterations - self.configs.train.warmup_iterations, |
|
eta_min=self.configs.train.lr_min, |
|
) |
|
|
|
def build_model(self): |
|
super().build_model() |
|
if self.rank == 0 and hasattr(self.configs.train, 'ema_rate'): |
|
self.ema_ignore_keys.extend([x for x in self.ema_state.keys() if 'relative_position_index' in x]) |
|
|
|
|
|
if self.configs.autoencoder is not None: |
|
ckpt = torch.load(self.configs.autoencoder.ckpt_path, map_location=f"cuda:{self.rank}") |
|
if self.rank == 0: |
|
self.logger.info(f"Restoring autoencoder from {self.configs.autoencoder.ckpt_path}") |
|
params = self.configs.autoencoder.get('params', dict) |
|
autoencoder = util_common.get_obj_from_str(self.configs.autoencoder.target)(**params) |
|
autoencoder.cuda() |
|
autoencoder.load_state_dict(ckpt, True) |
|
for params in autoencoder.parameters(): |
|
params.requires_grad_(False) |
|
autoencoder.eval() |
|
if self.configs.train.compile.flag: |
|
if self.rank == 0: |
|
self.logger.info("Begin compiling autoencoder model...") |
|
autoencoder = torch.compile(autoencoder, mode=self.configs.train.compile.mode) |
|
if self.rank == 0: |
|
self.logger.info("Compiling Done") |
|
self.autoencoder = autoencoder |
|
else: |
|
self.autoencoder = None |
|
|
|
|
|
lpips_loss = lpips.LPIPS(net='vgg').to(f"cuda:{self.rank}") |
|
for params in lpips_loss.parameters(): |
|
params.requires_grad_(False) |
|
lpips_loss.eval() |
|
if self.configs.train.compile.flag: |
|
if self.rank == 0: |
|
self.logger.info("Begin compiling LPIPS Metric...") |
|
lpips_loss = torch.compile(lpips_loss, mode=self.configs.train.compile.mode) |
|
if self.rank == 0: |
|
self.logger.info("Compiling Done") |
|
self.lpips_loss = lpips_loss |
|
|
|
params = self.configs.diffusion.get('params', dict) |
|
self.base_diffusion = util_common.get_obj_from_str(self.configs.diffusion.target)(**params) |
|
|
|
@torch.no_grad() |
|
def _dequeue_and_enqueue(self): |
|
"""It is the training pair pool for increasing the diversity in a batch. |
|
|
|
Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a |
|
batch could not have different resize scaling factors. Therefore, we employ this training pair pool |
|
to increase the degradation diversity in a batch. |
|
""" |
|
|
|
b, c, h, w = self.lq.size() |
|
if not hasattr(self, 'queue_size'): |
|
self.queue_size = self.configs.degradation.get('queue_size', b*10) |
|
if not hasattr(self, 'queue_lr'): |
|
assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}' |
|
self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda() |
|
_, c, h, w = self.gt.size() |
|
self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda() |
|
self.queue_ptr = 0 |
|
if self.queue_ptr == self.queue_size: |
|
|
|
|
|
idx = torch.randperm(self.queue_size) |
|
self.queue_lr = self.queue_lr[idx] |
|
self.queue_gt = self.queue_gt[idx] |
|
|
|
lq_dequeue = self.queue_lr[0:b, :, :, :].clone() |
|
gt_dequeue = self.queue_gt[0:b, :, :, :].clone() |
|
|
|
self.queue_lr[0:b, :, :, :] = self.lq.clone() |
|
self.queue_gt[0:b, :, :, :] = self.gt.clone() |
|
|
|
self.lq = lq_dequeue |
|
self.gt = gt_dequeue |
|
else: |
|
|
|
self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone() |
|
self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone() |
|
self.queue_ptr = self.queue_ptr + b |
|
|
|
@torch.no_grad() |
|
def prepare_data(self, data, dtype=torch.float32, realesrgan=None, phase='train'): |
|
if realesrgan is None: |
|
realesrgan = self.configs.data.get(phase, dict).type == 'realesrgan' |
|
if realesrgan and phase == 'train': |
|
if not hasattr(self, 'jpeger'): |
|
self.jpeger = DiffJPEG(differentiable=False).cuda() |
|
if not hasattr(self, 'use_sharpener'): |
|
self.use_sharpener = USMSharp().cuda() |
|
|
|
im_gt = data['gt'].cuda() |
|
kernel1 = data['kernel1'].cuda() |
|
kernel2 = data['kernel2'].cuda() |
|
sinc_kernel = data['sinc_kernel'].cuda() |
|
|
|
ori_h, ori_w = im_gt.size()[2:4] |
|
if isinstance(self.configs.degradation.sf, int): |
|
sf = self.configs.degradation.sf |
|
else: |
|
assert len(self.configs.degradation.sf) == 2 |
|
sf = random.uniform(*self.configs.degradation.sf) |
|
|
|
if self.configs.degradation.use_sharp: |
|
im_gt = self.use_sharpener(im_gt) |
|
|
|
|
|
|
|
out = filter2D(im_gt, kernel1) |
|
|
|
updown_type = random.choices( |
|
['up', 'down', 'keep'], |
|
self.configs.degradation['resize_prob'], |
|
)[0] |
|
if updown_type == 'up': |
|
scale = random.uniform(1, self.configs.degradation['resize_range'][1]) |
|
elif updown_type == 'down': |
|
scale = random.uniform(self.configs.degradation['resize_range'][0], 1) |
|
else: |
|
scale = 1 |
|
mode = random.choice(['area', 'bilinear', 'bicubic']) |
|
out = F.interpolate(out, scale_factor=scale, mode=mode) |
|
|
|
gray_noise_prob = self.configs.degradation['gray_noise_prob'] |
|
if random.random() < self.configs.degradation['gaussian_noise_prob']: |
|
out = random_add_gaussian_noise_pt( |
|
out, |
|
sigma_range=self.configs.degradation['noise_range'], |
|
clip=True, |
|
rounds=False, |
|
gray_prob=gray_noise_prob, |
|
) |
|
else: |
|
out = random_add_poisson_noise_pt( |
|
out, |
|
scale_range=self.configs.degradation['poisson_scale_range'], |
|
gray_prob=gray_noise_prob, |
|
clip=True, |
|
rounds=False) |
|
|
|
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.configs.degradation['jpeg_range']) |
|
out = torch.clamp(out, 0, 1) |
|
out = self.jpeger(out, quality=jpeg_p) |
|
|
|
|
|
if random.random() < self.configs.degradation['second_order_prob']: |
|
|
|
if random.random() < self.configs.degradation['second_blur_prob']: |
|
out = filter2D(out, kernel2) |
|
|
|
updown_type = random.choices( |
|
['up', 'down', 'keep'], |
|
self.configs.degradation['resize_prob2'], |
|
)[0] |
|
if updown_type == 'up': |
|
scale = random.uniform(1, self.configs.degradation['resize_range2'][1]) |
|
elif updown_type == 'down': |
|
scale = random.uniform(self.configs.degradation['resize_range2'][0], 1) |
|
else: |
|
scale = 1 |
|
mode = random.choice(['area', 'bilinear', 'bicubic']) |
|
out = F.interpolate( |
|
out, |
|
size=(int(ori_h / sf * scale), int(ori_w / sf * scale)), |
|
mode=mode, |
|
) |
|
|
|
gray_noise_prob = self.configs.degradation['gray_noise_prob2'] |
|
if random.random() < self.configs.degradation['gaussian_noise_prob2']: |
|
out = random_add_gaussian_noise_pt( |
|
out, |
|
sigma_range=self.configs.degradation['noise_range2'], |
|
clip=True, |
|
rounds=False, |
|
gray_prob=gray_noise_prob, |
|
) |
|
else: |
|
out = random_add_poisson_noise_pt( |
|
out, |
|
scale_range=self.configs.degradation['poisson_scale_range2'], |
|
gray_prob=gray_noise_prob, |
|
clip=True, |
|
rounds=False, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if random.random() < 0.5: |
|
|
|
mode = random.choice(['area', 'bilinear', 'bicubic']) |
|
out = F.interpolate( |
|
out, |
|
size=(ori_h // sf, ori_w // sf), |
|
mode=mode, |
|
) |
|
out = filter2D(out, sinc_kernel) |
|
|
|
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.configs.degradation['jpeg_range2']) |
|
out = torch.clamp(out, 0, 1) |
|
out = self.jpeger(out, quality=jpeg_p) |
|
else: |
|
|
|
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.configs.degradation['jpeg_range2']) |
|
out = torch.clamp(out, 0, 1) |
|
out = self.jpeger(out, quality=jpeg_p) |
|
|
|
mode = random.choice(['area', 'bilinear', 'bicubic']) |
|
out = F.interpolate( |
|
out, |
|
size=(ori_h // sf, ori_w // sf), |
|
mode=mode, |
|
) |
|
out = filter2D(out, sinc_kernel) |
|
|
|
|
|
if self.configs.degradation.resize_back: |
|
out = F.interpolate(out, size=(ori_h, ori_w), mode='bicubic') |
|
temp_sf = self.configs.degradation['sf'] |
|
else: |
|
temp_sf = self.configs.degradation['sf'] |
|
|
|
|
|
im_lq = torch.clamp((out * 255.0).round(), 0, 255) / 255. |
|
|
|
|
|
gt_size = self.configs.degradation['gt_size'] |
|
im_gt, im_lq = paired_random_crop(im_gt, im_lq, gt_size, temp_sf) |
|
im_lq = (im_lq - 0.5) / 0.5 |
|
im_gt = (im_gt - 0.5) / 0.5 |
|
self.lq, self.gt, flag_nan = replace_nan_in_batch(im_lq, im_gt) |
|
if flag_nan: |
|
with open(f"records_nan_rank{self.rank}.log", 'a') as f: |
|
f.write(f'Find Nan value in rank{self.rank}\n') |
|
|
|
|
|
self._dequeue_and_enqueue() |
|
self.lq = self.lq.contiguous() |
|
|
|
return {'lq':self.lq, 'gt':self.gt} |
|
elif phase == 'val': |
|
offset = self.configs.train.get('val_resolution', 256) |
|
for key, value in data.items(): |
|
h, w = value.shape[2:] |
|
if h > offset and w > offset: |
|
h_end = int((h // offset) * offset) |
|
w_end = int((w // offset) * offset) |
|
data[key] = value[:, :, :h_end, :w_end] |
|
else: |
|
h_pad = math.ceil(h / offset) * offset - h |
|
w_pad = math.ceil(w / offset) * offset - w |
|
padding_mode = self.configs.train.get('val_padding_mode', 'reflect') |
|
data[key] = F.pad(value, pad=(0, w_pad, 0, h_pad), mode=padding_mode) |
|
return {key:value.cuda().to(dtype=dtype) for key, value in data.items()} |
|
else: |
|
return {key:value.cuda().to(dtype=dtype) for key, value in data.items()} |
|
|
|
def backward_step(self, dif_loss_wrapper, micro_data, num_grad_accumulate, tt): |
|
context = torch.cuda.amp.autocast if self.configs.train.use_amp else nullcontext |
|
with context(): |
|
losses, z_t, z0_pred = dif_loss_wrapper() |
|
losses['loss'] = losses['mse'] |
|
loss = losses['loss'].mean() / num_grad_accumulate |
|
if self.amp_scaler is None: |
|
loss.backward() |
|
else: |
|
self.amp_scaler.scale(loss).backward() |
|
|
|
return losses, z0_pred, z_t |
|
|
|
def training_step(self, data): |
|
current_batchsize = data['gt'].shape[0] |
|
micro_batchsize = self.configs.train.microbatch |
|
num_grad_accumulate = math.ceil(current_batchsize / micro_batchsize) |
|
|
|
for jj in range(0, current_batchsize, micro_batchsize): |
|
micro_data = {key:value[jj:jj+micro_batchsize,] for key, value in data.items()} |
|
last_batch = (jj+micro_batchsize >= current_batchsize) |
|
tt = torch.randint( |
|
0, self.base_diffusion.num_timesteps, |
|
size=(micro_data['gt'].shape[0],), |
|
device=f"cuda:{self.rank}", |
|
) |
|
latent_downsamping_sf = 2**(len(self.configs.autoencoder.params.ddconfig.ch_mult) - 1) |
|
latent_resolution = micro_data['gt'].shape[-1] // latent_downsamping_sf |
|
if 'autoencoder' in self.configs: |
|
noise_chn = self.configs.autoencoder.params.embed_dim |
|
else: |
|
noise_chn = micro_data['gt'].shape[1] |
|
noise = torch.randn( |
|
size= (micro_data['gt'].shape[0], noise_chn,) + (latent_resolution, ) * 2, |
|
device=micro_data['gt'].device, |
|
) |
|
if self.configs.model.params.cond_lq: |
|
model_kwargs = {'lq':micro_data['lq'],} |
|
if 'mask' in micro_data: |
|
model_kwargs['mask'] = micro_data['mask'] |
|
else: |
|
model_kwargs = None |
|
compute_losses = functools.partial( |
|
self.base_diffusion.training_losses, |
|
self.model, |
|
micro_data['gt'], |
|
micro_data['lq'], |
|
tt, |
|
first_stage_model=self.autoencoder, |
|
model_kwargs=model_kwargs, |
|
noise=noise, |
|
) |
|
if last_batch or self.num_gpus <= 1: |
|
losses, z0_pred, z_t = self.backward_step(compute_losses, micro_data, num_grad_accumulate, tt) |
|
else: |
|
with self.model.no_sync(): |
|
losses, z0_pred, z_t = self.backward_step(compute_losses, micro_data, num_grad_accumulate, tt) |
|
|
|
|
|
if last_batch: |
|
self.log_step_train(losses, tt, micro_data, z_t, z0_pred.detach()) |
|
|
|
if self.configs.train.use_amp: |
|
self.amp_scaler.step(self.optimizer) |
|
self.amp_scaler.update() |
|
else: |
|
self.optimizer.step() |
|
|
|
|
|
self.model.zero_grad() |
|
|
|
if hasattr(self.configs.train, 'ema_rate'): |
|
self.update_ema_model() |
|
|
|
def adjust_lr(self, current_iters=None): |
|
base_lr = self.configs.train.lr |
|
warmup_steps = self.configs.train.warmup_iterations |
|
current_iters = self.current_iters if current_iters is None else current_iters |
|
if current_iters <= warmup_steps: |
|
for params_group in self.optimizer.param_groups: |
|
params_group['lr'] = (current_iters / warmup_steps) * base_lr |
|
else: |
|
if hasattr(self, 'lr_scheduler'): |
|
self.lr_scheduler.step() |
|
|
|
def log_step_train(self, loss, tt, batch, z_t, z0_pred, phase='train'): |
|
''' |
|
param loss: a dict recording the loss informations |
|
param tt: 1-D tensor, time steps |
|
''' |
|
if self.rank == 0: |
|
chn = batch['gt'].shape[1] |
|
num_timesteps = self.base_diffusion.num_timesteps |
|
record_steps = [1, (num_timesteps // 2) + 1, num_timesteps] |
|
if self.current_iters % self.configs.train.log_freq[0] == 1: |
|
self.loss_mean = {key:torch.zeros(size=(len(record_steps),), dtype=torch.float64) |
|
for key in loss.keys()} |
|
self.loss_count = torch.zeros(size=(len(record_steps),), dtype=torch.float64) |
|
for jj in range(len(record_steps)): |
|
for key, value in loss.items(): |
|
index = record_steps[jj] - 1 |
|
mask = torch.where(tt == index, torch.ones_like(tt), torch.zeros_like(tt)) |
|
current_loss = torch.sum(value.detach() * mask) |
|
self.loss_mean[key][jj] += current_loss.item() |
|
self.loss_count[jj] += mask.sum().item() |
|
|
|
if self.current_iters % self.configs.train.log_freq[0] == 0: |
|
if torch.any(self.loss_count == 0): |
|
self.loss_count += 1e-4 |
|
for key in loss.keys(): |
|
self.loss_mean[key] /= self.loss_count |
|
log_str = 'Train: {:06d}/{:06d}, Loss/MSE: '.format( |
|
self.current_iters, |
|
self.configs.train.iterations) |
|
for jj, current_record in enumerate(record_steps): |
|
log_str += 't({:d}):{:.1e}/{:.1e}, '.format( |
|
current_record, |
|
self.loss_mean['loss'][jj].item(), |
|
self.loss_mean['mse'][jj].item(), |
|
) |
|
log_str += 'lr:{:.2e}'.format(self.optimizer.param_groups[0]['lr']) |
|
self.logger.info(log_str) |
|
self.logging_metric(self.loss_mean, tag='Loss', phase=phase, add_global_step=True) |
|
if self.current_iters % self.configs.train.log_freq[1] == 0: |
|
self.logging_image(batch['lq'], tag='lq', phase=phase, add_global_step=False) |
|
self.logging_image(batch['gt'], tag='gt', phase=phase, add_global_step=False) |
|
x_t = self.base_diffusion.decode_first_stage( |
|
self.base_diffusion._scale_input(z_t, tt), |
|
self.autoencoder, |
|
) |
|
self.logging_image(x_t, tag='diffused', phase=phase, add_global_step=False) |
|
x0_pred = self.base_diffusion.decode_first_stage( |
|
z0_pred, |
|
self.autoencoder, |
|
) |
|
self.logging_image(x0_pred, tag='x0-pred', phase=phase, add_global_step=True) |
|
|
|
if self.current_iters % self.configs.train.save_freq == 1: |
|
self.tic = time.time() |
|
if self.current_iters % self.configs.train.save_freq == 0: |
|
self.toc = time.time() |
|
elaplsed = (self.toc - self.tic) |
|
self.logger.info(f"Elapsed time: {elaplsed:.2f}s") |
|
self.logger.info("="*100) |
|
|
|
def validation(self, phase='val'): |
|
if self.rank == 0: |
|
if self.configs.train.use_ema_val: |
|
self.reload_ema_model() |
|
self.ema_model.eval() |
|
else: |
|
self.model.eval() |
|
|
|
indices = np.linspace( |
|
0, |
|
self.base_diffusion.num_timesteps, |
|
self.base_diffusion.num_timesteps if self.base_diffusion.num_timesteps < 5 else 4, |
|
endpoint=False, |
|
dtype=np.int64, |
|
).tolist() |
|
if not (self.base_diffusion.num_timesteps-1) in indices: |
|
indices.append(self.base_diffusion.num_timesteps-1) |
|
batch_size = self.configs.train.batch[1] |
|
num_iters_epoch = math.ceil(len(self.datasets[phase]) / batch_size) |
|
mean_psnr = mean_lpips = 0 |
|
for ii, data in enumerate(self.dataloaders[phase]): |
|
data = self.prepare_data(data, phase='val') |
|
if 'gt' in data: |
|
im_lq, im_gt = data['lq'], data['gt'] |
|
else: |
|
im_lq = data['lq'] |
|
num_iters = 0 |
|
if self.configs.model.params.cond_lq: |
|
model_kwargs = {'lq':data['lq'],} |
|
if 'mask' in data: |
|
model_kwargs['mask'] = data['mask'] |
|
else: |
|
model_kwargs = None |
|
tt = torch.tensor( |
|
[self.base_diffusion.num_timesteps, ]*im_lq.shape[0], |
|
dtype=torch.int64, |
|
).cuda() |
|
for sample in self.base_diffusion.p_sample_loop_progressive( |
|
y=im_lq, |
|
model=self.ema_model if self.configs.train.use_ema_val else self.model, |
|
first_stage_model=self.autoencoder, |
|
noise=None, |
|
clip_denoised=True if self.autoencoder is None else False, |
|
model_kwargs=model_kwargs, |
|
device=f"cuda:{self.rank}", |
|
progress=False, |
|
): |
|
sample_decode = {} |
|
if num_iters in indices: |
|
for key, value in sample.items(): |
|
if key in ['sample', ]: |
|
sample_decode[key] = self.base_diffusion.decode_first_stage( |
|
value, |
|
self.autoencoder, |
|
).clamp(-1.0, 1.0) |
|
im_sr_progress = sample_decode['sample'] |
|
if num_iters + 1 == 1: |
|
im_sr_all = im_sr_progress |
|
else: |
|
im_sr_all = torch.cat((im_sr_all, im_sr_progress), dim=1) |
|
num_iters += 1 |
|
tt -= 1 |
|
|
|
if 'gt' in data: |
|
mean_psnr += util_image.batch_PSNR( |
|
sample_decode['sample'] * 0.5 + 0.5, |
|
im_gt * 0.5 + 0.5, |
|
ycbcr=self.configs.train.val_y_channel, |
|
) |
|
mean_lpips += self.lpips_loss( |
|
sample_decode['sample'], |
|
im_gt, |
|
).sum().item() |
|
|
|
if (ii + 1) % self.configs.train.log_freq[2] == 0: |
|
self.logger.info(f'Validation: {ii+1:02d}/{num_iters_epoch:02d}...') |
|
|
|
im_sr_all = rearrange(im_sr_all, 'b (k c) h w -> (b k) c h w', c=im_lq.shape[1]) |
|
self.logging_image( |
|
im_sr_all, |
|
tag='progress', |
|
phase=phase, |
|
add_global_step=False, |
|
nrow=len(indices), |
|
) |
|
if 'gt' in data: |
|
self.logging_image(im_gt, tag='gt', phase=phase, add_global_step=False) |
|
self.logging_image(im_lq, tag='lq', phase=phase, add_global_step=True) |
|
|
|
if 'gt' in data: |
|
mean_psnr /= len(self.datasets[phase]) |
|
mean_lpips /= len(self.datasets[phase]) |
|
self.logger.info(f'Validation Metric: PSNR={mean_psnr:5.2f}, LPIPS={mean_lpips:6.4f}...') |
|
self.logging_metric(mean_psnr, tag='PSNR', phase=phase, add_global_step=False) |
|
self.logging_metric(mean_lpips, tag='LPIPS', phase=phase, add_global_step=True) |
|
|
|
self.logger.info("="*100) |
|
|
|
if not (self.configs.train.use_ema_val and hasattr(self.configs.train, 'ema_rate')): |
|
self.model.train() |
|
|
|
class TrainerDifIRLPIPS(TrainerDifIR): |
|
def backward_step(self, dif_loss_wrapper, micro_data, num_grad_accumulate, tt): |
|
loss_coef = self.configs.train.get('loss_coef') |
|
context = torch.cuda.amp.autocast if self.configs.train.use_amp else nullcontext |
|
|
|
with context(): |
|
losses, z_t, z0_pred = dif_loss_wrapper() |
|
x0_pred = self.base_diffusion.decode_first_stage( |
|
z0_pred, |
|
self.autoencoder, |
|
) |
|
self.current_x0_pred = x0_pred.detach() |
|
|
|
|
|
losses["lpips"] = self.lpips_loss( |
|
x0_pred.clamp(-1.0, 1.0), |
|
micro_data['gt'], |
|
).to(z0_pred.dtype).view(-1) |
|
flag_nan = torch.any(torch.isnan(losses["lpips"])) |
|
if flag_nan: |
|
losses["lpips"] = torch.nan_to_num(losses["lpips"], nan=0.0) |
|
|
|
losses["mse"] *= loss_coef[0] |
|
losses["lpips"] *= loss_coef[1] |
|
|
|
assert losses["mse"].shape == losses["lpips"].shape |
|
if flag_nan: |
|
losses["loss"] = losses["mse"] |
|
else: |
|
losses["loss"] = losses["mse"] + losses["lpips"] |
|
loss = losses['loss'].mean() / num_grad_accumulate |
|
if self.amp_scaler is None: |
|
loss.backward() |
|
else: |
|
self.amp_scaler.scale(loss).backward() |
|
|
|
return losses, z0_pred, z_t |
|
|
|
def log_step_train(self, loss, tt, batch, z_t, z0_pred, phase='train'): |
|
''' |
|
param loss: a dict recording the loss informations |
|
param tt: 1-D tensor, time steps |
|
''' |
|
if self.rank == 0: |
|
chn = batch['gt'].shape[1] |
|
num_timesteps = self.base_diffusion.num_timesteps |
|
record_steps = [1, (num_timesteps // 2) + 1, num_timesteps] |
|
if self.current_iters % self.configs.train.log_freq[0] == 1: |
|
self.loss_mean = {key:torch.zeros(size=(len(record_steps),), dtype=torch.float64) |
|
for key in loss.keys()} |
|
self.loss_count = torch.zeros(size=(len(record_steps),), dtype=torch.float64) |
|
for jj in range(len(record_steps)): |
|
for key, value in loss.items(): |
|
index = record_steps[jj] - 1 |
|
mask = torch.where(tt == index, torch.ones_like(tt), torch.zeros_like(tt)) |
|
assert value.shape == mask.shape |
|
current_loss = torch.sum(value.detach() * mask) |
|
self.loss_mean[key][jj] += current_loss.item() |
|
self.loss_count[jj] += mask.sum().item() |
|
|
|
if self.current_iters % self.configs.train.log_freq[0] == 0: |
|
if torch.any(self.loss_count == 0): |
|
self.loss_count += 1e-4 |
|
for key in loss.keys(): |
|
self.loss_mean[key] /= self.loss_count |
|
log_str = 'Train: {:06d}/{:06d}, MSE/LPIPS: '.format( |
|
self.current_iters, |
|
self.configs.train.iterations) |
|
for jj, current_record in enumerate(record_steps): |
|
log_str += 't({:d}):{:.1e}/{:.1e}, '.format( |
|
current_record, |
|
self.loss_mean['mse'][jj].item(), |
|
self.loss_mean['lpips'][jj].item(), |
|
) |
|
log_str += 'lr:{:.2e}'.format(self.optimizer.param_groups[0]['lr']) |
|
self.logger.info(log_str) |
|
self.logging_metric(self.loss_mean, tag='Loss', phase=phase, add_global_step=True) |
|
if self.current_iters % self.configs.train.log_freq[1] == 0: |
|
self.logging_image(batch['lq'], tag='lq', phase=phase, add_global_step=False) |
|
self.logging_image(batch['gt'], tag='gt', phase=phase, add_global_step=False) |
|
x_t = self.base_diffusion.decode_first_stage( |
|
self.base_diffusion._scale_input(z_t, tt), |
|
self.autoencoder, |
|
) |
|
self.logging_image(x_t, tag='diffused', phase=phase, add_global_step=False) |
|
self.logging_image(self.current_x0_pred, tag='x0-pred', phase=phase, add_global_step=True) |
|
|
|
if self.current_iters % self.configs.train.save_freq == 1: |
|
self.tic = time.time() |
|
if self.current_iters % self.configs.train.save_freq == 0: |
|
self.toc = time.time() |
|
elaplsed = (self.toc - self.tic) |
|
self.logger.info(f"Elapsed time: {elaplsed:.2f}s") |
|
self.logger.info("="*100) |
|
|
|
def replace_nan_in_batch(im_lq, im_gt): |
|
''' |
|
Input: |
|
im_lq, im_gt: b x c x h x w |
|
''' |
|
if torch.isnan(im_lq).sum() > 0: |
|
valid_index = [] |
|
im_lq = im_lq.contiguous() |
|
for ii in range(im_lq.shape[0]): |
|
if torch.isnan(im_lq[ii,]).sum() == 0: |
|
valid_index.append(ii) |
|
assert len(valid_index) > 0 |
|
im_lq, im_gt = im_lq[valid_index,], im_gt[valid_index,] |
|
flag = True |
|
else: |
|
flag = False |
|
return im_lq, im_gt, flag |
|
|
|
def my_worker_init_fn(worker_id): |
|
np.random.seed(np.random.get_state()[1][0] + worker_id) |
|
|
|
if __name__ == '__main__': |
|
from utils import util_image |
|
from einops import rearrange |
|
im1 = util_image.imread('./testdata/inpainting/val/places/Places365_val_00012685_crop000.png', |
|
chn = 'rgb', dtype='float32') |
|
im2 = util_image.imread('./testdata/inpainting/val/places/Places365_val_00014886_crop000.png', |
|
chn = 'rgb', dtype='float32') |
|
im = rearrange(np.stack((im1, im2), 3), 'h w c b -> b c h w') |
|
im_grid = im.copy() |
|
for alpha in [0.8, 0.4, 0.1, 0]: |
|
im_new = im * alpha + np.random.randn(*im.shape) * (1 - alpha) |
|
im_grid = np.concatenate((im_new, im_grid), 1) |
|
|
|
im_grid = np.clip(im_grid, 0.0, 1.0) |
|
im_grid = rearrange(im_grid, 'b (k c) h w -> (b k) c h w', k=5) |
|
xx = vutils.make_grid(torch.from_numpy(im_grid), nrow=5, normalize=True, scale_each=True).numpy() |
|
util_image.imshow(np.concatenate((im1, im2), 0)) |
|
util_image.imshow(xx.transpose((1,2,0))) |
|
|
|
|