File size: 2,664 Bytes
c985ba4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import importlib
import random
import sys

sys.setrecursionlimit(10000)
sys.path.append('.')
sys.path.append('..')

import torch.multiprocessing as mp

from networks.managers.trainer import Trainer


def main_worker(gpu, cfg, enable_amp=True):
    # Initiate a training manager
    trainer = Trainer(rank=gpu, cfg=cfg, enable_amp=enable_amp)
    # Start Training
    trainer.sequential_training()


def main():
    import argparse
    parser = argparse.ArgumentParser(description="Train VOS")
    parser.add_argument('--exp_name', type=str, default='')
    parser.add_argument('--stage', type=str, default='pre')
    parser.add_argument('--model', type=str, default='aott')
    parser.add_argument('--max_id_num', type=int, default='-1')

    parser.add_argument('--start_gpu', type=int, default=0)
    parser.add_argument('--gpu_num', type=int, default=-1)
    parser.add_argument('--batch_size', type=int, default=-1)
    parser.add_argument('--dist_url', type=str, default='')
    parser.add_argument('--amp', action='store_true')
    parser.set_defaults(amp=False)

    parser.add_argument('--pretrained_path', type=str, default='')

    parser.add_argument('--datasets', nargs='+', type=str, default=[])
    parser.add_argument('--lr', type=float, default=-1.)
    parser.add_argument('--total_step', type=int, default=-1.)
    parser.add_argument('--start_step', type=int, default=-1.)

    args = parser.parse_args()

    engine_config = importlib.import_module('configs.' + args.stage)

    cfg = engine_config.EngineConfig(args.exp_name, args.model)

    if len(args.datasets) > 0:
        cfg.DATASETS = args.datasets

    cfg.DIST_START_GPU = args.start_gpu
    if args.gpu_num > 0:
        cfg.TRAIN_GPUS = args.gpu_num
    if args.batch_size > 0:
        cfg.TRAIN_BATCH_SIZE = args.batch_size

    if args.pretrained_path != '':
        cfg.PRETRAIN_MODEL = args.pretrained_path

    if args.max_id_num > 0:
        cfg.MODEL_MAX_OBJ_NUM = args.max_id_num

    if args.lr > 0:
        cfg.TRAIN_LR = args.lr

    if args.total_step > 0:
        cfg.TRAIN_TOTAL_STEPS = args.total_step

    if args.start_step > 0:
        cfg.TRAIN_START_STEP = args.start_step

    if args.dist_url == '':
        cfg.DIST_URL = 'tcp://127.0.0.1:123' + str(random.randint(0, 9)) + str(
            random.randint(0, 9))
    else:
        cfg.DIST_URL = args.dist_url
        
    if cfg.TRAIN_GPUS > 1:
        # Use torch.multiprocessing.spawn to launch distributed processes
        mp.spawn(main_worker, nprocs=cfg.TRAIN_GPUS, args=(cfg, args.amp))
    else:
        cfg.TRAIN_GPUS = 1
        main_worker(0, cfg, args.amp)

if __name__ == '__main__':
    main()