AiOS / main.py
ttxskk
update
d7e58f0
import argparse
import datetime
import json
import random
import time
from pathlib import Path
import os, sys
from util.get_param_dicts import get_param_dict
from util.logger import setup_logger
import numpy as np
import torch
import util.misc as utils
from detrsmpl.data.datasets import build_dataloader
from mmcv.parallel import MMDistributedDataParallel
from engine import evaluate, train_one_epoch, inference
from util.config import DictAction
from util.utils import ModelEma
import shutil
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
import config.config as cfg
from datasets.dataset import MultipleDatasets
def get_args_parser():
parser = argparse.ArgumentParser('Set transformer detector',
add_help=False)
parser.add_argument('--config_file', '-c', type=str, required=True)
parser.add_argument(
'--options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file.')
# parser.add_argument('--exp_name', default='data/log/smplx_test', type=str)
# dataset parameters
# training parameters
parser.add_argument('--output_dir',
default='',
help='path where to save, empty for no saving')
parser.add_argument('--device',
default='cuda',
help='device to use for training / testing')
parser.add_argument('--seed', default=42, type=int)
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--pretrain_model_path',
help='load from other checkpoint')
parser.add_argument('--finetune_ignore', type=str, nargs='+')
parser.add_argument('--start_epoch',
default=0,
type=int,
metavar='N',
help='start epoch')
parser.add_argument('--eval', action='store_true')
parser.add_argument('--num_workers', default=0, type=int)
parser.add_argument('--test', action='store_true')
parser.add_argument('--debug', action='store_true')
parser.add_argument('--find_unused_params', action='store_true')
parser.add_argument('--save_log', action='store_true')
parser.add_argument('--to_vid', action='store_true')
parser.add_argument('--inference', action='store_true')
# distributed training parameters
parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
parser.add_argument('--rank', default=0, type=int,
help='number of distributed processes')
parser.add_argument("--local_rank", default=0, type=int, help='local rank for DistributedDataParallel')
parser.add_argument('--amp', action='store_true',
help="Train with mixed precision")
parser.add_argument('--inference_input', default=None, type=str)
return parser
def build_model_main(args, cfg):
print(args.modelname)
from models.registry import MODULE_BUILD_FUNCS
assert args.modelname in MODULE_BUILD_FUNCS._module_dict
build_func = MODULE_BUILD_FUNCS.get(args.modelname)
model, criterion, postprocessors, postprocessors_aios = build_func(
args, cfg)
return model, criterion, postprocessors, postprocessors_aios
def main(args):
utils.init_distributed_mode(args)
print('Loading config file from {}'.format(args.config_file))
shutil.copy2(args.config_file,'config/aios_smplx.py')
from config.config import cfg
if args.options is not None:
cfg.merge_from_dict(args.options)
if args.rank == 0:
save_cfg_path = os.path.join(args.output_dir, 'config_cfg.py')
cfg.dump(save_cfg_path)
save_json_path = os.path.join(args.output_dir, 'config_args_raw.json')
with open(save_json_path, 'w') as f:
json.dump(vars(args), f, indent=2)
cfg_dict = cfg._cfg_dict.to_dict()
args_vars = vars(args)
for k, v in cfg_dict.items():
if k not in args_vars:
setattr(args, k, v)
else:
continue
raise ValueError('Key {} can used by args only'.format(k))
# update some new args temporally
if not getattr(args, 'use_ema', None):
args.use_ema = False
if not getattr(args, 'debug', None):
args.debug = False
# setup logger
os.makedirs(args.output_dir, exist_ok=True)
logger = setup_logger(output=os.path.join(args.output_dir, 'info.txt'),
distributed_rank=args.rank,
color=False,
name='detr')
logger.info('git:\n {}\n'.format(utils.get_sha()))
logger.info('Command: ' + ' '.join(sys.argv))
writer = None
if args.rank == 0:
writer = SummaryWriter(args.output_dir)
save_json_path = os.path.join(args.output_dir, 'config_args_all.json')
# print("args:", vars(args))
with open(save_json_path, 'w') as f:
json.dump(vars(args), f, indent=2)
logger.info('Full config saved to {}'.format(save_json_path))
logger.info('world size: {}'.format(args.world_size))
logger.info('rank: {}'.format(args.rank))
logger.info('local_rank: {}'.format(args.local_rank))
logger.info('args: ' + str(args) + '\n')
if args.frozen_weights is not None:
assert args.masks, 'Frozen training is meant for segmentation only'
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
# build model
model, criterion, postprocessors, _ = build_model_main(
args, cfg)
wo_class_error = False
model.to(device)
# ema
if args.use_ema:
ema_m = ModelEma(model, args.ema_decay)
else:
ema_m = None
model_without_ddp = model
if args.distributed:
model = MMDistributedDataParallel(
model,
device_ids=[args.gpu],
find_unused_parameters=args.find_unused_params)
model_without_ddp = model.module
n_parameters = sum(p.numel() for p in model.parameters()
if p.requires_grad)
logger.info('number of params:' + str(n_parameters))
logger.info('params:\n' + json.dumps(
{n: p.numel()
for n, p in model.named_parameters() if p.requires_grad},
indent=2))
param_dicts = get_param_dict(args, model_without_ddp)
optimizer = torch.optim.AdamW(param_dicts,
lr=args.lr,
weight_decay=args.weight_decay)
logger.info('Creating dataset...')
if not args.eval:
trainset= []
for trainset_i,v in cfg.trainset_partition.items():
exec('from datasets.' + trainset_i +
' import ' + trainset_i)
trainset.append(
eval(trainset_i)(transforms.ToTensor(), 'train'))
trainset_loader = MultipleDatasets(trainset, make_same_len=False,partition=cfg.trainset_partition)
data_loader_train = build_dataloader(
trainset_loader,
args.batch_size,
0 if 'workers_per_gpu' in args else 1,
dist=args.distributed)
exec('from datasets.' + cfg.testset +
' import ' + cfg.testset)
if not args.inference:
dataset_val = eval(cfg.testset)(transforms.ToTensor(), "test")
else:
dataset_val = eval(cfg.testset)(args.inference_input, args.output_dir)
data_loader_val = build_dataloader(
dataset_val,
args.batch_size,
0 if 'workers_per_gpu' in args else 2,
dist=args.distributed,
shuffle=False)
if args.onecyclelr:
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=args.lr,
steps_per_epoch=len(data_loader_train),
epochs=args.epochs,
pct_start=0.2)
elif args.multi_step_lr:
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=args.lr_drop_list)
else:
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)
if args.frozen_weights is not None:
checkpoint = torch.load(args.frozen_weights, map_location='cpu')
model_without_ddp.detr.load_state_dict(checkpoint['model'])
output_dir = Path(args.output_dir)
if os.path.exists(os.path.join(args.output_dir, 'checkpoint.pth')):
args.resume = os.path.join(args.output_dir, 'checkpoint.pth')
if args.resume:
if args.resume.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(args.resume,
map_location='cpu',
check_hash=True)
else:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
if args.use_ema:
if 'ema_model' in checkpoint:
ema_m.module.load_state_dict(
utils.clean_state_dict(checkpoint['ema_model']))
else:
del ema_m
ema_m = ModelEma(model, args.ema_decay)
if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1
if (not args.resume) and args.pretrain_model_path:
checkpoint = torch.load(args.pretrain_model_path,
map_location='cpu')['model']
from collections import OrderedDict
_ignorekeywordlist = args.finetune_ignore if args.finetune_ignore else []
ignorelist = []
def check_keep(keyname, ignorekeywordlist):
for keyword in ignorekeywordlist:
if keyword in keyname:
ignorelist.append(keyname)
return False
return True
_tmp_st = OrderedDict({
k: v
for k, v in utils.clean_state_dict(checkpoint).items()
if check_keep(k, _ignorekeywordlist)
})
logger.info('Ignore keys: {}'.format(json.dumps(ignorelist, indent=2)))
# Change This
_load_output = model_without_ddp.load_state_dict(_tmp_st, strict=False)
print('loading')
logger.info(str(_load_output))
if args.use_ema:
if 'ema_model' in checkpoint:
ema_m.module.load_state_dict(utils.clean_state_dict(checkpoint['ema_model']))
else:
del ema_m
ema_m = ModelEma(model, args.ema_decay)
_load_output = model_without_ddp.load_state_dict(_tmp_st, strict=False)
logger.info(str(_load_output))
if args.eval:
os.environ['EVAL_FLAG'] = 'TRUE'
if args.inference_input is not None and args.inference:
inference(model,
criterion,
postprocessors,
data_loader_val,
device,
args.output_dir,
wo_class_error=wo_class_error,
args=args)
else:
from config.config import cfg
cfg.result_dir=args.output_dir
cfg.exp_name=args.pretrain_model_path
evaluate(model,
criterion,
postprocessors,
data_loader_val,
device,
args.output_dir,
wo_class_error=wo_class_error,
args=args)
return
print('Start training')
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
epoch_start_time = time.time()
train_stats = train_one_epoch(
model,
criterion,
data_loader_train,
optimizer,
device,
epoch,
args.clip_max_norm,
wo_class_error=wo_class_error,
lr_scheduler=lr_scheduler,
args=args,
logger=(logger if args.save_log else None),
ema_m=ema_m,
tf_writer=writer)
if args.output_dir:
checkpoint_paths = [output_dir / 'checkpoint.pth']
if not args.onecyclelr:
lr_scheduler.step()
if args.output_dir:
checkpoint_paths = [output_dir / 'checkpoint.pth']
# extra checkpoint before LR drop and every 100 epochs
if (epoch + 1) % args.lr_drop == 0 or (
epoch + 1) % args.save_checkpoint_interval == 0:
checkpoint_paths.append(output_dir /
f'checkpoint{epoch:04}.pth')
for checkpoint_path in checkpoint_paths:
weights = {
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'args': args,
}
if args.use_ema:
weights.update({
'ema_model': ema_m.module.state_dict(),
})
utils.save_on_master(weights, checkpoint_path)
log_stats = {
**{f'train_{k}': v
for k, v in train_stats.items()},
}
ep_paras = {'epoch': epoch, 'n_parameters': n_parameters}
log_stats.update(ep_paras)
try:
log_stats.update({'now_time': str(datetime.datetime.now())})
except:
pass
epoch_time = time.time() - epoch_start_time
epoch_time_str = str(datetime.timedelta(seconds=int(epoch_time)))
log_stats['epoch_time'] = epoch_time_str
if args.output_dir and utils.is_main_process():
with (output_dir / 'log.txt').open('a') as f:
f.write(json.dumps(log_stats) + '\n')
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
if __name__ == '__main__':
parser = argparse.ArgumentParser('DETR training and evaluation script',
parents=[get_args_parser()])
__spec__ = "ModuleSpec(name='builtins', loader=<class '_frozen_importlib.BuiltinImporter'>)"
args = parser.parse_args()
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args)