File size: 4,501 Bytes
24be7a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import argparse
import logging
import os
import os.path as osp
import random
import time

import torch

from data.segm_attr_dataset import DeepFashionAttrSegmDataset
from models import create_model
from utils.logger import MessageLogger, get_root_logger, init_tb_logger
from utils.options import dict2str, dict_to_nonedict, parse
from utils.util import make_exp_dirs


def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, help='Path to option YAML file.')
    args = parser.parse_args()
    opt = parse(args.opt, is_train=True)

    # mkdir and loggers
    make_exp_dirs(opt)
    log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log")
    logger = get_root_logger(
        logger_name='base', log_level=logging.INFO, log_file=log_file)
    logger.info(dict2str(opt))
    # initialize tensorboard logger
    tb_logger = None
    if opt['use_tb_logger'] and 'debug' not in opt['name']:
        tb_logger = init_tb_logger(log_dir='./tb_logger/' + opt['name'])

    # convert to NoneDict, which returns None for missing keys
    opt = dict_to_nonedict(opt)

    # set up data loader
    train_dataset = DeepFashionAttrSegmDataset(
        img_dir=opt['train_img_dir'],
        segm_dir=opt['segm_dir'],
        pose_dir=opt['pose_dir'],
        ann_dir=opt['train_ann_file'],
        xflip=True)
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=opt['batch_size'],
        shuffle=True,
        num_workers=opt['num_workers'],
        persistent_workers=True,
        drop_last=True)
    logger.info(f'Number of train set: {len(train_dataset)}.')
    opt['max_iters'] = opt['num_epochs'] * len(
        train_dataset) // opt['batch_size']

    val_dataset = DeepFashionAttrSegmDataset(
        img_dir=opt['train_img_dir'],
        segm_dir=opt['segm_dir'],
        pose_dir=opt['pose_dir'],
        ann_dir=opt['val_ann_file'])
    val_loader = torch.utils.data.DataLoader(
        dataset=val_dataset, batch_size=1, shuffle=False)
    logger.info(f'Number of val set: {len(val_dataset)}.')

    test_dataset = DeepFashionAttrSegmDataset(
        img_dir=opt['test_img_dir'],
        segm_dir=opt['segm_dir'],
        pose_dir=opt['pose_dir'],
        ann_dir=opt['test_ann_file'])
    test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset, batch_size=1, shuffle=False)
    logger.info(f'Number of test set: {len(test_dataset)}.')

    current_iter = 0
    best_epoch = None
    best_loss = 100000

    model = create_model(opt)

    data_time, iter_time = 0, 0
    current_iter = 0

    # create message logger (formatted outputs)
    msg_logger = MessageLogger(opt, current_iter, tb_logger)

    for epoch in range(opt['num_epochs']):
        lr = model.update_learning_rate(epoch)

        for _, batch_data in enumerate(train_loader):
            data_time = time.time() - data_time

            current_iter += 1

            model.optimize_parameters(batch_data, current_iter)

            iter_time = time.time() - iter_time
            if current_iter % opt['print_freq'] == 0:
                log_vars = {'epoch': epoch, 'iter': current_iter}
                log_vars.update({'lrs': [lr]})
                log_vars.update({'time': iter_time, 'data_time': data_time})
                log_vars.update(model.get_current_log())
                msg_logger(log_vars)

            data_time = time.time()
            iter_time = time.time()

        if epoch % opt['val_freq'] == 0:
            save_dir = f'{opt["path"]["visualization"]}/valset/epoch_{epoch:03d}'  # noqa
            os.makedirs(save_dir, exist_ok=opt['debug'])
            val_loss_total = model.inference(val_loader, save_dir)

            save_dir = f'{opt["path"]["visualization"]}/testset/epoch_{epoch:03d}'  # noqa
            os.makedirs(save_dir, exist_ok=opt['debug'])
            test_loss_total = model.inference(test_loader, save_dir)

            logger.info(f'Epoch: {epoch}, '
                        f'val_loss_total: {val_loss_total}, '
                        f'test_loss_total: {test_loss_total}.')

            if test_loss_total < best_loss:
                best_epoch = epoch
                best_loss = test_loss_total

            logger.info(f'Best epoch: {best_epoch}, '
                        f'Best test loss: {best_loss: .4f}.')

            # save model
            model.save_network(f'{opt["path"]["models"]}/epoch{epoch}.pth')


if __name__ == '__main__':
    main()