|
import importlib |
|
import random |
|
import sys |
|
|
|
sys.setrecursionlimit(10000) |
|
sys.path.append('.') |
|
sys.path.append('..') |
|
|
|
import torch.multiprocessing as mp |
|
|
|
from networks.managers.trainer import Trainer |
|
|
|
|
|
def main_worker(gpu, cfg, enable_amp=True): |
|
|
|
trainer = Trainer(rank=gpu, cfg=cfg, enable_amp=enable_amp) |
|
|
|
trainer.sequential_training() |
|
|
|
|
|
def main(): |
|
import argparse |
|
parser = argparse.ArgumentParser(description="Train VOS") |
|
parser.add_argument('--exp_name', type=str, default='') |
|
parser.add_argument('--stage', type=str, default='pre') |
|
parser.add_argument('--model', type=str, default='aott') |
|
parser.add_argument('--max_id_num', type=int, default='-1') |
|
|
|
parser.add_argument('--start_gpu', type=int, default=0) |
|
parser.add_argument('--gpu_num', type=int, default=-1) |
|
parser.add_argument('--batch_size', type=int, default=-1) |
|
parser.add_argument('--dist_url', type=str, default='') |
|
parser.add_argument('--amp', action='store_true') |
|
parser.set_defaults(amp=False) |
|
|
|
parser.add_argument('--pretrained_path', type=str, default='') |
|
|
|
parser.add_argument('--datasets', nargs='+', type=str, default=[]) |
|
parser.add_argument('--lr', type=float, default=-1.) |
|
parser.add_argument('--total_step', type=int, default=-1.) |
|
parser.add_argument('--start_step', type=int, default=-1.) |
|
|
|
args = parser.parse_args() |
|
|
|
engine_config = importlib.import_module('configs.' + args.stage) |
|
|
|
cfg = engine_config.EngineConfig(args.exp_name, args.model) |
|
|
|
if len(args.datasets) > 0: |
|
cfg.DATASETS = args.datasets |
|
|
|
cfg.DIST_START_GPU = args.start_gpu |
|
if args.gpu_num > 0: |
|
cfg.TRAIN_GPUS = args.gpu_num |
|
if args.batch_size > 0: |
|
cfg.TRAIN_BATCH_SIZE = args.batch_size |
|
|
|
if args.pretrained_path != '': |
|
cfg.PRETRAIN_MODEL = args.pretrained_path |
|
|
|
if args.max_id_num > 0: |
|
cfg.MODEL_MAX_OBJ_NUM = args.max_id_num |
|
|
|
if args.lr > 0: |
|
cfg.TRAIN_LR = args.lr |
|
|
|
if args.total_step > 0: |
|
cfg.TRAIN_TOTAL_STEPS = args.total_step |
|
|
|
if args.start_step > 0: |
|
cfg.TRAIN_START_STEP = args.start_step |
|
|
|
if args.dist_url == '': |
|
cfg.DIST_URL = 'tcp://127.0.0.1:123' + str(random.randint(0, 9)) + str( |
|
random.randint(0, 9)) |
|
else: |
|
cfg.DIST_URL = args.dist_url |
|
|
|
if cfg.TRAIN_GPUS > 1: |
|
|
|
mp.spawn(main_worker, nprocs=cfg.TRAIN_GPUS, args=(cfg, args.amp)) |
|
else: |
|
cfg.TRAIN_GPUS = 1 |
|
main_worker(0, cfg, args.amp) |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|