Spaces:
Runtime error
Runtime error
File size: 4,501 Bytes
24be7a2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import argparse
import logging
import os
import os.path as osp
import random
import time
import torch
from data.segm_attr_dataset import DeepFashionAttrSegmDataset
from models import create_model
from utils.logger import MessageLogger, get_root_logger, init_tb_logger
from utils.options import dict2str, dict_to_nonedict, parse
from utils.util import make_exp_dirs
def main():
# options
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.')
args = parser.parse_args()
opt = parse(args.opt, is_train=True)
# mkdir and loggers
make_exp_dirs(opt)
log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log")
logger = get_root_logger(
logger_name='base', log_level=logging.INFO, log_file=log_file)
logger.info(dict2str(opt))
# initialize tensorboard logger
tb_logger = None
if opt['use_tb_logger'] and 'debug' not in opt['name']:
tb_logger = init_tb_logger(log_dir='./tb_logger/' + opt['name'])
# convert to NoneDict, which returns None for missing keys
opt = dict_to_nonedict(opt)
# set up data loader
train_dataset = DeepFashionAttrSegmDataset(
img_dir=opt['train_img_dir'],
segm_dir=opt['segm_dir'],
pose_dir=opt['pose_dir'],
ann_dir=opt['train_ann_file'],
xflip=True)
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=opt['batch_size'],
shuffle=True,
num_workers=opt['num_workers'],
persistent_workers=True,
drop_last=True)
logger.info(f'Number of train set: {len(train_dataset)}.')
opt['max_iters'] = opt['num_epochs'] * len(
train_dataset) // opt['batch_size']
val_dataset = DeepFashionAttrSegmDataset(
img_dir=opt['train_img_dir'],
segm_dir=opt['segm_dir'],
pose_dir=opt['pose_dir'],
ann_dir=opt['val_ann_file'])
val_loader = torch.utils.data.DataLoader(
dataset=val_dataset, batch_size=1, shuffle=False)
logger.info(f'Number of val set: {len(val_dataset)}.')
test_dataset = DeepFashionAttrSegmDataset(
img_dir=opt['test_img_dir'],
segm_dir=opt['segm_dir'],
pose_dir=opt['pose_dir'],
ann_dir=opt['test_ann_file'])
test_loader = torch.utils.data.DataLoader(
dataset=test_dataset, batch_size=1, shuffle=False)
logger.info(f'Number of test set: {len(test_dataset)}.')
current_iter = 0
best_epoch = None
best_loss = 100000
model = create_model(opt)
data_time, iter_time = 0, 0
current_iter = 0
# create message logger (formatted outputs)
msg_logger = MessageLogger(opt, current_iter, tb_logger)
for epoch in range(opt['num_epochs']):
lr = model.update_learning_rate(epoch)
for _, batch_data in enumerate(train_loader):
data_time = time.time() - data_time
current_iter += 1
model.optimize_parameters(batch_data, current_iter)
iter_time = time.time() - iter_time
if current_iter % opt['print_freq'] == 0:
log_vars = {'epoch': epoch, 'iter': current_iter}
log_vars.update({'lrs': [lr]})
log_vars.update({'time': iter_time, 'data_time': data_time})
log_vars.update(model.get_current_log())
msg_logger(log_vars)
data_time = time.time()
iter_time = time.time()
if epoch % opt['val_freq'] == 0:
save_dir = f'{opt["path"]["visualization"]}/valset/epoch_{epoch:03d}' # noqa
os.makedirs(save_dir, exist_ok=opt['debug'])
val_loss_total = model.inference(val_loader, save_dir)
save_dir = f'{opt["path"]["visualization"]}/testset/epoch_{epoch:03d}' # noqa
os.makedirs(save_dir, exist_ok=opt['debug'])
test_loss_total = model.inference(test_loader, save_dir)
logger.info(f'Epoch: {epoch}, '
f'val_loss_total: {val_loss_total}, '
f'test_loss_total: {test_loss_total}.')
if test_loss_total < best_loss:
best_epoch = epoch
best_loss = test_loss_total
logger.info(f'Best epoch: {best_epoch}, '
f'Best test loss: {best_loss: .4f}.')
# save model
model.save_network(f'{opt["path"]["models"]}/epoch{epoch}.pth')
if __name__ == '__main__':
main()
|