######################################################################################################################## # IMPORT # ######################################################################################################################## import torch import sys import os import json import time import numpy as np import argparse from torch.utils.data import DataLoader from torch.utils.data import WeightedRandomSampler from umap.umap_ import find_ab_params from singleVis.custom_weighted_random_sampler import CustomWeightedRandomSampler from singleVis.SingleVisualizationModel import VisModel from singleVis.losses import UmapLoss, ReconstructionLoss, TemporalLoss, DVILoss, SingleVisLoss, DummyTemporalLoss from singleVis.edge_dataset import DVIDataHandler from singleVis.trainer import DVITrainer,DVIALTrainer from singleVis.eval.evaluator import Evaluator from singleVis.data import NormalDataProvider # from singleVis.spatial_edge_constructor import SingleEpochSpatialEdgeConstructor from singleVis.spatial_skeleton_edge_constructor import ProxyBasedSpatialEdgeConstructor from singleVis.projector import DVIProjector from singleVis.utils import find_neighbor_preserving_rate from trustVis.skeleton_generator import SkeletonGenerator,CenterSkeletonGenerator,HierarchicalClusteringProxyGenerator ######################################################################################################################## # DVI PARAMETERS # ######################################################################################################################## """This serve as an example of DeepVisualInsight implementation in pytorch.""" VIS_METHOD = "DVI" # DeepVisualInsight ######################################################################################################################## # LOAD PARAMETERS # ######################################################################################################################## parser = argparse.ArgumentParser(description='Process hyperparameters...') # get workspace dir current_path = os.getcwd() new_path = os.path.join(current_path, 'training_dynamic') parser.add_argument('--content_path', type=str,default=new_path) parser.add_argument('--start', type=int,default=1) parser.add_argument('--end', type=int,default=3) # parser.add_argument('--epoch_end', type=int) parser.add_argument('--epoch_period', type=int,default=1) parser.add_argument('--preprocess', type=int,default=0) parser.add_argument('--base',type=bool,default=False) args = parser.parse_args() CONTENT_PATH = args.content_path sys.path.append(CONTENT_PATH) with open(os.path.join(CONTENT_PATH, "config.json"), "r") as f: config = json.load(f) config = config[VIS_METHOD] # record output information # now = time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime(time.time())) # sys.stdout = open(os.path.join(CONTENT_PATH, now+".txt"), "w") SETTING = config["SETTING"] CLASSES = config["CLASSES"] DATASET = config["DATASET"] PREPROCESS = config["VISUALIZATION"]["PREPROCESS"] GPU_ID = config["GPU"] GPU_ID = 0 EPOCH_START = config["EPOCH_START"] EPOCH_END = config["EPOCH_END"] EPOCH_PERIOD = config["EPOCH_PERIOD"] EPOCH_START = args.start EPOCH_END = args.end EPOCH_PERIOD = args.epoch_period # Training parameter (subject model) TRAINING_PARAMETER = config["TRAINING"] NET = TRAINING_PARAMETER["NET"] LEN = TRAINING_PARAMETER["train_num"] # Training parameter (visualization model) VISUALIZATION_PARAMETER = config["VISUALIZATION"] LAMBDA1 = VISUALIZATION_PARAMETER["LAMBDA1"] LAMBDA2 = VISUALIZATION_PARAMETER["LAMBDA2"] B_N_EPOCHS = VISUALIZATION_PARAMETER["BOUNDARY"]["B_N_EPOCHS"] L_BOUND = VISUALIZATION_PARAMETER["BOUNDARY"]["L_BOUND"] ENCODER_DIMS = VISUALIZATION_PARAMETER["ENCODER_DIMS"] DECODER_DIMS = VISUALIZATION_PARAMETER["DECODER_DIMS"] S_N_EPOCHS = VISUALIZATION_PARAMETER["S_N_EPOCHS"] N_NEIGHBORS = VISUALIZATION_PARAMETER["N_NEIGHBORS"] PATIENT = VISUALIZATION_PARAMETER["PATIENT"] MAX_EPOCH = VISUALIZATION_PARAMETER["MAX_EPOCH"] VIS_MODEL_NAME = 'proxy' ### saved_as v_base = args.base if v_base: VIS_MODEL_NAME = 'v_base' EVALUATION_NAME = VISUALIZATION_PARAMETER["EVALUATION_NAME"] # Define hyperparameters DEVICE = torch.device("cuda:{}".format(GPU_ID) if torch.cuda.is_available() else "cpu") import Model.model as subject_model net = eval("subject_model.{}()".format(NET)) ######################################################################################################################## # TRAINING SETTING # ######################################################################################################################## # Define data_provider data_provider = NormalDataProvider(CONTENT_PATH, net, EPOCH_START, EPOCH_END, EPOCH_PERIOD, device=DEVICE, epoch_name='Epoch',classes=CLASSES,verbose=1) PREPROCESS = args.preprocess if PREPROCESS: data_provider._meta_data() if B_N_EPOCHS >0: data_provider._estimate_boundary(LEN // 10, l_bound=L_BOUND) # Define visualization models model = VisModel(ENCODER_DIMS, DECODER_DIMS) # Define Losses negative_sample_rate = 5 min_dist = .1 _a, _b = find_ab_params(1.0, min_dist) umap_loss_fn = UmapLoss(negative_sample_rate, DEVICE, _a, _b, repulsion_strength=1.0) recon_loss_fn = ReconstructionLoss(beta=1.0) single_loss_fn = SingleVisLoss(umap_loss_fn, recon_loss_fn, lambd=LAMBDA1) # Define Projector projector = DVIProjector(vis_model=model, content_path=CONTENT_PATH, vis_model_name=VIS_MODEL_NAME, device=DEVICE) evaluator = Evaluator(data_provider, projector) Evaluation_NAME = 'subject_model_eval' for i in range(EPOCH_START, EPOCH_END+1, EPOCH_PERIOD): evaluator.save_epoch_eval_for_subject_model(i,file_name="{}".format(Evaluation_NAME))