File size: 3,468 Bytes
c985ba4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import importlib
import sys
sys.path.append('.')
sys.path.append('..')
import torch
import torch.multiprocessing as mp
from networks.managers.evaluator import Evaluator
def main_worker(gpu, cfg, seq_queue=None, info_queue=None, enable_amp=False):
# Initiate a evaluating manager
evaluator = Evaluator(rank=gpu,
cfg=cfg,
seq_queue=seq_queue,
info_queue=info_queue)
# Start evaluation
if enable_amp:
with torch.cuda.amp.autocast(enabled=True):
evaluator.evaluating()
else:
evaluator.evaluating()
def main():
import argparse
parser = argparse.ArgumentParser(description="Eval VOS")
parser.add_argument('--exp_name', type=str, default='default')
parser.add_argument('--stage', type=str, default='pre')
parser.add_argument('--model', type=str, default='aott')
parser.add_argument('--lstt_num', type=int, default=-1)
parser.add_argument('--lt_gap', type=int, default=-1)
parser.add_argument('--st_skip', type=int, default=-1)
parser.add_argument('--max_id_num', type=int, default='-1')
parser.add_argument('--gpu_id', type=int, default=0)
parser.add_argument('--gpu_num', type=int, default=1)
parser.add_argument('--ckpt_path', type=str, default='')
parser.add_argument('--ckpt_step', type=int, default=-1)
parser.add_argument('--dataset', type=str, default='')
parser.add_argument('--split', type=str, default='')
parser.add_argument('--ema', action='store_true')
parser.set_defaults(ema=False)
parser.add_argument('--flip', action='store_true')
parser.set_defaults(flip=False)
parser.add_argument('--ms', nargs='+', type=float, default=[1.])
parser.add_argument('--max_resolution', type=float, default=480 * 1.3)
parser.add_argument('--amp', action='store_true')
parser.set_defaults(amp=False)
args = parser.parse_args()
engine_config = importlib.import_module('configs.' + args.stage)
cfg = engine_config.EngineConfig(args.exp_name, args.model)
cfg.TEST_EMA = args.ema
cfg.TEST_GPU_ID = args.gpu_id
cfg.TEST_GPU_NUM = args.gpu_num
if args.lstt_num > 0:
cfg.MODEL_LSTT_NUM = args.lstt_num
if args.lt_gap > 0:
cfg.TEST_LONG_TERM_MEM_GAP = args.lt_gap
if args.st_skip > 0:
cfg.TEST_SHORT_TERM_MEM_SKIP = args.st_skip
if args.max_id_num > 0:
cfg.MODEL_MAX_OBJ_NUM = args.max_id_num
if args.ckpt_path != '':
cfg.TEST_CKPT_PATH = args.ckpt_path
if args.ckpt_step > 0:
cfg.TEST_CKPT_STEP = args.ckpt_step
if args.dataset != '':
cfg.TEST_DATASET = args.dataset
if args.split != '':
cfg.TEST_DATASET_SPLIT = args.split
cfg.TEST_FLIP = args.flip
cfg.TEST_MULTISCALE = args.ms
if cfg.TEST_MULTISCALE != [1.]:
cfg.TEST_MAX_SHORT_EDGE = args.max_resolution # for preventing OOM
else:
cfg.TEST_MAX_SHORT_EDGE = None # the default resolution setting of CFBI and AOT
cfg.TEST_MAX_LONG_EDGE = args.max_resolution * 800. / 480.
if args.gpu_num > 1:
mp.set_start_method('spawn')
seq_queue = mp.Queue()
info_queue = mp.Queue()
mp.spawn(main_worker,
nprocs=cfg.TEST_GPU_NUM,
args=(cfg, seq_queue, info_queue, args.amp))
else:
main_worker(0, cfg, enable_amp=args.amp)
if __name__ == '__main__':
main()
|