yitianlian's picture
update demo
24be7a2
raw
history blame
1.11 kB
import glob
import importlib
import logging
import os.path as osp
# automatically scan and import model modules
# scan all the files under the 'models' folder and collect files ending with
# '_model.py'
model_folder = osp.dirname(osp.abspath(__file__))
model_filenames = [
osp.splitext(osp.basename(v))[0]
for v in glob.glob(f'{model_folder}/*_model.py')
]
# import all the model modules
_model_modules = [
importlib.import_module(f'models.{file_name}')
for file_name in model_filenames
]
def create_model(opt):
"""Create model.
Args:
opt (dict): Configuration. It constains:
model_type (str): Model type.
"""
model_type = opt['model_type']
# dynamically instantiation
for module in _model_modules:
model_cls = getattr(module, model_type, None)
if model_cls is not None:
break
if model_cls is None:
raise ValueError(f'Model {model_type} is not found.')
model = model_cls(opt)
logger = logging.getLogger('base')
logger.info(f'Model [{model.__class__.__name__}] is created.')
return model