import argparse import sys sys.path.insert(0, sys.path[0]+"/../") import prompts as task_prompts import envs import os from envs.translator import InitSummarizer, CurrSummarizer, FutureSummarizer, Translator import gym from torch.optim.lr_scheduler import LambdaLR import torch from tianshou.data import Collector, VectorReplayBuffer, ReplayBuffer from tianshou.env import DummyVectorEnv, SubprocVectorEnv from tianshou.policy import PPOPolicy, ICMPolicy from tianshou.trainer import onpolicy_trainer from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule from RL_based.utils import Net_GRU_Bert_tianshou, Net_Bert_CLS_tianshou, Net_Bert_CNN_tianshou, Net_GRU_nn_emb_tianshou from tianshou.utils import WandbLogger from torch.utils.tensorboard import SummaryWriter from tianshou.trainer.utils import test_episode import warnings warnings.filterwarnings('ignore') class MaxStepLimitWrapper(gym.Wrapper): def __init__(self, env, max_steps=200): super(MaxStepLimitWrapper, self).__init__(env) self.max_steps = max_steps self.current_step = 0 def reset(self, **kwargs): self.current_step = 0 return self.env.reset(**kwargs) def step(self, action): observation, reward, terminated, truncated, info = self.env.step(action) self.current_step += 1 if self.current_step >= self.max_steps: terminated = True info['episode_step_limit'] = self.max_steps return observation, reward, terminated, truncated, info class SimpleTextWrapper(gym.Wrapper): def __init__(self, env): super(SimpleTextWrapper, self).__init__(env) self.env = env def reset(self, **kwargs): observation, _ = self.env.reset(**kwargs) return str(observation), {} def step(self, action): observation, reward, terminated, truncated, info = self.env.step(action) return str(observation), reward, terminated, truncated, info if __name__ == "__main__": parser = argparse.ArgumentParser(description='Evaluate a translator in a gym environment with a ChatGPT model.') parser.add_argument('--init_summarizer', type=str, required=True, help='The name of the init summarizer to use.') parser.add_argument('--curr_summarizer', type=str, required=True, help='The name of the curr summarizer to use.') parser.add_argument('--future_summarizer', type=str, help='The name of the future summarizer to use.') parser.add_argument('--env', type=str, default='base_env', help='The name of the gym environment to use.') parser.add_argument('--env_name', type=str, default='CartPole-v1', help='The name of the gym environment to use.') parser.add_argument('--decider', type=str, default="naive_actor", help='The actor used to select action') parser.add_argument('--render', type=str, default="rgb_array", help='The render mode') parser.add_argument('--future_horizon', type=int, help='The horizon of looking to future') parser.add_argument( "--prompt_level", type=int, default=1, help="The level of prompts", ) parser.add_argument( "--past_horizon", type=int, help="The horizon of looking back" ) parser.add_argument( "--max_episode_len", type=int, default=200, help="The max length of an episode" ) ### for RL training parser.add_argument('--max_length', type=int, default=128, help='The token length of the observation') # trans_model_name parser.add_argument('--trans_model_name', type=str, default='bert-base-uncased', help='The name of the pretrained transformer to use.') parser.add_argument('--model_name', type=str, default='bert-embedding', help='The name of the model to use.') parser.add_argument('--vector_env', type=str, default='dummy', help='The name of the vector env to use.') parser.add_argument('--eval', action='store_true', default=False, help='Whether to only eval the model') parser.add_argument('--policy-path', type=str, default=None, help='The path to the policy to be evaluated') parser.add_argument('--collect_one_episode', action='store_true', default=False, help='Whether to only collect one episode') parser.add_argument('--lr', type=float, default=0.0003, help='The learning rate of the model') parser.add_argument('--step_per_epoch', type=int, default=10000, help='The number of steps per epoch') parser.add_argument('--step_per_collect', type=int, default=2000, help='The number of steps per collect') parser.add_argument('--lr_decay', action='store_true', default=False, help='Whether to decay the learning rate') parser.add_argument('--epoch', type=int, default=400, help='The number of epochs to train') parser.add_argument('--resume_path', type=str, default=None, help='The path to the policy to be resumed') parser.add_argument('--taxi_specific_env', action='store_true', default=False, help='Whether to use taxi specific env') args = parser.parse_args() args_dict = vars(args) device = 'cuda' if torch.cuda.is_available() else 'cpu' # Get the specified translator, environment, and ChatGPT model env_class = envs.REGISTRY[args.env] init_summarizer = InitSummarizer(envs.REGISTRY[args.init_summarizer]) curr_summarizer = CurrSummarizer(envs.REGISTRY[args.curr_summarizer]) if args.future_summarizer: future_summarizer = FutureSummarizer( envs.REGISTRY[args.future_summarizer], envs.REGISTRY["cart_policies"], future_horizon=args.future_horizon, ) else: future_summarizer = None wandb_log_config = { "env": args.env_name, "init_summarizer": args.init_summarizer, "curr_summarizer": args.curr_summarizer, "future_summarizer": args.future_summarizer, } wandb_log_config.update(args_dict) if not args.eval: logger = WandbLogger( project="LLM-decider-bench-RL", entity="llm-bench-team", config=wandb_log_config, ) random_name = logger.wandb_run.name log_path = os.path.join('/home/ubuntu/LLM-Decider-Bench/RL_based/results', args.env_name, random_name) writer = SummaryWriter(log_dir=log_path) writer.add_text("args", str(args)) logger.load(writer) def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) sampling_env = envs.REGISTRY["sampling_wrapper"](gym.make(args.env_name)) if args.prompt_level == 5: prompts_class = task_prompts.REGISTRY[(args.env_name,args.decider)]() else: prompts_class = task_prompts.REGISTRY[(args.decider)]() translator = Translator( init_summarizer, curr_summarizer, future_summarizer, env=sampling_env ) if args.taxi_specific_env: environment = gym.make(args.env_name, render_mode=args.render) else: environment = env_class( gym.make(args.env_name, render_mode=args.render), translator ) # Set the translation level translate_level = 1 if args.past_horizon is None and args.future_horizon is None: translate_level = 1 if args.past_horizon and args.future_horizon is None: raise NotImplementedError # translate_level = 2 if args.past_horizon is None and args.future_horizon: raise NotImplementedError # translate_level = 3 if args.past_horizon and args.future_horizon: raise NotImplementedError # translate_level = 3.5 if args.vector_env == 'dummy': ThisEnv = DummyVectorEnv elif args.vector_env == 'subproc': ThisEnv = SubprocVectorEnv def make_env(): if args.taxi_specific_env: env = MaxStepLimitWrapper(SimpleTextWrapper(gym.make(args.env_name, render_mode=args.render)), max_steps=200) env._max_episode_steps = args.max_episode_len else: env = env_class(MaxStepLimitWrapper(gym.make(args.env_name, render_mode=args.render), max_steps=200), translator) env._max_episode_steps = args.max_episode_len return env train_envs = ThisEnv([make_env for _ in range(20)]) test_envs = ThisEnv([make_env for _ in range(10)]) # model & optimizer def get_net(): if args.model_name == "bert-embedding": net = Net_GRU_Bert_tianshou(state_shape=environment.observation_space.shape, hidden_sizes=[64, 64], device=device, max_length=args.max_length, trans_model_name=args.trans_model_name) elif args.model_name == "bert-CLS-embedding": net = Net_Bert_CLS_tianshou(state_shape=environment.observation_space.shape, hidden_sizes=[256, 128], device=device, max_length=args.max_length, trans_model_name=args.trans_model_name) elif args.model_name == "bert-CNN-embedding": net = Net_Bert_CNN_tianshou(state_shape=environment.observation_space.shape, hidden_sizes=[256, 128], device=device, max_length=args.max_length, trans_model_name=args.trans_model_name) elif args.model_name == "nn_embedding": net = Net_GRU_nn_emb_tianshou(hidden_sizes=[256, 128], device=device, max_length=args.max_length, trans_model_name=args.trans_model_name) return net net = get_net() actor = Actor(net, environment.action_space.n, device=device).to(device) critic = Critic(net, device=device).to(device) actor_critic = ActorCritic(actor, critic) optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) # PPO policy dist = torch.distributions.Categorical lr_scheduler = None if args.lr_decay: max_update_num = args.step_per_epoch // args.step_per_collect * args.epoch lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) policy = PPOPolicy(actor, critic, optim, dist, action_space=environment.action_space, lr_scheduler=lr_scheduler).to(device) # collector train_collector = Collector(policy, train_envs, VectorReplayBuffer(20000, len(train_envs)), exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) if not args.eval: # trainer # test train_collector and start filling replay buffer if args.resume_path: policy.load_state_dict(torch.load(args.resume_path, map_location='cuda')) print("Loaded agent from: ", args.resume_path) train_collector.collect(256 * 20) result = onpolicy_trainer( policy, train_collector, test_collector, max_epoch=args.epoch, step_per_epoch=50000, # the number of transitions collected per epoch repeat_per_collect=4, episode_per_test=10, batch_size=256, logger=logger, step_per_collect=1000, # the number of transitions the collector would collect before the network update save_best_fn=save_best_fn, # stop_fn=lambda mean_reward: mean_reward >= environment.spec.reward_threshold, ) print(result) else: assert args.policy_path is not None policy.load_state_dict(torch.load(args.policy_path)) test_collector = Collector(policy, test_envs) result = test_episode(policy, test_collector, None, None, n_episode=10) print(result) if args.collect_one_episode: replaybuffer = ReplayBuffer(size=1000) test_collector_1 = Collector(policy, environment, replaybuffer) test_collector_1.reset_env() test_collector_1.reset_buffer() policy.eval() result = test_collector_1.collect(n_episode=1) print('sample results', f"/home/ubuntu/LLM-Decider-Bench/RL_based/checkpoints/{args.env_name}/output.txt") sample_result = replaybuffer.sample(0) f = open(f"/home/ubuntu/LLM-Decider-Bench/RL_based/checkpoints/{args.env_name}/output.txt", "w") print(sample_result, file=f) f.close()