|
import os |
|
import cv2 |
|
import time |
|
import random |
|
import datetime |
|
import argparse |
|
import numpy as np |
|
from tqdm import tqdm |
|
from piq import ssim,psnr |
|
from itertools import cycle |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.utils import data |
|
import torch.distributed as dist |
|
from torch.utils.data.distributed import DistributedSampler |
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
|
|
|
from utils import dict2string,mkdir,get_lr,torch2cvimg,second2hours |
|
from loaders import docres_loader |
|
from models import restormer_arch |
|
|
|
|
|
def seed_torch(seed=1029): |
|
random.seed(seed) |
|
os.environ['PYTHONHASHSEED'] = str(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
torch.backends.cudnn.benchmark = False |
|
torch.backends.cudnn.deterministic = True |
|
|
|
|
|
|
|
|
|
def getBasecoord(h,w): |
|
base_coord0 = np.tile(np.arange(h).reshape(h,1),(1,w)).astype(np.float32) |
|
base_coord1 = np.tile(np.arange(w).reshape(1,w),(h,1)).astype(np.float32) |
|
base_coord = np.concatenate((np.expand_dims(base_coord1,-1),np.expand_dims(base_coord0,-1)),-1) |
|
return base_coord |
|
|
|
def train(args): |
|
|
|
|
|
dist.init_process_group(backend='nccl',init_method='env://',timeout=datetime.timedelta(seconds=36000)) |
|
torch.cuda.set_device(args.local_rank) |
|
device = torch.device('cuda',args.local_rank) |
|
torch.cuda.manual_seed_all(42) |
|
|
|
|
|
mkdir(args.logdir) |
|
mkdir(os.path.join(args.logdir,args.experiment_name)) |
|
log_file_path=os.path.join(args.logdir,args.experiment_name,'log.txt') |
|
log_file=open(log_file_path,'a') |
|
log_file.write('\n--------------- '+args.experiment_name+' ---------------\n') |
|
log_file.close() |
|
|
|
|
|
if args.tboard: |
|
writer = SummaryWriter(os.path.join(args.logdir,args.experiment_name,'runs'),args.experiment_name) |
|
|
|
|
|
datasets_setting = [ |
|
{'task':'deblurring','ratio':1,'im_path':'/home/jiaxin/Training_Data/DocRes_data/train/deblurring/','json_paths':['/home/jiaxin/Training_Data/DocRes_data/train/deblurring/tdd/train.json']}, |
|
{'task':'dewarping','ratio':1,'im_path':'/home/jiaxin/Training_Data/DocRes_data/train/dewarping/','json_paths':['/home/jiaxin/Training_Data/DocRes_data/train/dewarping/doc3d/train_1_19.json']}, |
|
{'task':'binarization','ratio':1,'im_path':'/home/jiaxin/Training_Data/DocRes_data/train/binarization/','json_paths':['/home/jiaxin/Training_Data/DocRes_data/train/binarization/train.json']}, |
|
{'task':'deshadowing','ratio':1,'im_path':'/home/jiaxin/Training_Data/DocRes_data/train/deshadowing/','json_paths':['/home/jiaxin/Training_Data/DocRes_data/train/deshadowing/train.json']}, |
|
{'task':'appearance','ratio':1,'im_path':'/home/jiaxin/Training_Data/DocRes_data/train/appearance/','json_paths':['/home/jiaxin/Training_Data/DocRes_data/train/appearance/trainv2.json']} |
|
] |
|
|
|
|
|
ratios = [dataset_setting['ratio'] for dataset_setting in datasets_setting] |
|
datasets = [docres_loader.DocResTrainDataset(dataset=dataset_setting,img_size=args.im_size) for dataset_setting in datasets_setting] |
|
trainloaders = [{'task':datasets_setting[i],'loader':data.DataLoader(dataset=datasets[i], sampler=DistributedSampler(datasets[i]), batch_size=args.batch_size, num_workers=2, pin_memory=True,drop_last=True),'iter_loader':iter(data.DataLoader(dataset=datasets[i], sampler=DistributedSampler(datasets[i]), batch_size=args.batch_size, num_workers=2, pin_memory=True,drop_last=True))} for i in range(len(datasets))] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = restormer_arch.Restormer( |
|
inp_channels=6, |
|
out_channels=3, |
|
dim = 48, |
|
num_blocks = [2,3,3,4], |
|
num_refinement_blocks = 4, |
|
heads = [1,2,4,8], |
|
ffn_expansion_factor = 2.66, |
|
bias = False, |
|
LayerNorm_type = 'WithBias', |
|
dual_pixel_task = True |
|
) |
|
model=DDP(model.cuda(),device_ids=[args.local_rank],output_device=args.local_rank) |
|
|
|
|
|
optimizer= torch.optim.AdamW(model.parameters(),lr=args.l_rate,weight_decay=5e-4) |
|
|
|
|
|
sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.total_iter, eta_min=1e-6, last_epoch=-1) |
|
|
|
|
|
iter_start=0 |
|
if args.resume is not None: |
|
print("Loading model and optimizer from checkpoint '{}'".format(args.resume)) |
|
x = checkpoint['model_state'] |
|
model.load_state_dict(x,strict=False) |
|
iter_start=checkpoint['iter'] |
|
print("Loaded checkpoint '{}' (iter {})".format(args.resume, iter_start)) |
|
|
|
|
|
|
|
scaler = torch.cuda.amp.GradScaler() |
|
loss_dict = {} |
|
total_step = 0 |
|
l2 = nn.MSELoss() |
|
l1 = nn.L1Loss() |
|
ce = nn.CrossEntropyLoss() |
|
bce = nn.BCEWithLogitsLoss() |
|
m = nn.Sigmoid() |
|
best = 0 |
|
best_ce = 999 |
|
|
|
|
|
for iters in range(iter_start,args.total_iter): |
|
start_time = time.time() |
|
loader_index = random.choices(list(range(len(trainloaders))),ratios)[0] |
|
|
|
try: |
|
in_im,gt_im = next(trainloaders[loader_index]['iter_loader']) |
|
except StopIteration: |
|
trainloaders[loader_index]['iter_loader']=iter(trainloaders[loader_index]['loader']) |
|
in_im,gt_im = next(trainloaders[loader_index]['iter_loader']) |
|
in_im = in_im.float().cuda() |
|
gt_im = gt_im.float().cuda() |
|
|
|
binarization_loss,appearance_loss,dewarping_loss,deblurring_loss,deshadowing_loss = 0,0,0,0,0 |
|
with torch.cuda.amp.autocast(): |
|
pred_im = model(in_im,trainloaders[loader_index]['task']['task']) |
|
if trainloaders[loader_index]['task']['task'] == 'binarization': |
|
gt_im = gt_im.long() |
|
binarization_loss = ce(pred_im[:,:2,:,:], gt_im[:,0,:,:]) |
|
loss = binarization_loss |
|
elif trainloaders[loader_index]['task']['task'] == 'dewarping': |
|
dewarping_loss = l1(pred_im[:,:2,:,:], gt_im[:,:2,:,:]) |
|
loss = dewarping_loss |
|
elif trainloaders[loader_index]['task']['task'] == 'appearance': |
|
appearance_loss = l1(pred_im, gt_im) |
|
loss = appearance_loss |
|
elif trainloaders[loader_index]['task']['task'] == 'deblurring': |
|
deblurring_loss = l1(pred_im, gt_im) |
|
loss = deblurring_loss |
|
elif trainloaders[loader_index]['task']['task'] == 'deshadowing': |
|
deshadowing_loss = l1(pred_im, gt_im) |
|
loss = deshadowing_loss |
|
|
|
optimizer.zero_grad() |
|
scaler.scale(loss).backward() |
|
scaler.step(optimizer) |
|
scaler.update() |
|
|
|
loss_dict['dew_loss']=dewarping_loss.item() if isinstance(dewarping_loss,torch.Tensor) else 0 |
|
loss_dict['app_loss']=appearance_loss.item() if isinstance(appearance_loss,torch.Tensor) else 0 |
|
loss_dict['des_loss']=deshadowing_loss.item() if isinstance(deshadowing_loss,torch.Tensor) else 0 |
|
loss_dict['deb_loss']=deblurring_loss.item() if isinstance(deblurring_loss,torch.Tensor) else 0 |
|
loss_dict['bin_loss']=binarization_loss.item() if isinstance(binarization_loss,torch.Tensor) else 0 |
|
end_time = time.time() |
|
duration = end_time-start_time |
|
|
|
if (iters+1) % 10 == 0: |
|
|
|
print('iters [{}/{}] -- '.format(iters+1,args.total_iter)+dict2string(loss_dict)+' --lr {:6f}'.format(get_lr(optimizer))+' -- time {}'.format(second2hours(duration*(args.total_iter-iters)))) |
|
|
|
if args.tboard: |
|
for key,value in loss_dict.items(): |
|
writer.add_scalar('Train '+key+'/Iterations', value, total_step) |
|
|
|
with open(log_file_path,'a') as f: |
|
f.write('iters [{}/{}] -- '.format(iters+1,args.total_iter)+dict2string(loss_dict)+' --lr {:6f}'.format(get_lr(optimizer))+' -- time {}'.format(second2hours(duration*(args.total_iter-iters)))+'\n') |
|
|
|
|
|
if (iters+1) % 5000 == 0: |
|
state = {'iters': iters+1, |
|
'model_state': model.state_dict(), |
|
'optimizer_state' : optimizer.state_dict(),} |
|
if not os.path.exists(os.path.join(args.logdir,args.experiment_name)): |
|
os.system('mkdir ' + os.path.join(args.logdir,args.experiment_name)) |
|
if torch.distributed.get_rank()==0: |
|
torch.save(state, os.path.join(args.logdir,args.experiment_name,"{}.pkl".format(iters+1))) |
|
|
|
sched.step() |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser(description='Hyperparams') |
|
parser.add_argument('--im_size', nargs='?', type=int, default=256, |
|
help='Height of the input image') |
|
parser.add_argument('--total_iter', nargs='?', type=int, default=100000, |
|
help='# of the epochs') |
|
parser.add_argument('--batch_size', nargs='?', type=int, default=10, |
|
help='Batch Size') |
|
parser.add_argument('--l_rate', nargs='?', type=float, default=2e-4, |
|
help='Learning Rate') |
|
parser.add_argument('--resume', nargs='?', type=str, default=None, |
|
help='Path to previous saved model to restart from') |
|
parser.add_argument('--logdir', nargs='?', type=str, default='./checkpoints/', |
|
help='Path to store the loss logs') |
|
parser.add_argument('--tboard', dest='tboard', action='store_true', |
|
help='Enable visualization(s) on tensorboard | False by default') |
|
parser.add_argument('--local_rank',type=int,default=0,metavar='N') |
|
parser.add_argument('--experiment_name', nargs='?', type=str,default='experiment_name', |
|
help='the name of this experiment') |
|
parser.set_defaults(tboard=False) |
|
args = parser.parse_args() |
|
|
|
train(args) |