from networks.engines.aot_engine import AOTEngine, AOTInferEngine | |
from networks.engines.deaot_engine import DeAOTEngine, DeAOTInferEngine | |
def build_engine(name, phase='train', **kwargs): | |
if name == 'aotengine': | |
if phase == 'train': | |
return AOTEngine(**kwargs) | |
elif phase == 'eval': | |
return AOTInferEngine(**kwargs) | |
else: | |
raise NotImplementedError | |
elif name == 'deaotengine': | |
if phase == 'train': | |
return DeAOTEngine(**kwargs) | |
elif phase == 'eval': | |
return DeAOTInferEngine(**kwargs) | |
else: | |
raise NotImplementedError | |
else: | |
raise NotImplementedError | |