import sys import logging import copy import torch from PIL import Image import torchvision.transforms as transforms from utils import factory from utils.data_manager import DataManager from torch.utils.data import DataLoader from utils.toolkit import count_parameters import os import numpy as np import json import argparse import torch.multiprocessing torch.multiprocessing.set_sharing_strategy('file_system') def _set_device(args): device_type = args["device"] gpus = [] for device in device_type: if device == -1: device = torch.device("cpu") else: device = torch.device("cuda:{}".format(device)) gpus.append(device) args["device"] = gpus def get_methods(object, spacing=20): methodList = [] for method_name in dir(object): try: if callable(getattr(object, method_name)): methodList.append(str(method_name)) except Exception: methodList.append(str(method_name)) processFunc = (lambda s: ' '.join(s.split())) or (lambda s: s) for method in methodList: try: print(str(method.ljust(spacing)) + ' ' + processFunc(str(getattr(object, method).__doc__)[0:90])) except Exception: print(method.ljust(spacing) + ' ' + ' getattr() failed') def load_model(args): _set_device(args) model = factory.get_model(args["model_name"], args) model.load_checkpoint(args["checkpoint"]) return model def evaluate(args): logs_name = "logs/{}/{}_{}/{}/{}".format(args["model_name"],args["dataset"], args['data'], args['init_cls'], args['increment']) if not os.path.exists(logs_name): os.makedirs(logs_name) logfilename = "logs/{}/{}_{}/{}/{}/{}_{}_{}".format( args["model_name"], args["dataset"], args['data'], args['init_cls'], args["increment"], args["prefix"], args["seed"], args["convnet_type"], ) if not os.path.exists(logs_name): os.makedirs(logs_name) args['logfilename'] = logs_name args['csv_name'] = "{}_{}_{}".format( args["prefix"], args["seed"], args["convnet_type"], ) logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(filename)s] => %(message)s", handlers=[ logging.FileHandler(filename=logfilename + ".log"), logging.StreamHandler(sys.stdout), ], ) _set_random() print_args(args) model = load_model(args) data_manager = DataManager( args["dataset"], False, args["seed"], args["init_cls"], args["increment"], path = args["data"] ) loader = DataLoader(data_manager.get_dataset(model.class_list, source = "test", mode = "test"), batch_size=args['batch_size'], shuffle=True, num_workers=8) cnn_acc, nme_acc = model.eval_task(loader, group = 1, mode = "test") print(cnn_acc, nme_acc) def main(): args = setup_parser().parse_args() param = load_json(args.config) args = vars(args) # Converting argparse Namespace to a dict. args.update(param) # Add parameters from json evaluate(args) def load_json(settings_path): with open(settings_path) as data_file: param = json.load(data_file) return param def _set_random(): torch.manual_seed(1) torch.cuda.manual_seed(1) torch.cuda.manual_seed_all(1) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def setup_parser(): parser = argparse.ArgumentParser(description='Reproduce of multiple continual learning algorthms.') parser.add_argument('--config', type=str, default='./exps/finetune.json', help='Json file of settings.') parser.add_argument('-d','--data', type=str, help='Path of the data folder') parser.add_argument('-c','--checkpoint', type=str, help='Path of checkpoint file if resume training') return parser def print_args(args): for key, value in args.items(): logging.info("{}: {}".format(key, value)) if __name__ == '__main__': main()