File size: 2,664 Bytes
c985ba4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import importlib
import random
import sys
sys.setrecursionlimit(10000)
sys.path.append('.')
sys.path.append('..')
import torch.multiprocessing as mp
from networks.managers.trainer import Trainer
def main_worker(gpu, cfg, enable_amp=True):
# Initiate a training manager
trainer = Trainer(rank=gpu, cfg=cfg, enable_amp=enable_amp)
# Start Training
trainer.sequential_training()
def main():
import argparse
parser = argparse.ArgumentParser(description="Train VOS")
parser.add_argument('--exp_name', type=str, default='')
parser.add_argument('--stage', type=str, default='pre')
parser.add_argument('--model', type=str, default='aott')
parser.add_argument('--max_id_num', type=int, default='-1')
parser.add_argument('--start_gpu', type=int, default=0)
parser.add_argument('--gpu_num', type=int, default=-1)
parser.add_argument('--batch_size', type=int, default=-1)
parser.add_argument('--dist_url', type=str, default='')
parser.add_argument('--amp', action='store_true')
parser.set_defaults(amp=False)
parser.add_argument('--pretrained_path', type=str, default='')
parser.add_argument('--datasets', nargs='+', type=str, default=[])
parser.add_argument('--lr', type=float, default=-1.)
parser.add_argument('--total_step', type=int, default=-1.)
parser.add_argument('--start_step', type=int, default=-1.)
args = parser.parse_args()
engine_config = importlib.import_module('configs.' + args.stage)
cfg = engine_config.EngineConfig(args.exp_name, args.model)
if len(args.datasets) > 0:
cfg.DATASETS = args.datasets
cfg.DIST_START_GPU = args.start_gpu
if args.gpu_num > 0:
cfg.TRAIN_GPUS = args.gpu_num
if args.batch_size > 0:
cfg.TRAIN_BATCH_SIZE = args.batch_size
if args.pretrained_path != '':
cfg.PRETRAIN_MODEL = args.pretrained_path
if args.max_id_num > 0:
cfg.MODEL_MAX_OBJ_NUM = args.max_id_num
if args.lr > 0:
cfg.TRAIN_LR = args.lr
if args.total_step > 0:
cfg.TRAIN_TOTAL_STEPS = args.total_step
if args.start_step > 0:
cfg.TRAIN_START_STEP = args.start_step
if args.dist_url == '':
cfg.DIST_URL = 'tcp://127.0.0.1:123' + str(random.randint(0, 9)) + str(
random.randint(0, 9))
else:
cfg.DIST_URL = args.dist_url
if cfg.TRAIN_GPUS > 1:
# Use torch.multiprocessing.spawn to launch distributed processes
mp.spawn(main_worker, nprocs=cfg.TRAIN_GPUS, args=(cfg, args.amp))
else:
cfg.TRAIN_GPUS = 1
main_worker(0, cfg, args.amp)
if __name__ == '__main__':
main()
|