Spaces:
Runtime error
Runtime error
import argparse | |
import logging | |
import os.path as osp | |
import random | |
import torch | |
from data.segm_attr_dataset import DeepFashionAttrSegmDataset | |
from models import create_model | |
from utils.logger import get_root_logger | |
from utils.options import dict2str, dict_to_nonedict, parse | |
from utils.util import make_exp_dirs, set_random_seed | |
def main(): | |
# options | |
parser = argparse.ArgumentParser() | |
parser.add_argument('-opt', type=str, help='Path to option YAML file.') | |
args = parser.parse_args() | |
opt = parse(args.opt, is_train=False) | |
# mkdir and loggers | |
make_exp_dirs(opt) | |
log_file = osp.join(opt['path']['log'], f"test_{opt['name']}.log") | |
logger = get_root_logger( | |
logger_name='base', log_level=logging.INFO, log_file=log_file) | |
logger.info(dict2str(opt)) | |
# convert to NoneDict, which returns None for missing keys | |
opt = dict_to_nonedict(opt) | |
# random seed | |
seed = opt['manual_seed'] | |
if seed is None: | |
seed = random.randint(1, 10000) | |
logger.info(f'Random seed: {seed}') | |
set_random_seed(seed) | |
test_dataset = DeepFashionAttrSegmDataset( | |
img_dir=opt['test_img_dir'], | |
segm_dir=opt['segm_dir'], | |
pose_dir=opt['pose_dir'], | |
ann_dir=opt['test_ann_file']) | |
test_loader = torch.utils.data.DataLoader( | |
dataset=test_dataset, batch_size=4, shuffle=False) | |
logger.info(f'Number of test set: {len(test_dataset)}.') | |
model = create_model(opt) | |
_ = model.inference(test_loader, opt['path']['results_root']) | |
if __name__ == '__main__': | |
main() | |