import os import pickle import setproctitle from absl import app, flags import time import random from typing import Tuple, Union, Text import numpy as np import torch as th import sys import gymnasium sys.modules["gym"] = gymnasium from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3 import PPO from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.vec_env import DummyVecEnv, VecEnvWrapper from facility_location.agent.solver import PMPSolver from facility_location.agent.ga import PMPGA from facility_location.agent.heuristic import HeuristicRandom, HeuristicGreedy, HeuristicFastInterchange from facility_location.agent.metaheuristic import TabuSearch, POPSTAR, VNS from facility_location.env import EvalPMPEnv, MULTIPMP from facility_location.utils import Config from facility_location.agent import MaskedFacilityLocationActorCriticPolicy from facility_location.utils.policy import get_policy_kwargs import warnings warnings.filterwarnings('ignore') flags.DEFINE_string('cfg', None, 'Configuration file.') flags.DEFINE_integer('global_seed', None, 'Used in env and weight initialization, does not impact action sampling.') flags.DEFINE_string('root_dir', '/data2/suhongyuan/flp', 'Root directory for writing ' 'logs/summaries/checkpoints.') flags.DEFINE_bool('tmp', False, 'Whether to use temporary storage.') flags.DEFINE_enum('agent', None, ['solver-gurobi', 'solver-gurobi-cmd', 'solver-pulp-cbc-cmd', 'solver-glpk-cmd', 'solver-mosek', 'heuristic-random', 'heuristic-greedy', 'heuristic-fastinterchange', 'metaheuristic-ts', 'metaheuristic-vns', 'metaheuristic-popstar', 'ga', 'ppo-random', 'rl-mlp', 'rl-gnn', 'rl-agnn'], 'Agent type.') flags.DEFINE_string('model_path', None, 'Path to saved mode to evaluate.') FLAGS = flags.FLAGS AGENT = Union[PMPSolver, HeuristicRandom, HeuristicGreedy, HeuristicFastInterchange, TabuSearch, VNS, POPSTAR, PMPGA, PPO] BASELINE = Union[PMPSolver, HeuristicRandom, HeuristicGreedy, HeuristicFastInterchange, TabuSearch, VNS, POPSTAR, PMPGA] def get_model(cfg: Config, env: Union[VecEnvWrapper, DummyVecEnv, EvalPMPEnv], device: str) -> PPO: policy_kwargs = get_policy_kwargs(cfg) model = PPO(MaskedFacilityLocationActorCriticPolicy, env, verbose=1, policy_kwargs=policy_kwargs, device=device) return model def get_agent(cfg: Config, env: Union[VecEnvWrapper, DummyVecEnv, EvalPMPEnv], model_path: Text) -> AGENT: if cfg.agent.startswith('solver'): if cfg.agent == 'solver-gurobi': agent = PMPSolver('GUROBI', env) elif cfg.agent == 'solver-gurobi-cmd': agent = PMPSolver('GUROBI_CMD', env) elif cfg.agent == 'solver-pulp-cbc-cmd': agent = PMPSolver('PULP_CBC_CMD', env) elif cfg.agent == 'solver-glpk-cmd': agent = PMPSolver('GLPK_CMD', env) elif cfg.agent == 'solver-mosek': agent = PMPSolver('MOSEK', env) else: raise ValueError(f'Agent {cfg.agent} not supported.') elif cfg.agent.startswith('heuristic'): if cfg.agent == 'heuristic-random': agent = HeuristicRandom(cfg.seed, env) elif cfg.agent == 'heuristic-greedy': agent = HeuristicGreedy(env) elif cfg.agent == 'heuristic-fastinterchange': agent = HeuristicFastInterchange(env) else: raise ValueError(f'Agent {cfg.agent} not supported.') elif cfg.agent.startswith('metaheuristic'): if cfg.agent == 'metaheuristic-ts': agent = TabuSearch(cfg, env) elif cfg.agent == 'metaheuristic-vns': agent = VNS(env) elif cfg.agent == 'metaheuristic-popstar': agent = POPSTAR(cfg, env) else: raise ValueError(f'Agent {cfg.agent} not supported.') elif cfg.agent == 'ga': agent = PMPGA(cfg, env) elif cfg.agent == 'ppo-random': agent = PPO("MultiInputPolicy", env, verbose=1) elif cfg.agent in ['rl-mlp', 'rl-gnn', 'rl-agnn']: test_model = get_model(cfg, env, device='cuda:3') trained_model = PPO.load(model_path) test_model.set_parameters(trained_model.get_parameters()) agent = test_model else: raise ValueError(f'Agent {cfg.agent} not supported.') return agent def evaluate(agent: AGENT, env: Union[VecEnvWrapper, DummyVecEnv, EvalPMPEnv], num_cases: int, return_episode_rewards: bool): if isinstance(agent, PPO): return evaluate_ppo(agent, env, num_cases, return_episode_rewards=return_episode_rewards) else: return evaluate_baseline(agent, env, num_cases) from stable_baselines3.common.callbacks import BaseCallback def evaluate_ppo(agent: PPO, env: EvalPMPEnv, num_cases: int, return_episode_rewards: bool) -> Tuple[float, float]: rewards, _ = evaluate_policy(agent, env, n_eval_episodes=num_cases, return_episode_rewards=return_episode_rewards) return rewards def evaluate_baseline( agent: BASELINE, env: EvalPMPEnv, num_cases: int): rewards = np.zeros(num_cases) for case_idx in range(num_cases): env.reset() solution = agent.solve() reward = env.evaluate(solution) rewards[case_idx] = reward return rewards def calculate_gap(gurobi_obj, method_obj): method_obj = np.array(method_obj) gap = (method_obj - gurobi_obj) / gurobi_obj mean_gap = np.mean(gap) std_gap = np.std(gap) return mean_gap, std_gap def main(_): setproctitle.setproctitle('rl@suhy') th.manual_seed(FLAGS.global_seed) np.random.seed(FLAGS.global_seed) random.seed(FLAGS.global_seed) cfg = Config(FLAGS.cfg, FLAGS.global_seed, FLAGS.tmp, FLAGS.root_dir, FLAGS.agent, model_path=FLAGS.model_path) # if cfg.eval_specs['region'] is None: # eval_np = cfg.eval_specs['test_np'] # else: # eval_path = './data/{}/pkl'.format(cfg.eval_specs['region']) # files = os.listdir(eval_path) # eval_np = [] # for f in files: # eval_np.append(tuple(map(int, f.split('.')[0].split('_')))) # eval_np = sorted(eval_np, key=lambda x: (x[0], x[1])) eval_env = MULTIPMP(cfg) if cfg.agent in ['rl-mlp', 'rl-gnn', 'rl-agnn']: eval_env = Monitor(eval_env) eval_env = DummyVecEnv([lambda: eval_env]) model_path = os.path.join(cfg.root_dir, 'output', FLAGS.model_path) else: model_path = None agent = get_agent(cfg, eval_env, model_path) start_time = time.time() episode_rewards = evaluate(agent, eval_env, 1, return_episode_rewards=True) eval_time = time.time() - start_time # if cfg.agent == 'solver-gurobi': # pickle.dump(episode_rewards, open(f'gurobi_result/{n}_{p}.pkl', 'wb')) # else: # gurobi_obj = pickle.load(open(f'gurobi_result/{n}_{p}.pkl', 'rb')) # mean_gap, std_gap = calculate_gap(gurobi_obj, episode_rewards) # print(f'\t mean gap: {mean_gap}') # print(f'\t std gap: {std_gap}') print(f'\t reward: {episode_rewards}') print(f'\t time: {eval_time}') if __name__ == '__main__': flags.mark_flags_as_required([ 'cfg', 'global_seed', 'agent' ]) app.run(main)