import os import pickle import setproctitle from absl import app, flags import time import random from typing import Tuple, Union, Text import numpy as np import torch as th import sys import gymnasium sys.modules["gym"] = gymnasium from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3 import PPO from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.vec_env import DummyVecEnv, VecEnvWrapper from facility_location.agent.solver import PMPSolver from facility_location.env import EvalPMPEnv, MULTIPMP from facility_location.utils import Config from facility_location.agent import MaskedFacilityLocationActorCriticPolicy from facility_location.utils.policy import get_policy_kwargs import warnings warnings.filterwarnings('ignore') AGENT = Union[PMPSolver, PPO] def get_model(cfg: Config, env: Union[VecEnvWrapper, DummyVecEnv, EvalPMPEnv], device: str) -> PPO: policy_kwargs = get_policy_kwargs(cfg) model = PPO(MaskedFacilityLocationActorCriticPolicy, env, verbose=1, policy_kwargs=policy_kwargs, device=device) return model def get_agent(cfg: Config, env: Union[VecEnvWrapper, DummyVecEnv, EvalPMPEnv], model_path: Text) -> AGENT: if cfg.agent in ['rl-mlp', 'rl-gnn', 'rl-agnn']: test_model = get_model(cfg, env, device='cuda:0') trained_model = PPO.load(model_path) test_model.set_parameters(trained_model.get_parameters()) agent = test_model else: raise ValueError(f'Agent {cfg.agent} not supported.') return agent def evaluate(agent: AGENT, env: Union[VecEnvWrapper, DummyVecEnv, EvalPMPEnv], num_cases: int, return_episode_rewards: bool): if isinstance(agent, PPO): return evaluate_ppo(agent, env, num_cases, return_episode_rewards=return_episode_rewards) else: raise ValueError(f'Agent {agent} not supported.') from stable_baselines3.common.callbacks import BaseCallback def evaluate_ppo(agent: PPO, env: EvalPMPEnv, num_cases: int, return_episode_rewards: bool) -> Tuple[float, float]: rewards, _ = evaluate_policy(agent, env, n_eval_episodes=num_cases, return_episode_rewards=return_episode_rewards) return rewards def main(data_npy, boost=False): th.manual_seed(0) np.random.seed(0) random.seed(0) model_path = './facility_location/best_model.zip' cfg = Config('plot', 0, False, '/data2/suhongyuan/flp', 'rl-gnn', model_path=model_path) eval_env = MULTIPMP(cfg, data_npy, boost) eval_env = Monitor(eval_env) eval_env = DummyVecEnv([lambda: eval_env]) agent = get_agent(cfg, eval_env, model_path) start_time = time.time() _ = evaluate(agent, eval_env, 1, return_episode_rewards=True) eval_time = time.time() - start_time print(f'\t time: {eval_time}') if __name__ == '__main__': app.run(main)