import os | |
from .default import DefaultEngineConfig | |
class EngineConfig(DefaultEngineConfig): | |
def __init__(self, exp_name='default', model='AOTT'): | |
super().__init__(exp_name, model) | |
self.STAGE_NAME = 'PRE_DAV' | |
self.init_dir() | |
self.DATASETS = ['davis2017'] | |
self.TRAIN_TOTAL_STEPS = 50000 | |
pretrain_stage = 'PRE' | |
pretrain_ckpt = 'save_step_100000.pth' | |
self.PRETRAIN_FULL = True # if False, load encoder only | |
self.PRETRAIN_MODEL = os.path.join(self.DIR_ROOT, 'result', | |
self.EXP_NAME, pretrain_stage, | |
'ema_ckpt', pretrain_ckpt) | |