Jarvis-K commited on
Commit
2a33798
·
1 Parent(s): 64eba11
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +97 -0
  2. RL_based/test_RL.sh +39 -0
  3. RL_based/train_PPO.py +251 -0
  4. RL_based/train_RL.sh +39 -0
  5. RL_based/utils.py +621 -0
  6. deciders/__init__.py +26 -0
  7. deciders/act.py +248 -0
  8. deciders/cot.py +147 -0
  9. deciders/jarvis.py +177 -0
  10. deciders/jarvis_without_insights.py +179 -0
  11. deciders/jarvis_without_shortmem.py +182 -0
  12. deciders/jarvis_without_suggestions.py +180 -0
  13. deciders/jarvis_without_summary.py +179 -0
  14. deciders/misc.py +21 -0
  15. deciders/pal.py +149 -0
  16. deciders/parser.py +53 -0
  17. deciders/reflexion.py +179 -0
  18. deciders/self_consistency.py +170 -0
  19. deciders/selfask.py +150 -0
  20. deciders/spp.py +142 -0
  21. deciders/utils.py +65 -0
  22. distillers/__init__.py +10 -0
  23. distillers/guidance_summary_few_shot_examples.txt +85 -0
  24. distillers/guider.py +144 -0
  25. distillers/raw_prompt_generator.py +16 -0
  26. distillers/reflexion_few_shot_examples.txt +75 -0
  27. distillers/self_reflection.py +53 -0
  28. distillers/traj_prompt_summarizer.py +46 -0
  29. distillers/traj_summary_few_shot_examples.txt +76 -0
  30. draw_overall_performance.py +59 -0
  31. environment.yml +193 -0
  32. envs/__init__.py +51 -0
  33. envs/base_env.py +97 -0
  34. envs/box2d/LunarLander_policies.py +36 -0
  35. envs/box2d/LunarLander_translator.py +67 -0
  36. envs/box2d/__init__.py +0 -0
  37. envs/box2d/few_shot_examples/lunarlander_l2.json +0 -0
  38. envs/box2d/few_shot_examples/lunarlander_l4.json +0 -0
  39. envs/classic_control/__init__.py +0 -0
  40. envs/classic_control/acrobot_policies.py +36 -0
  41. envs/classic_control/acrobot_translator.py +58 -0
  42. envs/classic_control/cartpole_policies.py +25 -0
  43. envs/classic_control/cartpole_translator.py +57 -0
  44. envs/classic_control/few_shot_examples/acrobot_l2.json +0 -0
  45. envs/classic_control/few_shot_examples/acrobot_l4.json +0 -0
  46. envs/classic_control/few_shot_examples/cartpole_l2.json +0 -0
  47. envs/classic_control/few_shot_examples/cartpole_l4.json +0 -0
  48. envs/classic_control/few_shot_examples/mountaincarContinuous_l2.json +0 -0
  49. envs/classic_control/few_shot_examples/mountaincarContinuous_l4.json +0 -0
  50. envs/classic_control/few_shot_examples/mountaincar_l2.json +0 -0
README.md ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Bench LLM Deciders with gym translators
2
+ This project provides a set of translators to convert OpenAI Gym environments into text-based environments. It is designed to investigate the capabilities of large language models in decision-making tasks within these text-based environments.
3
+
4
+ ## Summarizer Levels
5
+ We translate the game with basic level descriptions. It provides a simple description of the current state of the game. It's suitable for beginners who are just getting familiar with the game.
6
+ ## Environment Categories
7
+ The environments are categorized based on the information that revealed to agents. We propose *5 level* scenarios.
8
+
9
+ **L1**: No external information is given. Only abstract game description. (zero shot)
10
+
11
+ **L2**: Agents can take a sampling traj of the random policy as external knowledge. (few shots, off-policy info)
12
+
13
+ **L3**: self sampling and updating w/ feedback. (few shots, on-policy info)
14
+
15
+ **L4**: sampling traj of an expert policy (few shots, expert-info)
16
+
17
+ **L5**: expert teaching (few shots, expert-info with guidance)
18
+
19
+ The five level scenarios are mainly considering making decision with perception. For future world, we leave it to stage 2 investigation.
20
+
21
+ **Perception and Future World**: These environments provide a perception of the current state, and also predict future infos. The futrue info is given in the info dict at step and reset.
22
+
23
+ It should be noted that the past memory part should be implemented as a component of deciders.
24
+
25
+ ## Fewshot Examples Generation
26
+ For `L1` level, the `[]` is given.
27
+ For `L2` and `L4` level, we use `gen_few_shots_examples.py` to generate corresponding examples in json format and place them in the `envs/*/few_shot_examples/`.
28
+ For `L3` level, agent should collect the examples on their own and only a few methods support it. Thus we leave it to the agent design.
29
+ For `L5` level, we handcraft the few shot examples with domain knowledge in `prompts/task_relevant`.
30
+
31
+ ## Usage
32
+
33
+ 1. create `./deciders/gpt.py` to provide your gpt agent:
34
+ ```python
35
+ import openai
36
+ class gpt:
37
+ def __init__(self,):
38
+ openai.api_type = "azure"
39
+ openai.api_version = "2023-05-15"
40
+ # Your Azure OpenAI resource's endpoint value.
41
+ openai.api_base = "https://js-partner.openai.azure.com/"
42
+ openai.api_key = "your azure openai key"
43
+ ```
44
+
45
+ 2. Install Requirements
46
+
47
+ ```
48
+ conda env create --file environment.yml
49
+ ```
50
+
51
+ 3. Testing
52
+ The project can be run using the provided test.sh script. This script runs a series of commands, each of which initiates a Gym environment and applies different translators to it.
53
+
54
+ Here is an example of how to run the script:
55
+
56
+ ```
57
+ ./test.sh
58
+ ```
59
+ The commands in test.sh are structured as follows:
60
+
61
+ ```
62
+ python main.py --env_name ENV_NAME --init_summarizer INIT_SUMMARIZER --curr_summarizer CURR_SUMMARIZER [--future_summarizer FUTURE_SUMMARIZER --future_horizon FUTURE_HORIZON]
63
+ ```
64
+ Where:
65
+
66
+ * ENV_NAME: The name of the Gym environment to be used (e.g., CartPole-v0).
67
+ * INIT_SUMMARIZER: The initial summarizer to be used (e.g., cart_init_translator).
68
+ * CURR_SUMMARIZER: The current summarizer to be used (e.g., cart_basic_translator).
69
+ * FUTURE_SUMMARIZER (optional): The future summarizer to be used (e.g., cart_basic_translator).
70
+ * FUTURE_HORIZON (optional): The horizon that each policy will look to (e.g., 3).
71
+
72
+ ## Supported Environment Translators and LLM Deciders
73
+
74
+ | | Acrobot | Cart Pole | Mountain Car | Pendulum | Lunar Lander | Blackjack | Taxi | Cliff Walking | Frozen Lake |
75
+ |------------------------------|:------------------------:|:----------------------------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|
76
+ | Translator | :heavy_multiplication_x: | :white_check_mark: | :heavy_multiplication_x: | :heavy_multiplication_x: | :white_check_mark: | :heavy_multiplication_x: | :heavy_multiplication_x: | :heavy_multiplication_x: | :heavy_multiplication_x: |
77
+ | Chain-of-Thought | :heavy_minus_sign: | :white_check_mark:(L1)<br>:gift:<sup>[1]</sup>(~30) | :heavy_minus_sign: | :heavy_minus_sign: | :white_check_mark:(L1)<br/>:gift:<sup>[1]</sup>(-367) | :heavy_minus_sign: | :heavy_minus_sign: | :heavy_minus_sign: | :heavy_minus_sign: |
78
+ | Program-aided Language Model | :heavy_minus_sign: | :white_check_mark:(L1)<br>:gift:(168) | :heavy_minus_sign: | :heavy_minus_sign: | :white_check_mark:(L1)<br/>:gift:(-68) | :heavy_minus_sign: | :heavy_minus_sign: | :heavy_minus_sign: | :heavy_minus_sign: |
79
+ | Self-ask Prompting | :heavy_minus_sign: | :white_check_mark:(L1)<br>:gift:(~10) | :heavy_minus_sign: | :heavy_minus_sign: | :heavy_multiplication_x: | :heavy_minus_sign: | :heavy_minus_sign: | :heavy_minus_sign: | :heavy_minus_sign: |
80
+ | Self-consistency Prompting | :heavy_minus_sign: | :white_check_mark:(L1)<br>:gift:(~30) | :heavy_minus_sign: | :heavy_minus_sign: | :heavy_multiplication_x: | :heavy_minus_sign: | :heavy_minus_sign: | :heavy_minus_sign: | :heavy_minus_sign: |
81
+ | Reflexion | :heavy_minus_sign: | :heavy_multiplication_x: | :heavy_minus_sign: | :heavy_minus_sign: | :heavy_multiplication_x: | :heavy_minus_sign: | :heavy_minus_sign: | :heavy_minus_sign: | :heavy_minus_sign: |
82
+ | Solo Performance Prompting | :heavy_minus_sign: | :white_check_mark:(L1)<br/>:gift:(43) | :heavy_minus_sign: | :heavy_minus_sign: | :white_check_mark:(L1)<br/>:gift:(-583) | :heavy_minus_sign: | :heavy_minus_sign: | :heavy_minus_sign: | :heavy_minus_sign: |
83
+
84
+ <sup>[1]: Cumulative reward.</sup>
85
+ ![Image text](https://github.com/mail-ecnu/LLM-Decider-Bench/blob/master/vis/Classic%20Control.png)
86
+ ![Image text](https://github.com/mail-ecnu/LLM-Decider-Bench/blob/master/vis/Box%202D.png)
87
+ ![Image text](https://github.com/mail-ecnu/LLM-Decider-Bench/blob/master/vis/Toy%20Text.png)
88
+
89
+ >
90
+ > 1. Except for the reflexion L3 decider, all other L3 deciders in this task do not have memory.
91
+ > 2. reflexion L1 and L3 both have memory.
92
+ > 3. reflexion L1 run 5 trails.
93
+ > 4. Blackjack、MountainCar、Cliffwalking(PAL)、CartPole(PAL)、Taxi(SPP、PAL)、Frozen Lake use deciders modified at 15:29 09.18
94
+ > 5. update Frozen Lake translator, add prior knowledge.
95
+ # Remarks
96
+ 1. how to use future info
97
+ We provide future info in the env_info part. It is a dict and you can convert it to a text further to make your agent aware the world model.
RL_based/test_RL.sh ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # # ppo for cartpole-v0
2
+ # CUDA_VISIBLE_DEVICES=1 python RL_based/train_PPO.py --env_name CartPole-v0 --init_summarizer cart_init_translator --curr_summarizer cart_basic_translator\
3
+ # --trans_model_name distilbert-base-uncased --model_name nn_embedding --eval --policy-path RL_based/checkpoints/CartPole-v0/expert/policy.pth --collect_one_episode
4
+
5
+ # # ppo for lunarlander: treasured-music-91 score: 164.66
6
+ # TRANSFORMERS_OFFLINE=1 \
7
+ # CUDA_VISIBLE_DEVICES=2 python RL_based/train_PPO.py --env_name LunarLander-v2 --init_summarizer lunarLander_init_translator --curr_summarizer lunarLander_basic_translator \
8
+ # --trans_model_name /home/ubuntu/LLM-Decider-Bench/RL_based/transformer_offline_distilbert --model_name nn_embedding --max_length 128 --eval --collect_one_episode --policy-path /home/ubuntu/LLM-Decider-Bench/RL_based/checkpoints/LunarLander-v2/expert/policy.pth
9
+
10
+ # ppo for Acrobot-v1: charmed-salad-93 score: -85.8
11
+ # TRANSFORMERS_OFFLINE=1 \
12
+ # CUDA_VISIBLE_DEVICES=0 python RL_based/train_PPO.py --env_name Acrobot-v1 --init_summarizer acrobot_init_translator --curr_summarizer acrobot_basic_translator --decider naive_actor --prompt_level 1\
13
+ # --trans_model_name /home/ubuntu/LLM-Decider-Bench/RL_based/transformer_offline_distilbert --model_name nn_embedding --max_length 128 --eval --collect_one_episode --policy-path /home/ubuntu/LLM-Decider-Bench/RL_based/checkpoints/Acrobot-v1/expert/policy.pth
14
+
15
+ # # # # ppo for MountainCar-v0:
16
+ # TRANSFORMERS_OFFLINE=1 \
17
+ # CUDA_VISIBLE_DEVICES=1 python RL_based/train_PPO.py --env_name MountainCar-v0 --init_summarizer mountaincar_init_translator --curr_summarizer mountaincar_basic_translator --decider naive_actor --prompt_level 1\
18
+ # --trans_model_name /home/ubuntu/LLM-Decider-Bench/RL_based/transformer_offline_distilbert --model_name nn_embedding --max_length 128 --eval --collect_one_episode --policy-path /home/ubuntu/LLM-Decider-Bench/RL_based/checkpoints/MountainCar-v0/expert/policy.pth
19
+
20
+ # # ppo for Blackjack-v1
21
+ # TRANSFORMERS_OFFLINE=1 \
22
+ # CUDA_VISIBLE_DEVICES=2 python RL_based/train_PPO.py --env_name Blackjack-v1 --init_summarizer blackjack_init_translator --curr_summarizer blackjack_basic_translator --decider naive_actor --prompt_level 1\
23
+ # --trans_model_name /home/ubuntu/LLM-Decider-Bench/RL_based/transformer_offline_distilbert --model_name nn_embedding --max_length 128 --eval --collect_one_episode --policy-path /home/ubuntu/LLM-Decider-Bench/RL_based/checkpoints/Blackjack-v1/expert/policy.pth
24
+
25
+ # # # ppo for Taxi-v3
26
+ TRANSFORMERS_OFFLINE=1 \
27
+ CUDA_VISIBLE_DEVICES=6 python RL_based/train_PPO.py --env_name Taxi-v3 --init_summarizer taxi_init_translator --curr_summarizer taxi_basic_translator --decider naive_actor --prompt_level 1\
28
+ --trans_model_name /home/ubuntu/LLM-Decider-Bench/RL_based/transformer_offline_distilbert --model_name nn_embedding --max_length 128 --eval --collect_one_episode --policy-path /home/ubuntu/LLM-Decider-Bench/RL_based/checkpoints/Taxi-v3/expert/policy.pth
29
+
30
+ # # # ppo for CliffWalking-v0
31
+ # TRANSFORMERS_OFFLINE=1 \
32
+ # CUDA_VISIBLE_DEVICES=4 python RL_based/train_PPO.py --env_name CliffWalking-v0 --init_summarizer cliffwalking_init_translator --curr_summarizer cliffwalking_basic_translator --decider naive_actor --prompt_level 1\
33
+ # --trans_model_name /home/ubuntu/LLM-Decider-Bench/RL_based/transformer_offline_distilbert --model_name nn_embedding --max_length 128 --eval --collect_one_episode --policy-path /home/ubuntu/LLM-Decider-Bench/RL_based/checkpoints/CliffWalking-v0/expert/policy.pth
34
+
35
+ # # # ppo for FrozenLake-v1
36
+ # TRANSFORMERS_OFFLINE=1 \
37
+ # CUDA_VISIBLE_DEVICES=5 python RL_based/train_PPO.py --env_name FrozenLake-v1 --init_summarizer frozenlake_init_translator --curr_summarizer frozenlake_basic_translator --decider naive_actor --prompt_level 1\
38
+ # --trans_model_name /home/ubuntu/LLM-Decider-Bench/RL_based/transformer_offline_distilbert --model_name nn_embedding --max_length 128 --eval --collect_one_episode --policy-path /home/ubuntu/LLM-Decider-Bench/RL_based/checkpoints/FrozenLake-v1/expert/policy.pth
39
+
RL_based/train_PPO.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import sys
3
+ sys.path.insert(0, sys.path[0]+"/../")
4
+ import prompts as task_prompts
5
+ import envs
6
+ import os
7
+ from envs.translator import InitSummarizer, CurrSummarizer, FutureSummarizer, Translator
8
+ import gym
9
+ from torch.optim.lr_scheduler import LambdaLR
10
+ import torch
11
+ from tianshou.data import Collector, VectorReplayBuffer, ReplayBuffer
12
+ from tianshou.env import DummyVectorEnv, SubprocVectorEnv
13
+ from tianshou.policy import PPOPolicy, ICMPolicy
14
+ from tianshou.trainer import onpolicy_trainer
15
+ from tianshou.utils.net.common import ActorCritic
16
+ from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule
17
+ from RL_based.utils import Net_GRU_Bert_tianshou, Net_Bert_CLS_tianshou, Net_Bert_CNN_tianshou, Net_GRU_nn_emb_tianshou
18
+ from tianshou.utils import WandbLogger
19
+ from torch.utils.tensorboard import SummaryWriter
20
+ from tianshou.trainer.utils import test_episode
21
+
22
+ import warnings
23
+ warnings.filterwarnings('ignore')
24
+
25
+ class MaxStepLimitWrapper(gym.Wrapper):
26
+ def __init__(self, env, max_steps=200):
27
+ super(MaxStepLimitWrapper, self).__init__(env)
28
+ self.max_steps = max_steps
29
+ self.current_step = 0
30
+
31
+ def reset(self, **kwargs):
32
+ self.current_step = 0
33
+ return self.env.reset(**kwargs)
34
+
35
+ def step(self, action):
36
+ observation, reward, terminated, truncated, info = self.env.step(action)
37
+ self.current_step += 1
38
+
39
+ if self.current_step >= self.max_steps:
40
+ terminated = True
41
+ info['episode_step_limit'] = self.max_steps
42
+
43
+ return observation, reward, terminated, truncated, info
44
+
45
+ class SimpleTextWrapper(gym.Wrapper):
46
+ def __init__(self, env):
47
+ super(SimpleTextWrapper, self).__init__(env)
48
+ self.env = env
49
+
50
+ def reset(self, **kwargs):
51
+ observation, _ = self.env.reset(**kwargs)
52
+ return str(observation), {}
53
+
54
+ def step(self, action):
55
+ observation, reward, terminated, truncated, info = self.env.step(action)
56
+ return str(observation), reward, terminated, truncated, info
57
+
58
+ if __name__ == "__main__":
59
+ parser = argparse.ArgumentParser(description='Evaluate a translator in a gym environment with a ChatGPT model.')
60
+ parser.add_argument('--init_summarizer', type=str, required=True, help='The name of the init summarizer to use.')
61
+ parser.add_argument('--curr_summarizer', type=str, required=True, help='The name of the curr summarizer to use.')
62
+ parser.add_argument('--future_summarizer', type=str, help='The name of the future summarizer to use.')
63
+ parser.add_argument('--env', type=str, default='base_env', help='The name of the gym environment to use.')
64
+ parser.add_argument('--env_name', type=str, default='CartPole-v1', help='The name of the gym environment to use.')
65
+ parser.add_argument('--decider', type=str, default="naive_actor", help='The actor used to select action')
66
+ parser.add_argument('--render', type=str, default="rgb_array", help='The render mode')
67
+ parser.add_argument('--future_horizon', type=int, help='The horizon of looking to future')
68
+ parser.add_argument(
69
+ "--prompt_level",
70
+ type=int,
71
+ default=1,
72
+ help="The level of prompts",
73
+ )
74
+ parser.add_argument(
75
+ "--past_horizon", type=int, help="The horizon of looking back"
76
+ )
77
+ parser.add_argument(
78
+ "--max_episode_len", type=int, default=200, help="The max length of an episode"
79
+ )
80
+
81
+ ### for RL training
82
+ parser.add_argument('--max_length', type=int, default=128, help='The token length of the observation')
83
+ # trans_model_name
84
+ parser.add_argument('--trans_model_name', type=str, default='bert-base-uncased', help='The name of the pretrained transformer to use.')
85
+ parser.add_argument('--model_name', type=str, default='bert-embedding', help='The name of the model to use.')
86
+ parser.add_argument('--vector_env', type=str, default='dummy', help='The name of the vector env to use.')
87
+ parser.add_argument('--eval', action='store_true', default=False, help='Whether to only eval the model')
88
+ parser.add_argument('--policy-path', type=str, default=None, help='The path to the policy to be evaluated')
89
+ parser.add_argument('--collect_one_episode', action='store_true', default=False, help='Whether to only collect one episode')
90
+ parser.add_argument('--lr', type=float, default=0.0003, help='The learning rate of the model')
91
+ parser.add_argument('--step_per_epoch', type=int, default=10000, help='The number of steps per epoch')
92
+ parser.add_argument('--step_per_collect', type=int, default=2000, help='The number of steps per collect')
93
+ parser.add_argument('--lr_decay', action='store_true', default=False, help='Whether to decay the learning rate')
94
+ parser.add_argument('--epoch', type=int, default=400, help='The number of epochs to train')
95
+ parser.add_argument('--resume_path', type=str, default=None, help='The path to the policy to be resumed')
96
+ parser.add_argument('--taxi_specific_env', action='store_true', default=False, help='Whether to use taxi specific env')
97
+ args = parser.parse_args()
98
+ args_dict = vars(args)
99
+
100
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
101
+ # Get the specified translator, environment, and ChatGPT model
102
+ env_class = envs.REGISTRY[args.env]
103
+ init_summarizer = InitSummarizer(envs.REGISTRY[args.init_summarizer])
104
+ curr_summarizer = CurrSummarizer(envs.REGISTRY[args.curr_summarizer])
105
+ if args.future_summarizer:
106
+ future_summarizer = FutureSummarizer(
107
+ envs.REGISTRY[args.future_summarizer],
108
+ envs.REGISTRY["cart_policies"],
109
+ future_horizon=args.future_horizon,
110
+ )
111
+ else:
112
+ future_summarizer = None
113
+
114
+ wandb_log_config = {
115
+ "env": args.env_name,
116
+ "init_summarizer": args.init_summarizer,
117
+ "curr_summarizer": args.curr_summarizer,
118
+ "future_summarizer": args.future_summarizer,
119
+ }
120
+ wandb_log_config.update(args_dict)
121
+
122
+ if not args.eval:
123
+ logger = WandbLogger(
124
+ project="LLM-decider-bench-RL",
125
+ entity="llm-bench-team",
126
+ config=wandb_log_config,
127
+ )
128
+ random_name = logger.wandb_run.name
129
+ log_path = os.path.join('/home/ubuntu/LLM-Decider-Bench/RL_based/results', args.env_name, random_name)
130
+ writer = SummaryWriter(log_dir=log_path)
131
+ writer.add_text("args", str(args))
132
+ logger.load(writer)
133
+ def save_best_fn(policy):
134
+ torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
135
+
136
+ sampling_env = envs.REGISTRY["sampling_wrapper"](gym.make(args.env_name))
137
+ if args.prompt_level == 5:
138
+ prompts_class = task_prompts.REGISTRY[(args.env_name,args.decider)]()
139
+ else:
140
+ prompts_class = task_prompts.REGISTRY[(args.decider)]()
141
+ translator = Translator(
142
+ init_summarizer, curr_summarizer, future_summarizer, env=sampling_env
143
+ )
144
+ if args.taxi_specific_env:
145
+ environment = gym.make(args.env_name, render_mode=args.render)
146
+ else:
147
+ environment = env_class(
148
+ gym.make(args.env_name, render_mode=args.render), translator
149
+ )
150
+
151
+ # Set the translation level
152
+ translate_level = 1
153
+ if args.past_horizon is None and args.future_horizon is None:
154
+ translate_level = 1
155
+ if args.past_horizon and args.future_horizon is None:
156
+ raise NotImplementedError
157
+ # translate_level = 2
158
+ if args.past_horizon is None and args.future_horizon:
159
+ raise NotImplementedError
160
+ # translate_level = 3
161
+ if args.past_horizon and args.future_horizon:
162
+ raise NotImplementedError
163
+ # translate_level = 3.5
164
+
165
+
166
+ if args.vector_env == 'dummy':
167
+ ThisEnv = DummyVectorEnv
168
+ elif args.vector_env == 'subproc':
169
+ ThisEnv = SubprocVectorEnv
170
+ def make_env():
171
+ if args.taxi_specific_env:
172
+ env = MaxStepLimitWrapper(SimpleTextWrapper(gym.make(args.env_name, render_mode=args.render)), max_steps=200)
173
+ env._max_episode_steps = args.max_episode_len
174
+ else:
175
+ env = env_class(MaxStepLimitWrapper(gym.make(args.env_name, render_mode=args.render), max_steps=200), translator)
176
+ env._max_episode_steps = args.max_episode_len
177
+
178
+ return env
179
+ train_envs = ThisEnv([make_env for _ in range(20)])
180
+ test_envs = ThisEnv([make_env for _ in range(10)])
181
+ # model & optimizer
182
+ def get_net():
183
+ if args.model_name == "bert-embedding":
184
+ net = Net_GRU_Bert_tianshou(state_shape=environment.observation_space.shape, hidden_sizes=[64, 64], device=device, max_length=args.max_length, trans_model_name=args.trans_model_name)
185
+ elif args.model_name == "bert-CLS-embedding":
186
+ net = Net_Bert_CLS_tianshou(state_shape=environment.observation_space.shape, hidden_sizes=[256, 128], device=device, max_length=args.max_length, trans_model_name=args.trans_model_name)
187
+ elif args.model_name == "bert-CNN-embedding":
188
+ net = Net_Bert_CNN_tianshou(state_shape=environment.observation_space.shape, hidden_sizes=[256, 128], device=device, max_length=args.max_length, trans_model_name=args.trans_model_name)
189
+ elif args.model_name == "nn_embedding":
190
+ net = Net_GRU_nn_emb_tianshou(hidden_sizes=[256, 128], device=device, max_length=args.max_length, trans_model_name=args.trans_model_name)
191
+ return net
192
+ net = get_net()
193
+ actor = Actor(net, environment.action_space.n, device=device).to(device)
194
+ critic = Critic(net, device=device).to(device)
195
+ actor_critic = ActorCritic(actor, critic)
196
+ optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr)
197
+
198
+ # PPO policy
199
+ dist = torch.distributions.Categorical
200
+ lr_scheduler = None
201
+ if args.lr_decay:
202
+ max_update_num = args.step_per_epoch // args.step_per_collect * args.epoch
203
+
204
+ lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
205
+ policy = PPOPolicy(actor, critic, optim, dist, action_space=environment.action_space, lr_scheduler=lr_scheduler).to(device)
206
+ # collector
207
+ train_collector = Collector(policy, train_envs, VectorReplayBuffer(20000, len(train_envs)), exploration_noise=True)
208
+ test_collector = Collector(policy, test_envs, exploration_noise=True)
209
+
210
+ if not args.eval:
211
+ # trainer
212
+ # test train_collector and start filling replay buffer
213
+
214
+ if args.resume_path:
215
+ policy.load_state_dict(torch.load(args.resume_path, map_location='cuda'))
216
+ print("Loaded agent from: ", args.resume_path)
217
+
218
+ train_collector.collect(256 * 20)
219
+ result = onpolicy_trainer(
220
+ policy,
221
+ train_collector,
222
+ test_collector,
223
+ max_epoch=args.epoch,
224
+ step_per_epoch=50000, # the number of transitions collected per epoch
225
+ repeat_per_collect=4,
226
+ episode_per_test=10,
227
+ batch_size=256,
228
+ logger=logger,
229
+ step_per_collect=1000, # the number of transitions the collector would collect before the network update
230
+ save_best_fn=save_best_fn,
231
+ # stop_fn=lambda mean_reward: mean_reward >= environment.spec.reward_threshold,
232
+ )
233
+ print(result)
234
+ else:
235
+ assert args.policy_path is not None
236
+ policy.load_state_dict(torch.load(args.policy_path))
237
+ test_collector = Collector(policy, test_envs)
238
+ result = test_episode(policy, test_collector, None, None, n_episode=10)
239
+ print(result)
240
+ if args.collect_one_episode:
241
+ replaybuffer = ReplayBuffer(size=1000)
242
+ test_collector_1 = Collector(policy, environment, replaybuffer)
243
+ test_collector_1.reset_env()
244
+ test_collector_1.reset_buffer()
245
+ policy.eval()
246
+ result = test_collector_1.collect(n_episode=1)
247
+ print('sample results', f"/home/ubuntu/LLM-Decider-Bench/RL_based/checkpoints/{args.env_name}/output.txt")
248
+ sample_result = replaybuffer.sample(0)
249
+ f = open(f"/home/ubuntu/LLM-Decider-Bench/RL_based/checkpoints/{args.env_name}/output.txt", "w")
250
+ print(sample_result, file=f)
251
+ f.close()
RL_based/train_RL.sh ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # # ppo for cartpole
2
+ # CUDA_VISIBLE_DEVICES=1 python RL_based/train_PPO.py --env_name CartPole-v0 --init_summarizer cart_init_translator --curr_summarizer cart_basic_translator\
3
+ # --trans_model_name /home/ubuntu/LLM-Decider-Bench/RL_based/transformer_offline_distilbert --model_name nn_embedding
4
+
5
+ # # ppo for lunarlander
6
+ # TRANSFORMERS_OFFLINE=1 \
7
+ # CUDA_VISIBLE_DEVICES=3 python RL_based/train_PPO.py --env_name LunarLander-v2 --init_summarizer lunarLander_init_translator --curr_summarizer lunarLander_basic_translator \
8
+ # --trans_model_name /home/ubuntu/LLM-Decider-Bench/RL_based/transformer_offline_distilbert --model_name nn_embedding --max_length 128 --lr 0.0003 --lr_decay --epoch 500
9
+
10
+ # ppo for Acrobot-v1
11
+ # TRANSFORMERS_OFFLINE=1 \
12
+ # CUDA_VISIBLE_DEVICES=0 python RL_based/train_PPO.py --env_name Acrobot-v1 --init_summarizer acrobot_init_translator --curr_summarizer acrobot_basic_translator --decider naive_actor --prompt_level 1\
13
+ # --trans_model_name /home/ubuntu/LLM-Decider-Bench/RL_based/transformer_offline_distilbert --model_name nn_embedding --max_length 128 --lr 0.0003 --lr_decay --epoch 500 &
14
+
15
+ # # # ppo for MountainCar-v0
16
+ # TRANSFORMERS_OFFLINE=1 \
17
+ # CUDA_VISIBLE_DEVICES=1 python RL_based/train_PPO.py --env_name MountainCar-v0 --init_summarizer mountaincar_init_translator --curr_summarizer mountaincar_basic_translator --decider naive_actor --prompt_level 1\
18
+ # --trans_model_name /home/ubuntu/LLM-Decider-Bench/RL_based/transformer_offline_distilbert --model_name nn_embedding --max_length 300 --lr 0.0003 --lr_decay --epoch 500 &
19
+
20
+ # ppo for Blackjack-v1
21
+ # TRANSFORMERS_OFFLINE=1 \
22
+ # CUDA_VISIBLE_DEVICES=2 python RL_based/train_PPO.py --env_name Blackjack-v1 --init_summarizer blackjack_init_translator --curr_summarizer blackjack_basic_translator --decider naive_actor --prompt_level 1\
23
+ # --trans_model_name /home/ubuntu/LLM-Decider-Bench/RL_based/transformer_offline_distilbert --model_name nn_embedding --max_length 300 --lr 0.0003 --lr_decay --epoch 500 &
24
+
25
+ # # ppo for Taxi-v3
26
+ TRANSFORMERS_OFFLINE=1 \
27
+ CUDA_VISIBLE_DEVICES=6 python RL_based/train_PPO.py --env_name Taxi-v3 --init_summarizer taxi_init_translator --curr_summarizer taxi_basic_translator --decider naive_actor --prompt_level 1\
28
+ --trans_model_name /home/ubuntu/LLM-Decider-Bench/RL_based/transformer_offline_distilbert --model_name nn_embedding --max_length 300 --lr 0.0003 --lr_decay --epoch 500 --taxi_specific_env
29
+
30
+ # # ppo for CliffWalking-v0
31
+ # TRANSFORMERS_OFFLINE=1 \
32
+ # CUDA_VISIBLE_DEVICES=4 python RL_based/train_PPO.py --env_name CliffWalking-v0 --init_summarizer cliffwalking_init_translator --curr_summarizer cliffwalking_basic_translator --decider naive_actor --prompt_level 1\
33
+ # --trans_model_name /home/ubuntu/LLM-Decider-Bench/RL_based/transformer_offline_distilbert --model_name nn_embedding --max_length 300 --lr 0.0003 --lr_decay --epoch 500 &
34
+
35
+ # # ppo for FrozenLake-v1
36
+ # TRANSFORMERS_OFFLINE=1 \
37
+ # CUDA_VISIBLE_DEVICES=5 python RL_based/train_PPO.py --env_name FrozenLake-v1 --init_summarizer frozenlake_init_translator --curr_summarizer frozenlake_basic_translator --decider naive_actor --prompt_level 1\
38
+ # --trans_model_name /home/ubuntu/LLM-Decider-Bench/RL_based/transformer_offline_distilbert --model_name nn_embedding --max_length 300 --lr 0.0003 --lr_decay --epoch 500 &
39
+
RL_based/utils.py ADDED
@@ -0,0 +1,621 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+ sys.path.insert(0, sys.path[0]+"/../")
6
+ from typing import (
7
+ Any,
8
+ Dict,
9
+ List,
10
+ Optional,
11
+ Sequence,
12
+ Tuple,
13
+ Type,
14
+ Union,
15
+ no_type_check,
16
+ )
17
+ import torch.nn as nn
18
+ from tianshou.utils.net.discrete import NoisyLinear
19
+ ModuleType = Type[nn.Module]
20
+ import random
21
+ from collections import namedtuple, deque
22
+ from itertools import count
23
+ import math
24
+ import torch
25
+ import torch.optim as optim
26
+ from transformers import AutoModel, AutoTokenizer
27
+ import torch.nn.functional as F
28
+ from tianshou.utils.net.common import ModuleType, Net, MLP
29
+
30
+
31
+ def bert_embedding(x, max_length=512, device='cuda'):
32
+ from transformers import logging
33
+ logging.set_verbosity_error()
34
+ model_name = 'bert-base-uncased'
35
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
36
+ bert_model = AutoModel.from_pretrained(model_name)
37
+ text = x
38
+ if isinstance(text, np.ndarray):
39
+ text = list(text)
40
+ tokens = tokenizer(text, max_length=max_length, padding='max_length', truncation=True, return_tensors='pt')
41
+ input_ids = tokens['input_ids']
42
+ attention_mask = tokens['attention_mask']
43
+ with torch.no_grad():
44
+ outputs = bert_model(input_ids, attention_mask=attention_mask)
45
+ embeddings = outputs.last_hidden_state
46
+ return embeddings
47
+
48
+ class Net_GRU(nn.Module):
49
+
50
+ def __init__(self, input_size, n_actions, hidden_dim, n_layers, dropout, bidirectional):
51
+ super(Net_GRU, self).__init__()
52
+ self.input_size = input_size
53
+ self.hidden_dim = hidden_dim
54
+ self.num_classes = n_actions
55
+ self.n_layers = n_layers
56
+ self.dropout = dropout
57
+ self.bidirectional = bidirectional
58
+
59
+ # Layers
60
+ self.gru = nn.GRU(self.input_size, self.hidden_dim, self.n_layers,
61
+ batch_first=True, dropout=self.dropout, bidirectional=self.bidirectional)
62
+ self.final_layer = nn.Linear(self.hidden_dim*(1 + int(self.bidirectional)), self.num_classes)
63
+
64
+ def forward(self, x):
65
+ # Input shape: (batch_size, seq_length)
66
+ batch_size, seq_length, emb_size = x.size()
67
+
68
+ gru_out, hidden = self.gru(x)
69
+
70
+ # Use the final state
71
+ # hidden -> (num_direction, batch, hidden_size)
72
+ if self.bidirectional:
73
+ hidden = hidden.view(self.n_layers, 2, batch_size, self.hidden_dim)
74
+ final_hidden = torch.cat((hidden[:, -1, :, :].squeeze(0), hidden[:, 0, :, :].squeeze(0)), 1)
75
+ else:
76
+ final_hidden = hidden.squeeze(0)
77
+
78
+ # final_hidden -> (batch_size, num_classes)
79
+ logits = self.final_layer(final_hidden)
80
+
81
+ return logits
82
+
83
+ class MyGRU(nn.Module):
84
+ def __init__(self, input_size, hidden_dim, n_layers, dropout, bidirectional, output_dim):
85
+ super(MyGRU, self).__init__()
86
+ self.input_size = input_size
87
+ self.hidden_dim = hidden_dim
88
+ self.n_layers = n_layers
89
+ self.dropout = dropout
90
+ self.bidirectional = bidirectional
91
+
92
+ # Layers
93
+ self.gru = nn.GRU(self.input_size, self.hidden_dim, self.n_layers,
94
+ batch_first=True, dropout=self.dropout, bidirectional=self.bidirectional)
95
+ self.final_layer = nn.Linear(self.hidden_dim*(1 + int(self.bidirectional)), output_dim)
96
+
97
+ def forward(self, x):
98
+ batch_size, seq_length, emb_size = x.size()
99
+
100
+ gru_out, hidden = self.gru(x)
101
+
102
+ # Use the final state
103
+ # hidden -> (num_direction, batch, hidden_size)
104
+ if self.bidirectional:
105
+ hidden = hidden.view(self.n_layers, 2, batch_size, self.hidden_dim)
106
+ final_hidden = torch.cat((hidden[:, -1, :, :].squeeze(0), hidden[:, 0, :, :].squeeze(0)), 1)
107
+ else:
108
+ final_hidden = hidden.squeeze(0)
109
+
110
+ # final_hidden -> (batch_size, num_classes)
111
+ logits = self.final_layer(final_hidden)
112
+
113
+ return logits
114
+
115
+ class MyCNN(nn.Module):
116
+ def __init__(self,
117
+ input_dim: int,
118
+ output_dim: int = 0,
119
+ hidden_sizes: Sequence[int] = (),
120
+ norm_layer: Optional[Union[ModuleType, Sequence[ModuleType]]] = None,
121
+ activation: ModuleType = nn.ReLU,
122
+ device: Optional[Union[str, int, torch.device]] = None,
123
+ linear_layer: Type[nn.Linear] = nn.Linear,
124
+ flatten_input: bool = True,) -> None:
125
+ super().__init__()
126
+ self.model = []
127
+ input_dim_temp = input_dim
128
+ for h in hidden_sizes:
129
+ self.model.append(nn.Conv1d(in_channels=input_dim_temp, out_channels=h, kernel_size=3, padding=1))
130
+ self.model.append(activation())
131
+ self.model.append(nn.MaxPool1d(kernel_size=2))
132
+ input_dim_temp = h
133
+ self.model = nn.Sequential(*self.model)
134
+ self.fc = nn.Linear(in_features=input_dim_temp, out_features=output_dim)
135
+
136
+ def forward(self, x):
137
+ x = self.model(x.transpose(1, 2))
138
+ x.transpose_(1, 2)
139
+ x = self.fc(x)
140
+ return x
141
+
142
+ class Net_GRU_Bert_tianshou(Net):
143
+ def __init__(
144
+ self,
145
+ state_shape: Union[int, Sequence[int]],
146
+ action_shape: Union[int, Sequence[int]] = 0,
147
+ hidden_sizes: Sequence[int] = (),
148
+ norm_layer: Optional[ModuleType] = None,
149
+ activation: Optional[ModuleType] = nn.ReLU,
150
+ device: Union[str, int, torch.device] = "cpu",
151
+ softmax: bool = False,
152
+ concat: bool = False,
153
+ num_atoms: int = 1,
154
+ dueling_param: Optional[Tuple[Dict[str, Any], Dict[str, Any]]] = None,
155
+ linear_layer: Type[nn.Linear] = nn.Linear,
156
+ hidden_dim: int = 128,
157
+ bidirectional: bool = True,
158
+ dropout: float = 0.,
159
+ n_layers: int = 1,
160
+ max_length: int = 512,
161
+ trans_model_name: str = 'bert-base-uncased',
162
+ ) -> None:
163
+ nn.Module.__init__(self)
164
+ self.device = device
165
+ self.softmax = softmax
166
+ self.num_atoms = num_atoms
167
+ self.hidden_dim = hidden_dim
168
+ self.bidirectional = bidirectional
169
+ self.dropout = dropout
170
+ self.n_layers = n_layers
171
+ self.trans_model_name = trans_model_name
172
+ self.max_length = max_length
173
+
174
+ input_dim = int(np.prod(state_shape))
175
+ action_dim = int(np.prod(action_shape)) * num_atoms
176
+ if concat:
177
+ input_dim += action_dim
178
+ self.use_dueling = dueling_param is not None
179
+ output_dim = action_dim if not self.use_dueling and not concat else 0
180
+ self.output_dim = output_dim or hidden_dim
181
+ self.model = MyGRU(768, self.hidden_dim, self.n_layers,
182
+ self.dropout, self.bidirectional, self.output_dim)
183
+ if self.use_dueling: # dueling DQN
184
+ q_kwargs, v_kwargs = dueling_param # type: ignore
185
+ q_output_dim, v_output_dim = 0, 0
186
+ if not concat:
187
+ q_output_dim, v_output_dim = action_dim, num_atoms
188
+ q_kwargs: Dict[str, Any] = {
189
+ **q_kwargs, "input_dim": self.output_dim,
190
+ "output_dim": q_output_dim,
191
+ "device": self.device
192
+ }
193
+ v_kwargs: Dict[str, Any] = {
194
+ **v_kwargs, "input_dim": self.output_dim,
195
+ "output_dim": v_output_dim,
196
+ "device": self.device
197
+ }
198
+ self.Q, self.V = MLP(**q_kwargs), MLP(**v_kwargs)
199
+ self.output_dim = self.Q.output_dim
200
+ self.bert_model = AutoModel.from_pretrained(self.trans_model_name).to(self.device)
201
+ self.tokenizer = AutoTokenizer.from_pretrained(trans_model_name)
202
+ from transformers import logging
203
+ logging.set_verbosity_error()
204
+
205
+ def bert_embedding(self, x, max_length=512):
206
+ text = x
207
+ if isinstance(text, np.ndarray):
208
+ text = list(text)
209
+ tokens = self.tokenizer(text, max_length=max_length, padding='max_length', truncation=True, return_tensors='pt')
210
+ input_ids = tokens['input_ids'].to(self.device)
211
+ attention_mask = tokens['attention_mask'].to(self.device)
212
+ with torch.no_grad():
213
+ outputs = self.bert_model(input_ids, attention_mask=attention_mask)
214
+ embeddings = outputs.last_hidden_state
215
+ return embeddings
216
+
217
+ def forward(
218
+ self,
219
+ obs: Union[np.ndarray, torch.Tensor],
220
+ state: Any = None,
221
+ info: Dict[str, Any] = {},
222
+ ) -> Tuple[torch.Tensor, Any]:
223
+ """Mapping: obs -> flatten (inside MLP)-> logits."""
224
+ embedding = self.bert_embedding(obs, max_length=self.max_length)
225
+ logits = self.model(embedding)
226
+ bsz = logits.shape[0]
227
+ if self.use_dueling: # Dueling DQN
228
+ q, v = self.Q(logits), self.V(logits)
229
+ if self.num_atoms > 1:
230
+ q = q.view(bsz, -1, self.num_atoms)
231
+ v = v.view(bsz, -1, self.num_atoms)
232
+ logits = q - q.mean(dim=1, keepdim=True) + v
233
+ elif self.num_atoms > 1:
234
+ logits = logits.view(bsz, -1, self.num_atoms)
235
+ if self.softmax:
236
+ logits = torch.softmax(logits, dim=-1)
237
+ return logits, state
238
+
239
+ class Net_Bert_CLS_tianshou(Net):
240
+ def __init__(
241
+ self,
242
+ state_shape: Union[int, Sequence[int]],
243
+ action_shape: Union[int, Sequence[int]] = 0,
244
+ hidden_sizes: Sequence[int] = (),
245
+ norm_layer: Optional[ModuleType] = None,
246
+ activation: Optional[ModuleType] = nn.ReLU,
247
+ device: Union[str, int, torch.device] = "cpu",
248
+ softmax: bool = False,
249
+ concat: bool = False,
250
+ num_atoms: int = 1,
251
+ dueling_param: Optional[Tuple[Dict[str, Any], Dict[str, Any]]] = None,
252
+ linear_layer: Type[nn.Linear] = nn.Linear,
253
+ hidden_dim: int = 128,
254
+ bidirectional: bool = True,
255
+ dropout: float = 0.,
256
+ n_layers: int = 1,
257
+ max_length: int = 512,
258
+ trans_model_name: str = 'bert-base-uncased',
259
+ ) -> None:
260
+ nn.Module.__init__(self)
261
+ self.device = device
262
+ self.softmax = softmax
263
+ self.num_atoms = num_atoms
264
+ self.hidden_dim = hidden_dim
265
+ self.bidirectional = bidirectional
266
+ self.dropout = dropout
267
+ self.n_layers = n_layers
268
+ self.trans_model_name = trans_model_name
269
+ self.max_length = max_length
270
+
271
+ input_dim = int(np.prod(state_shape))
272
+ action_dim = int(np.prod(action_shape)) * num_atoms
273
+ if concat:
274
+ input_dim += action_dim
275
+ self.use_dueling = dueling_param is not None
276
+ output_dim = action_dim if not self.use_dueling and not concat else 0
277
+ self.output_dim = output_dim or hidden_dim
278
+ self.model = MLP(768, output_dim, hidden_sizes, norm_layer, activation, device, linear_layer)
279
+ if self.use_dueling: # dueling DQN
280
+ q_kwargs, v_kwargs = dueling_param # type: ignore
281
+ q_output_dim, v_output_dim = 0, 0
282
+ if not concat:
283
+ q_output_dim, v_output_dim = action_dim, num_atoms
284
+ q_kwargs: Dict[str, Any] = {
285
+ **q_kwargs, "input_dim": self.output_dim,
286
+ "output_dim": q_output_dim,
287
+ "device": self.device
288
+ }
289
+ v_kwargs: Dict[str, Any] = {
290
+ **v_kwargs, "input_dim": self.output_dim,
291
+ "output_dim": v_output_dim,
292
+ "device": self.device
293
+ }
294
+ self.Q, self.V = MLP(**q_kwargs), MLP(**v_kwargs)
295
+ self.output_dim = self.Q.output_dim
296
+ self.bert_model = AutoModel.from_pretrained(self.trans_model_name).to(self.device)
297
+ self.tokenizer = AutoTokenizer.from_pretrained(trans_model_name)
298
+ from transformers import logging
299
+ logging.set_verbosity_error()
300
+
301
+ def bert_CLS_embedding(self, x, max_length=512):
302
+ text = x
303
+ if isinstance(text, np.ndarray):
304
+ text = list(text)
305
+ tokens = self.tokenizer(text, max_length=max_length, padding='max_length', truncation=True, return_tensors='pt')
306
+ input_ids = tokens['input_ids'].to(self.device)
307
+ attention_mask = tokens['attention_mask'].to(self.device)
308
+ with torch.no_grad():
309
+ outputs = self.bert_model(input_ids, attention_mask=attention_mask)
310
+ embeddings = outputs[0][:, 0, :]
311
+ return embeddings
312
+
313
+ def forward(
314
+ self,
315
+ obs: Union[np.ndarray, torch.Tensor],
316
+ state: Any = None,
317
+ info: Dict[str, Any] = {},
318
+ ) -> Tuple[torch.Tensor, Any]:
319
+ """Mapping: obs -> flatten (inside MLP)-> logits."""
320
+ embedding = self.bert_CLS_embedding(obs, max_length=self.max_length)
321
+ logits = self.model(embedding)
322
+ bsz = logits.shape[0]
323
+ if self.use_dueling: # Dueling DQN
324
+ q, v = self.Q(logits), self.V(logits)
325
+ if self.num_atoms > 1:
326
+ q = q.view(bsz, -1, self.num_atoms)
327
+ v = v.view(bsz, -1, self.num_atoms)
328
+ logits = q - q.mean(dim=1, keepdim=True) + v
329
+ elif self.num_atoms > 1:
330
+ logits = logits.view(bsz, -1, self.num_atoms)
331
+ if self.softmax:
332
+ logits = torch.softmax(logits, dim=-1)
333
+ return logits, state
334
+
335
+
336
+ class Net_Bert_CNN_tianshou(Net_GRU_Bert_tianshou):
337
+ def __init__(
338
+ self,
339
+ state_shape: Union[int, Sequence[int]],
340
+ action_shape: Union[int, Sequence[int]] = 0,
341
+ hidden_sizes: Sequence[int] = (),
342
+ norm_layer: Optional[ModuleType] = None,
343
+ activation: Optional[ModuleType] = nn.ReLU,
344
+ device: Union[str, int, torch.device] = "cpu",
345
+ softmax: bool = False,
346
+ concat: bool = False,
347
+ num_atoms: int = 1,
348
+ dueling_param: Optional[Tuple[Dict[str, Any], Dict[str, Any]]] = None,
349
+ linear_layer: Type[nn.Linear] = nn.Linear,
350
+ hidden_dim: int = 128,
351
+ bidirectional: bool = True,
352
+ dropout: float = 0.,
353
+ n_layers: int = 1,
354
+ max_length: int = 512,
355
+ trans_model_name: str = 'bert-base-uncased',
356
+ ) -> None:
357
+ nn.Module.__init__(self)
358
+ self.device = device
359
+ self.softmax = softmax
360
+ self.num_atoms = num_atoms
361
+ self.hidden_dim = hidden_dim
362
+ self.bidirectional = bidirectional
363
+ self.dropout = dropout
364
+ self.n_layers = n_layers
365
+ self.trans_model_name = trans_model_name
366
+ self.max_length = max_length
367
+
368
+ input_dim = int(np.prod(state_shape))
369
+ action_dim = int(np.prod(action_shape)) * num_atoms
370
+ if concat:
371
+ input_dim += action_dim
372
+ self.use_dueling = dueling_param is not None
373
+ output_dim = action_dim if not self.use_dueling and not concat else 0
374
+ self.output_dim = output_dim or hidden_dim
375
+ self.model = MyCNN(768, output_dim, hidden_sizes, norm_layer, activation, device, linear_layer, flatten_input=False)
376
+ if self.use_dueling: # dueling DQN
377
+ q_kwargs, v_kwargs = dueling_param # type: ignore
378
+ q_output_dim, v_output_dim = 0, 0
379
+ if not concat:
380
+ q_output_dim, v_output_dim = action_dim, num_atoms
381
+ q_kwargs: Dict[str, Any] = {
382
+ **q_kwargs, "input_dim": self.output_dim,
383
+ "output_dim": q_output_dim,
384
+ "device": self.device
385
+ }
386
+ v_kwargs: Dict[str, Any] = {
387
+ **v_kwargs, "input_dim": self.output_dim,
388
+ "output_dim": v_output_dim,
389
+ "device": self.device
390
+ }
391
+ self.Q, self.V = MLP(**q_kwargs), MLP(**v_kwargs)
392
+ self.output_dim = self.Q.output_dim
393
+ self.bert_model = AutoModel.from_pretrained(self.trans_model_name).to(self.device)
394
+ self.tokenizer = AutoTokenizer.from_pretrained(trans_model_name)
395
+ from transformers import logging
396
+ logging.set_verbosity_error()
397
+
398
+ class DQN_GRU(nn.Module):
399
+ """Reference: Human-level control through deep reinforcement learning.
400
+ """
401
+
402
+ def __init__(
403
+ self,
404
+ state_shape: Union[int, Sequence[int]],
405
+ action_shape: Sequence[int],
406
+ device: Union[str, int, torch.device] = "cpu",
407
+ features_only: bool = False,
408
+ output_dim: Optional[int] = None,
409
+ hidden_dim: int = 128,
410
+ n_layers: int = 1,
411
+ dropout: float = 0.,
412
+ bidirectional: bool = True,
413
+ trans_model_name: str = 'bert-base-uncased',
414
+ max_length: int = 512,
415
+ ) -> None:
416
+ super().__init__()
417
+ self.device = device
418
+ self.max_length = max_length
419
+ action_dim = int(np.prod(action_shape))
420
+ self.net = MyGRU(768, hidden_dim, n_layers, dropout, bidirectional,
421
+ hidden_dim)
422
+ if not features_only:
423
+ self.net = MyGRU(768, hidden_dim, n_layers, dropout, bidirectional,
424
+ action_dim)
425
+ self.output_dim = action_dim
426
+ elif output_dim is not None:
427
+ self.net = MyGRU(768, hidden_dim, n_layers, dropout, bidirectional,
428
+ output_dim)
429
+ self.output_dim = output_dim
430
+ else:
431
+ self.net = MyGRU(768, hidden_dim, n_layers, dropout, bidirectional,
432
+ hidden_dim)
433
+ self.output_dim = hidden_dim
434
+ self.trans_model_name = trans_model_name
435
+ self.bert_model = AutoModel.from_pretrained(self.trans_model_name).to(self.device)
436
+ self.tokenizer = AutoTokenizer.from_pretrained(trans_model_name)
437
+ from transformers import logging
438
+ logging.set_verbosity_error()
439
+
440
+ def bert_embedding(self, x, max_length=512):
441
+ text = x
442
+ if isinstance(text, np.ndarray):
443
+ text = list(text)
444
+ tokens = self.tokenizer(text, max_length=max_length, padding='max_length', truncation=True, return_tensors='pt')
445
+ input_ids = tokens['input_ids'].to(self.device)
446
+ attention_mask = tokens['attention_mask'].to(self.device)
447
+ with torch.no_grad():
448
+ outputs = self.bert_model(input_ids, attention_mask=attention_mask)
449
+ embeddings = outputs.last_hidden_state
450
+ return embeddings
451
+
452
+ def forward(
453
+ self,
454
+ obs: Union[np.ndarray, torch.Tensor],
455
+ state: Optional[Any] = None,
456
+ info: Dict[str, Any] = {},
457
+ ) -> Tuple[torch.Tensor, Any]:
458
+ r"""Mapping: s -> Q(s, \*)."""
459
+ embedding = self.bert_embedding(obs, max_length=self.max_length)
460
+ return self.net(embedding), state
461
+
462
+ class Rainbow_GRU(DQN_GRU):
463
+ """Reference: Rainbow: Combining Improvements in Deep Reinforcement Learning.
464
+ """
465
+
466
+ def __init__(
467
+ self,
468
+ state_shape: Union[int, Sequence[int]],
469
+ action_shape: Sequence[int],
470
+ num_atoms: int = 51,
471
+ noisy_std: float = 0.5,
472
+ device: Union[str, int, torch.device] = "cpu",
473
+ is_dueling: bool = True,
474
+ is_noisy: bool = True,
475
+ output_dim: Optional[int] = None,
476
+ hidden_dim: int = 128,
477
+ n_layers: int = 1,
478
+ dropout: float = 0.,
479
+ bidirectional: bool = True,
480
+ trans_model_name: str = 'bert-base-uncased',
481
+ max_length: int = 512,
482
+ ) -> None:
483
+ super().__init__(state_shape, action_shape, device, features_only=True,
484
+ output_dim=output_dim, hidden_dim=hidden_dim, n_layers=n_layers,
485
+ dropout=dropout, bidirectional=bidirectional, trans_model_name=trans_model_name)
486
+ self.action_num = np.prod(action_shape)
487
+ self.num_atoms = num_atoms
488
+
489
+ def linear(x, y):
490
+ if is_noisy:
491
+ return NoisyLinear(x, y, noisy_std)
492
+ else:
493
+ return nn.Linear(x, y)
494
+
495
+ self.Q = nn.Sequential(
496
+ linear(self.output_dim, 512), nn.ReLU(inplace=True),
497
+ linear(512, self.action_num * self.num_atoms)
498
+ )
499
+ self._is_dueling = is_dueling
500
+ if self._is_dueling:
501
+ self.V = nn.Sequential(
502
+ linear(self.output_dim, 512), nn.ReLU(inplace=True),
503
+ linear(512, self.num_atoms)
504
+ )
505
+ self.output_dim = self.action_num * self.num_atoms
506
+
507
+ def forward(
508
+ self,
509
+ obs: Union[np.ndarray, torch.Tensor],
510
+ state: Optional[Any] = None,
511
+ info: Dict[str, Any] = {},
512
+ ) -> Tuple[torch.Tensor, Any]:
513
+ r"""Mapping: x -> Z(x, \*)."""
514
+ obs, state = super().forward(obs)
515
+ q = self.Q(obs)
516
+ q = q.view(-1, self.action_num, self.num_atoms)
517
+ if self._is_dueling:
518
+ v = self.V(obs)
519
+ v = v.view(-1, 1, self.num_atoms)
520
+ logits = q - q.mean(dim=1, keepdim=True) + v
521
+ else:
522
+ logits = q
523
+ probs = logits.softmax(dim=2)
524
+ return probs, state
525
+
526
+ class Net_GRU_nn_emb_tianshou(Net):
527
+
528
+ def __init__(
529
+ self,
530
+ action_shape: Union[int, Sequence[int]] = 0,
531
+ hidden_sizes: Sequence[int] = (),
532
+ norm_layer: Optional[ModuleType] = None,
533
+ activation: Optional[ModuleType] = nn.ReLU,
534
+ device: Union[str, int, torch.device] = "cpu",
535
+ softmax: bool = False,
536
+ concat: bool = False,
537
+ num_atoms: int = 1,
538
+ dueling_param: Optional[Tuple[Dict[str, Any], Dict[str, Any]]] = None,
539
+ linear_layer: Type[nn.Linear] = nn.Linear,
540
+ hidden_dim: int = 128,
541
+ bidirectional: bool = True,
542
+ dropout: float = 0.,
543
+ n_layers: int = 1,
544
+ max_length: int = 512,
545
+ trans_model_name: str = 'bert-base-uncased',
546
+ word_emb_dim: int = 128,
547
+ ) -> None:
548
+ nn.Module.__init__(self)
549
+ self.device = device
550
+ self.softmax = softmax
551
+ self.num_atoms = num_atoms
552
+ self.hidden_dim = hidden_dim
553
+ self.bidirectional = bidirectional
554
+ self.dropout = dropout
555
+ self.n_layers = n_layers
556
+ self.trans_model_name = trans_model_name
557
+ self.max_length = max_length
558
+
559
+ action_dim = int(np.prod(action_shape)) * num_atoms
560
+ self.use_dueling = dueling_param is not None
561
+ output_dim = action_dim if not self.use_dueling and not concat else 0
562
+ self.output_dim = output_dim or hidden_dim
563
+
564
+ self.tokenizer = AutoTokenizer.from_pretrained(trans_model_name)
565
+ from transformers import logging
566
+ logging.set_verbosity_error()
567
+ self.vocab_size = self.tokenizer.vocab_size
568
+ self.embedding = nn.Embedding(self.vocab_size, word_emb_dim)
569
+ self.model = MyGRU(word_emb_dim, self.hidden_dim, self.n_layers,
570
+ self.dropout, self.bidirectional, self.output_dim)
571
+ if self.use_dueling: # dueling DQN
572
+ q_kwargs, v_kwargs = dueling_param # type: ignore
573
+ q_output_dim, v_output_dim = 0, 0
574
+ if not concat:
575
+ q_output_dim, v_output_dim = action_dim, num_atoms
576
+ q_kwargs: Dict[str, Any] = {
577
+ **q_kwargs, "input_dim": self.output_dim,
578
+ "output_dim": q_output_dim,
579
+ "device": self.device
580
+ }
581
+ v_kwargs: Dict[str, Any] = {
582
+ **v_kwargs, "input_dim": self.output_dim,
583
+ "output_dim": v_output_dim,
584
+ "device": self.device
585
+ }
586
+ self.Q, self.V = MLP(**q_kwargs), MLP(**v_kwargs)
587
+ self.output_dim = self.Q.output_dim
588
+
589
+
590
+ def forward(
591
+ self,
592
+ obs: Union[np.ndarray, torch.Tensor],
593
+ state: Any = None,
594
+ info: Dict[str, Any] = {},
595
+ ) -> Tuple[torch.Tensor, Any]:
596
+ """Mapping: obs -> flatten (inside MLP)-> logits."""
597
+ if isinstance(obs, np.ndarray):
598
+ text = list(obs)
599
+ else:
600
+ text = obs
601
+ tokens = self.tokenizer(text, max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt')
602
+ input_ids = tokens['input_ids'].to(self.device)
603
+ attention_mask = tokens['attention_mask'].to(self.device)
604
+ embedding = self.embedding(input_ids)
605
+ mask = attention_mask.unsqueeze(-1).expand(embedding.size()).float()
606
+ embedding = embedding * mask
607
+ logits = self.model(embedding)
608
+ bsz = logits.shape[0]
609
+ if self.use_dueling: # Dueling DQN
610
+ q, v = self.Q(logits), self.V(logits)
611
+ if self.num_atoms > 1:
612
+ q = q.view(bsz, -1, self.num_atoms)
613
+ v = v.view(bsz, -1, self.num_atoms)
614
+ logits = q - q.mean(dim=1, keepdim=True) + v
615
+ elif self.num_atoms > 1:
616
+ logits = logits.view(bsz, -1, self.num_atoms)
617
+ if self.softmax:
618
+ logits = torch.softmax(logits, dim=-1)
619
+ return logits, state
620
+
621
+
deciders/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from .act import NaiveAct, RandomAct
3
+ from .selfask import SelfAskAct
4
+ from .pal import PAL
5
+ from .cot import ChainOfThought
6
+ from .self_consistency import SelfConsistency
7
+ from .spp import SPP
8
+ from .reflexion import Reflexion
9
+ from .jarvis import Jarvis
10
+ from .jarvis_without_insights import JarvisWithoutInsight
11
+ from .jarvis_without_suggestions import JarvisWithoutSuggestions
12
+ from .jarvis_without_shortmem import JarvisWithoutShortMem
13
+
14
+ REGISTRY = {}
15
+ REGISTRY['random_actor'] = RandomAct
16
+ REGISTRY['naive_actor'] = NaiveAct
17
+ REGISTRY['selfask_actor'] = SelfAskAct
18
+ REGISTRY['pal_actor'] = PAL
19
+ REGISTRY['cot_actor'] = ChainOfThought
20
+ REGISTRY['self_consistency_actor'] = SelfConsistency
21
+ REGISTRY['spp_actor'] = SPP
22
+ REGISTRY['reflexion_actor'] = Reflexion
23
+ REGISTRY['jarvis_actor'] = Jarvis
24
+ REGISTRY['jarvis_actor_woi'] = JarvisWithoutInsight
25
+ REGISTRY['jarvis_actor_wosug'] = JarvisWithoutSuggestions
26
+ REGISTRY['jarvis_actor_wosh'] = JarvisWithoutShortMem
deciders/act.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file contains functions for interacting with the ChatGPT model
2
+
3
+ import openai
4
+ from .gpt import gpt
5
+ from loguru import logger
6
+ from .parser import PARSERS
7
+ from langchain.output_parsers import PydanticOutputParser
8
+ from langchain.output_parsers import OutputFixingParser
9
+ from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
10
+ from memory.env_history import EnvironmentHistory
11
+ import tiktoken
12
+ import json
13
+ import re
14
+ from .utils import run_chain
15
+
16
+ class RandomAct():
17
+ def __init__(self, action_space):
18
+ self.action_space = action_space
19
+
20
+ def act(self, state_description, action_description, env_info, game_description=None, goal_description=None):
21
+ return self.action_space.sample()+1, '', '', '', 0, 0
22
+
23
+ class NaiveAct(gpt):
24
+ def __init__(self, action_space, args, prompts, distiller, temperature=0.0, max_tokens=512, logger=None):
25
+ self.action_space = action_space
26
+ self.temperature = temperature
27
+ self.action_desc_dict = args.action_desc_dict
28
+ self.args = args
29
+ self.prompts = prompts
30
+ self.max_tokens = max_tokens
31
+ self.prompt_level = args.prompt_level
32
+ if args.gpt_version == "gpt-35-turbo":
33
+ model = "gpt-3.5-turbo"
34
+ else:
35
+ model = args.gpt_version
36
+ self.encoding = tiktoken.encoding_for_model(model)
37
+ super().__init__()
38
+ self.distiller = distiller
39
+ self.fewshot_example_initialization(args.prompt_level, args.prompt_path, distiller = self.distiller)
40
+ self.default_action = 1
41
+ self.parser = self._parser_initialization()
42
+ self.irr_game_description = ''
43
+ self.memory = []
44
+ self.env_history = EnvironmentHistory()
45
+ self.first_call = True
46
+ self.logger = logger
47
+ if self.prompt_level in [2, 4]:
48
+ self.memory = self.summarized_fewshot_example
49
+ if args.use_short_mem == 1:
50
+ self.use_short_mem = True
51
+ self.mem_num = self.args.trajectories_num
52
+ else:
53
+ self.use_short_mem = False
54
+ self.mem_num = 0
55
+
56
+ def num_tokens_from_string(self,string: str) -> int:
57
+ """Returns the number of tokens in a text string."""
58
+ num_tokens = len(self.encoding.encode(string))
59
+ return num_tokens
60
+
61
+ def update_mem(self,):
62
+ traj = "Firstly, the description and the goal of the task will be provided. Please pay close attention to comprehend the information presented below.\n"
63
+ traj += "Task Description: " + self.game_description + '\n'
64
+ traj += "Goal Description: " + self.goal_description + '\n'
65
+ traj += self.action_description
66
+ traj += "Below is the historical data for this round of the game, which includes the state and corresponding action for each step.\n"
67
+ traj += str(self.env_history)
68
+ # print(traj)
69
+ self._update_mem(traj)
70
+
71
+ def _update_mem(self, traj):
72
+ my_reflection = self.distiller.generate(traj, self.memory)
73
+ self.memory.append(my_reflection)
74
+ self.env_history.reset()
75
+
76
+ def clear_mem(self):
77
+ self.pre_memory = []
78
+ self.post_memory = []
79
+ self.is_first = True
80
+ self._update_mem(None)
81
+
82
+
83
+ def _parser_initialization(self):
84
+ if hasattr(self.action_space, 'n'):
85
+ assert self.action_space.n in PARSERS.keys(), f'Action space {self.action_space} is not supported.'
86
+ num_action = self.action_space.n
87
+ else:
88
+ num_action = 1
89
+
90
+ # autofixing_chat = AzureChatOpenAI(
91
+ # openai_api_type=openai.api_type,
92
+ # openai_api_version=openai.api_version,
93
+ # openai_api_base=openai.api_base,
94
+ # openai_api_key=openai.api_key,
95
+ # deployment_name="gpt-35-turbo",
96
+ # temperature=self.temperature,
97
+ # max_tokens=self.max_tokens
98
+ # )
99
+ autofixing_chat = ChatOpenAI(temperature=0, openai_api_key=openai.api_key)
100
+
101
+ parser = PydanticOutputParser(pydantic_object=PARSERS[num_action])
102
+ autofixing_parser = OutputFixingParser.from_llm(
103
+ llm=autofixing_chat, parser=parser)
104
+
105
+ return autofixing_parser
106
+
107
+ def fewshot_example_initialization(self, level, path=None, distiller=None):
108
+ self.fewshot_example = []
109
+ self.irr_few_shot_examples = []
110
+ self.prompt_level = level
111
+ self.expert_knowledge = None
112
+ if level in [1,3]:
113
+ self.irr_few_shot_examples = self.prompts.TASK_IRRELEVANT_PROMPTS
114
+ elif level == 5:
115
+ if hasattr(self.prompts, "expert_prompt"):
116
+ self.expert_knowledge = self.prompts.expert_prompt
117
+ self.fewshot_example = self.prompts.PERCEPTRON_BASIC_FS_EXAMPLES
118
+ else:
119
+ self.irr_few_shot_examples = self.prompts.TASK_IRRELEVANT_PROMPTS
120
+ json_file = f'{path}_l{level}.json'
121
+ with open(json_file, 'r') as infile:
122
+ data = json.load(infile)
123
+ max_step_num = 0
124
+ for traj in data:
125
+ traj_text = traj[0]['game_description']
126
+ traj_text += traj[0]['goal_description']
127
+ for i, transition in enumerate(traj):
128
+ traj_text += transition['observation']
129
+ traj_text += f"> {transition['action']}"
130
+ one_traj_token = self.num_tokens_from_string(traj_text)
131
+ if one_traj_token > 5000:
132
+ max_step_num = i+1
133
+ break
134
+ traj_text += f"Your performance is: {transition['cum_reward']}"
135
+ if not max_step_num:
136
+ max_step_num = 200
137
+ self.summarized_fewshot_example = self.distiller.generate_from_file(json_file,max_step_num=max_step_num)
138
+
139
+ def response(self, state_description, action_description, env_info, game_description=None, goal_description=None, fewshot_examples=None):
140
+ if env_info['future_summary']:
141
+ prompt = f"{game_description}\n{goal_description}\n{fewshot_examples}\n{state_description}\n{env_info['future_summary']}\n{action_description} "
142
+ else:
143
+ prompt = f"{game_description}\n{goal_description}\n{fewshot_examples}\nCurrent {state_description}\n{action_description} "
144
+ prompt += "Please select an action based on the current game state and the information you get. You must select the appropriate action from the given action descriptions and cannot refrain from taking action or performing any prohibited actions. Your Action is: "
145
+ print(f"prompt is {prompt}")
146
+ res = openai.Completion.create(
147
+ engine=self.args.gpt_version,
148
+ prompt=prompt,
149
+ temperature=self.temperature,
150
+ max_tokens=self.max_tokens,
151
+ )
152
+ return prompt, res
153
+
154
+ def _add_history_before_action(self, game_description, goal_description, state_description):
155
+ self.game_description = game_description
156
+ self.goal_description = goal_description
157
+ self.env_history.add("observation", state_description)
158
+ # print(self.env_history)
159
+ if len(self.env_history) >= 2:
160
+ one_history_token = self.num_tokens_from_string(self.env_history.get_one_history())
161
+ self.env_history.set_history(6000 // one_history_token)
162
+
163
+ def act(self, state_description, action_description, env_info, game_description=None, goal_description=None, logfile=None):
164
+ self._add_history_before_action(game_description, goal_description, state_description)
165
+ asking_round = 0
166
+ res = None
167
+ action = None
168
+ prompt = None
169
+ if not self.logger:
170
+ logger.remove()
171
+ self.logger = logger.add(logfile, colorize=True, enqueue=True)
172
+
173
+ if self.args.prompt_level == 5:
174
+ my_mem = ""
175
+ if self.fewshot_example:
176
+ my_mem += "Here are some examples of how you should complete a task."
177
+ for examples in self.fewshot_example:
178
+ my_mem += "\nQuestion: \n" + examples['question'] + "Answer: \n" + examples['answer']
179
+ my_mem += '\nNow you are in the task.\n'
180
+ elif self.args.prompt_level in [2,3,4]:
181
+ my_mem = ""
182
+ if self.prompt_level == 2:
183
+ my_mem += 'I have collected a few trajectories from a random policy, and the summaries are listed below.'
184
+ elif self.prompt_level == 3:
185
+ my_mem += 'I have collected a few trajectories before, and the summaries are listed below.'
186
+ elif self.prompt_level == 4:
187
+ my_mem += 'I have collected a few trajectories from an expert policy, and the summaries are listed below.'
188
+ my_mem += self._read_mem()
189
+ else:
190
+ my_mem = ""
191
+
192
+ if self.use_short_mem:
193
+ if len(self.env_history) > 1:
194
+ my_mem += '\nSubsequently, I will offer pertinent guidance or information about the task. Please utilize this instruction to accomplish the given task effectively.'
195
+ my_mem += f"\nBelow are the latest {min(self.args.short_mem_num,len(self.env_history)//2)} historical data entries:\n"
196
+ my_mem += f"{self.env_history.get_histories(self.mem_num)}"
197
+
198
+ while asking_round < 3:
199
+ prompt, res = self.response(state_description, action_description, env_info, game_description, goal_description, my_mem)
200
+ action_str = res.choices[0].text.strip()
201
+ print(f'my anwser is {action_str}')
202
+ # import pdb; pdb.set_trace()
203
+ try:
204
+ if "Continuous" in self.args.env_name:
205
+ action = float(re.findall(r"[-+]?\d*\.\d+", action_str)[0])
206
+
207
+ else:
208
+ action = int(re.findall(r"\d+", action_str)[0])
209
+ except:
210
+ action = None
211
+ asking_round += 1
212
+ continue
213
+
214
+ if "Continuous" not in self.args.env_name:
215
+ if (action-1) in self.action_space:
216
+ break
217
+ else:
218
+ asking_round += 1
219
+ action = None
220
+ else:
221
+ if action >= self.action_space.low and action <= self.action_space.high:
222
+ break
223
+ else:
224
+ asking_round += 1
225
+ action = None
226
+
227
+ if action is None:
228
+ print('err on selecting action')
229
+ action = self.default_action
230
+ self._add_history_after_action(action)
231
+ self.logger.info(f'\n{prompt}')
232
+ self.logger.info(f'The GPT response is: {res}.')
233
+ self.logger.info(f'The optimal action is: {action}.')
234
+ return action, prompt, res, 0, 0
235
+
236
+ def _read_mem(self, ):
237
+ memory = self.memory
238
+ mem_str = ""
239
+ if len(memory) > 5:
240
+ memory = memory[-5:]
241
+ if len(memory) > 0:
242
+ mem_str += '\nYour memory for the task below:'
243
+ for i, m in enumerate(memory):
244
+ mem_str += f'\nTrial {i}:\n{m.strip()}'
245
+ return mem_str
246
+
247
+ def _add_history_after_action(self, action):
248
+ self.env_history.add('action', action)
deciders/cot.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+ from .misc import history_to_str
3
+ from langchain.chat_models import AzureChatOpenAI
4
+ from langchain.prompts.chat import (
5
+ PromptTemplate,
6
+ ChatPromptTemplate,
7
+ SystemMessagePromptTemplate,
8
+ HumanMessagePromptTemplate,
9
+ )
10
+ from langchain.prompts.few_shot import FewShotPromptTemplate
11
+ from langchain import LLMChain
12
+ from loguru import logger
13
+ from langchain.callbacks import FileCallbackHandler
14
+ from langchain.callbacks import get_openai_callback
15
+ from .act import NaiveAct
16
+ from .utils import run_chain
17
+
18
+
19
+ class ChainOfThought(NaiveAct):
20
+ def __init__(self, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None, logger=None):
21
+ super().__init__(action_space, args, prompts, distiller, temperature, max_tokens,logger)
22
+
23
+ def act(
24
+ self,
25
+ state_description,
26
+ action_description,
27
+ env_info,
28
+ game_description,
29
+ goal_description,
30
+ logfile=None,
31
+ ):
32
+ self.action_description = action_description
33
+ self._add_history_before_action(game_description, goal_description, state_description)
34
+ chat = AzureChatOpenAI(
35
+ openai_api_type=openai.api_type,
36
+ openai_api_version=openai.api_version,
37
+ openai_api_base=openai.api_base,
38
+ openai_api_key=openai.api_key,
39
+ deployment_name=self.args.gpt_version,
40
+ temperature=self.temperature,
41
+ max_tokens=self.max_tokens
42
+ )
43
+
44
+ suffix_flag = False
45
+ reply_format_description = \
46
+ "Your response should choose an optimal action from a valid action list and terminate with the following format: "
47
+
48
+ # System Message
49
+ human_template = "Now, you are completing a challenging task. You must carefully understand the Chain-of-Thought method you will use and apply it to the following task.\n"
50
+
51
+ # task-irrelevant SystemMessage
52
+ if self.irr_few_shot_examples:
53
+ human_template += 'In the following example, I shall present a set of question and answer with the Chain-of-Thought method. Please adhere to the format and reasoning of the provided response when addressing the subsequent task.\n'
54
+ for i, examples in enumerate(self.irr_few_shot_examples):
55
+ human_template += f"\nExample {i+1}:\n"
56
+ human_template += "Question: \n" + examples['question'] + "\nAnswer: \n" + examples['answer']
57
+
58
+ # task-irrelevant few shot if have
59
+ if self.irr_few_shot_examples:
60
+ human_template += "\nMoving forward, I will describe the task, the goal, and the actions you may execute. Please pay close attention to comprehend the information presented below.\n"
61
+
62
+ if self.fewshot_example:
63
+ human_template += "I will describe the task, the goal, and the actions you may execute. Please pay close attention to comprehend the information presented below."
64
+ # print(fewshot_example_prompt.format(**fewshot_examples[0]))
65
+ human_template += '\nTask Description: {game_description} \n'
66
+ human_template += 'Goal Description: {goal_description}\n'
67
+ human_template += 'Actions Description: {action_description}\n'
68
+
69
+ if self.fewshot_example:
70
+ human_template += "Here, I will provide you with some guidance to help you better understand the rules of the task. Next are some examples: "
71
+ for i, examples in enumerate(self.fewshot_example):
72
+ human_template += f"\nExample {i+1}:\n"
73
+ human_template += "Question: \n" + examples['question'] + "\nAnswer: \n" + examples['answer']
74
+
75
+ if self.prompt_level in [2, 3, 4]:
76
+ if self.memory:
77
+ human_template += '\nSubsequently, I will offer pertinent guidance or information about the task. Please utilize this instruction to accomplish the given task effectively.\n'
78
+ suffix_flag = True
79
+ if self.prompt_level == 2:
80
+ human_template += 'I have collected a few trajectories from a random policy, and the summaries are listed below.'
81
+ elif self.prompt_level == 3:
82
+ human_template += 'I have collected a few trajectories before, and the summaries are listed below.'
83
+ elif self.prompt_level == 4:
84
+ human_template += 'I have collected a few trajectories from an expert policy, and the summaries are listed below.'
85
+ human_template += self._read_mem() + "\n"
86
+
87
+ if self.use_short_mem:
88
+ if len(self.env_history) > 1:
89
+ if not suffix_flag:
90
+ human_template += '\nSubsequently, I will offer pertinent guidance or information about the task. Please utilize this instruction to accomplish the given task effectively.'
91
+ human_template += f"\nBelow are the latest {self.args.short_mem_num} historical data entries:\n"
92
+ human_template += f"{self.env_history.get_histories(self.mem_num)}"
93
+ human_template += '\nNext is the observation that the agent gets:\nCurrent {state_description}\n'
94
+ human_template += 'Please select an action based on the current game state and the information you get. You must select the appropriate action from the given action descriptions and cannot refrain from taking action or performing any prohibited actions. Here is the action description below:\n{action_description}\n'
95
+ human_template += 'Please note that you need to carefully lay out your thought process on the question, not just give an answer. You need to write the corresponding logic of your thinking following the example above. Also, please keep in mind not to answer with any redundant and irrelevant content.\n'
96
+ human_template += "Finally, you also need to normalize your output according to the reply format description.\n"
97
+ human_template += 'Reply format description: {reply_format_description}{format_instructions}\n'
98
+
99
+ human_message_prompt = PromptTemplate(
100
+ template=human_template,
101
+ input_variables=[
102
+ 'state_description', 'goal_description', 'game_description',
103
+ 'action_description', 'reply_format_description'],
104
+ partial_variables={'format_instructions': self.parser.get_format_instructions()}
105
+ )
106
+
107
+ human_message_prompt = HumanMessagePromptTemplate(prompt=human_message_prompt)
108
+
109
+ chat_prompt = ChatPromptTemplate.from_messages([human_message_prompt])
110
+
111
+ if not self.logger:
112
+ logger.remove()
113
+ self.logger = logger.add(logfile, colorize=True, enqueue=True)
114
+ handler = FileCallbackHandler(logfile)
115
+
116
+ chain = LLMChain(llm=chat, prompt=chat_prompt, callbacks=[handler], verbose=False)
117
+
118
+ text_prompt = chat_prompt.format_messages(
119
+ game_description=game_description,
120
+ state_description=state_description,
121
+ goal_description=goal_description,
122
+ action_description=action_description,
123
+ reply_format_description=reply_format_description
124
+ )
125
+ texts = ""
126
+ for text in text_prompt:
127
+ texts += text.content + "\n"
128
+
129
+ with get_openai_callback() as cb:
130
+ response = run_chain(
131
+ chain,
132
+ game_description=game_description,
133
+ state_description=state_description,
134
+ goal_description=goal_description,
135
+ action_description=action_description,
136
+ reply_format_description=reply_format_description
137
+ )
138
+ total_tokens = cb.total_tokens
139
+ total_cost = cb.total_cost
140
+ action = self.parser.parse(response).action
141
+ self._add_history_after_action(action)
142
+ self.logger.info(f'The GPT response is: {response}.')
143
+ self.logger.info(f'The optimal action is: {action}.')
144
+ if env_info.get('history'):
145
+ self.logger.info(f'History: {history_to_str(env_info["history"])}')
146
+
147
+ return action, texts, response, total_tokens, total_cost
deciders/jarvis.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+ from .misc import history_to_str
3
+ from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
4
+ from langchain.prompts.chat import (
5
+ PromptTemplate,
6
+ ChatPromptTemplate,
7
+ SystemMessagePromptTemplate,
8
+ HumanMessagePromptTemplate,
9
+ )
10
+ from langchain.prompts.few_shot import FewShotPromptTemplate
11
+ from langchain import LLMChain
12
+ from langchain.callbacks import FileCallbackHandler
13
+ from langchain.callbacks import get_openai_callback
14
+ from .act import NaiveAct
15
+ from memory.env_history import EnvironmentHistory
16
+ import tiktoken
17
+ from .utils import run_chain
18
+ from loguru import logger
19
+
20
+
21
+
22
+ class Jarvis(NaiveAct):
23
+ def __init__(self, action_space, args, prompts, distiller, temperature=0., max_tokens=None, logger=None, fixed_suggestion=None, fixed_insight=None):
24
+ super().__init__(action_space, args, prompts, distiller, temperature, max_tokens, logger)
25
+ self.pre_memory = []
26
+ self.post_memory = []
27
+ self.is_first = True
28
+ self.num_trails = args.num_trails
29
+ self.game_description = args.game_description
30
+ self.goal_description = args.goal_description
31
+ self.action_description = args.action_description
32
+ self.action_desc_dict = args.action_desc_dict
33
+ self.mem_num = args.trajectories_num
34
+ self.temperature = temperature
35
+ self.fixed_suggestion = fixed_suggestion
36
+ self.fixed_insight = fixed_insight
37
+ self._update_mem(None)
38
+ self.insight = ""
39
+
40
+ def num_tokens_from_string(self,string: str) -> int:
41
+ """Returns the number of tokens in a text string."""
42
+ num_tokens = len(self.encoding.encode(string))
43
+ return num_tokens
44
+
45
+ def update_mem(self,):
46
+ traj = self.game_description
47
+ traj += self.goal_description
48
+ traj += self.action_description
49
+ traj += str(self.env_history)
50
+ self._update_mem(traj)
51
+
52
+ def clear_mem(self):
53
+ self.pre_memory = []
54
+ self.post_memory = []
55
+ self.is_first = True
56
+ self._update_mem(None)
57
+
58
+ def _update_mem(self, traj):
59
+ if self.memory:
60
+ self.post_memory = self.memory
61
+ self.insight = self.distiller.generate_insight(self.post_memory)
62
+ else:
63
+ if not self.is_first:
64
+ summary = self.distiller.generate_summary(traj, self.post_memory)
65
+ self.post_memory.append(summary)
66
+ self.insight = self.distiller.generate_insight(self.post_memory)
67
+ else:
68
+ self.is_first = False
69
+ self.insight = ""
70
+ suggestion = self.distiller.generate_suggestion(self.game_description, self.goal_description, self.action_description, self.pre_memory, self.post_memory, self.insight, self.num_trails)
71
+ if self.fixed_suggestion:
72
+ suggestion = self.fixed_suggestion
73
+ if self.fixed_insight:
74
+ self.insight = self.fixed_insight
75
+ self.pre_memory.append(suggestion)
76
+ self.env_history.reset()
77
+
78
+ def _read_mem(self, ):
79
+ insight_str = ""
80
+ if self.insight:
81
+ insight_str += "The insights of the game are listed below: "
82
+ insight_str += f"{self.insight}\n"
83
+ suggestion_str = "The suggestions are listed below:" + self.pre_memory[-1]
84
+ return insight_str + suggestion_str
85
+ def act(
86
+ self,
87
+ state_description,
88
+ action_description,
89
+ env_info,
90
+ game_description,
91
+ goal_description,
92
+ logfile=None,
93
+ ):
94
+ self.game_description = game_description
95
+ self.goal_description = goal_description
96
+ self.env_history.add("observation", state_description)
97
+ chat = ChatOpenAI(temperature=0.5, openai_api_key=openai.api_key, model=self.args.gpt_version)
98
+ # print(self.logger)
99
+ reply_format_description = \
100
+ "Your response should choose an optimal action from valid action list, and terminated with following format: "
101
+ # only task relevant examplesA
102
+ template = "Now you are completing a task."
103
+ template += "You need to carefully understand the description of the game. "
104
+ # TODO: few shot example handle
105
+ if self.irr_few_shot_examples:
106
+ template += "Here are some examples of how you should completing a task."
107
+ for examples in self.irr_few_shot_examples:
108
+ template += "\nQuestion: \n" + examples['question'] + "Answer: \n" + examples['answer']
109
+
110
+ template += "\n\nNow you are in the task."
111
+ template += " {game_description} {action_description} {goal_description}"
112
+ template += "You are observing something and " \
113
+ "you need to choose the optimal action acoordingly."
114
+ template += 'Response and interact using the format: {reply_format_description}{format_instructions}\n'
115
+
116
+ template += self._read_mem()
117
+ system_message_prompt = SystemMessagePromptTemplate.from_template(template)
118
+
119
+ short_memory_template = HumanMessagePromptTemplate.from_template("{history}")
120
+ chat_prompt = ChatPromptTemplate.from_messages(
121
+ [system_message_prompt, short_memory_template])
122
+ if self.logger:
123
+ pass
124
+ else:
125
+ if logfile:
126
+ # logger.remove()
127
+ if self.first_call:
128
+ self.logger = logger.add(logfile, colorize=True, enqueue=True, filter=lambda x: '[Reflexion Memory]' not in x['message'])
129
+ self.first_call = False
130
+ handler = FileCallbackHandler(logfile)
131
+ total_tokens, total_cost = 0, 0
132
+ max_think_times = 1
133
+ # TODO: ADD REACT Support
134
+ # print(str(self.env_history))
135
+ if self.use_short_mem:
136
+ my_history = str(self.env_history)
137
+ else:
138
+ my_history = ""
139
+ for i_think in range(max_think_times):
140
+ # chain = LLMChain(llm=chat, prompt=chat_prompt, callbacks=[handler], verbose=True)
141
+ chain = LLMChain(llm=chat, prompt=chat_prompt, callbacks=[handler], verbose=False)
142
+ with get_openai_callback() as cb:
143
+ response = run_chain(
144
+ chain,
145
+ game_description=game_description,
146
+ goal_description=goal_description,
147
+ action_description=action_description,
148
+ # state_description = self.env_history.get_last_history(),
149
+ history=self.env_history.get_histories_with_last(self.mem_num),
150
+ format_instructions=self.parser.get_format_instructions(),
151
+ reply_format_description=reply_format_description,
152
+ max_token=3000
153
+ )
154
+
155
+ total_tokens += cb.total_tokens
156
+ total_cost += cb.total_cost
157
+ action = self.parser.parse(response).action
158
+ self._add_history_after_action(action)
159
+ self.logger.info(f'The GPT response is: {response}.')
160
+ self.logger.info(f'The optimal action is: {action}.')
161
+ if self.pre_memory:
162
+ self.logger.info(f'The suggestion is: {self.pre_memory[-1]}.')
163
+ if self.post_memory:
164
+ self.logger.info(f'The summary is: {self.post_memory[-1]}.')
165
+ if env_info.get('history'):
166
+ self.logger.info(f'History: {history_to_str(env_info["history"])}')
167
+ text_prompt = chat_prompt.format_messages(
168
+ game_description=game_description,
169
+ goal_description=goal_description,
170
+ action_description=action_description,
171
+ # state_description = self.env_history.get_last_history(),
172
+ history=self.env_history.get_histories_with_last(self.mem_num),
173
+ format_instructions=self.parser.get_format_instructions(),
174
+ reply_format_description=reply_format_description,
175
+ )
176
+ text_prompt = f'{text_prompt[0].content}\n{text_prompt[1].content}'
177
+ return action, text_prompt, response, total_tokens, total_cost
deciders/jarvis_without_insights.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+ from .misc import history_to_str
3
+ from langchain.chat_models import AzureChatOpenAI
4
+ from langchain.prompts.chat import (
5
+ PromptTemplate,
6
+ ChatPromptTemplate,
7
+ SystemMessagePromptTemplate,
8
+ HumanMessagePromptTemplate,
9
+ )
10
+ from langchain.prompts.few_shot import FewShotPromptTemplate
11
+ from langchain import LLMChain
12
+ from loguru import logger
13
+ from langchain.callbacks import FileCallbackHandler
14
+ from langchain.callbacks import get_openai_callback
15
+ from .act import NaiveAct
16
+ from memory.env_history import EnvironmentHistory
17
+ import tiktoken
18
+ from .utils import run_chain
19
+
20
+
21
+ class JarvisWithoutInsight(NaiveAct):
22
+ def __init__(self, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None):
23
+ super().__init__(action_space, args, prompts, distiller, temperature, max_tokens)
24
+ self.pre_memory = []
25
+ self.post_memory = []
26
+ self.is_first = True
27
+ self.num_trails = args.num_trails
28
+ self.game_description = args.game_description
29
+ self.goal_description = args.goal_description
30
+ self.action_description = args.action_description
31
+ self._update_mem(None)
32
+
33
+ def update_mem(self,):
34
+ traj = self.game_description
35
+ traj += self.goal_description
36
+ max_step_num = min(14000 // self.num_tokens_from_string(self.env_history.get_one_history()),200)
37
+ traj += self.env_history.get_histories(max_step_num)
38
+ self._update_mem(traj)
39
+
40
+ def _update_mem(self, traj):
41
+ if not self.is_first:
42
+ summary = self.distiller.generate_summary(traj, self.post_memory)
43
+ self.post_memory.append(summary)
44
+ self.insight = self.distiller.generate_insight(self.post_memory)
45
+ else:
46
+ self.is_first = False
47
+ suggestion = self.distiller.generate_suggestion(self.game_description, self.goal_description, self.action_description, self.pre_memory, self.post_memory, self.num_trails)
48
+ self.pre_memory.append(suggestion)
49
+ self.env_history.reset()
50
+
51
+ def _read_mem(self, ):
52
+ insight_str = ""
53
+ suggestion_str = "The suggestions are listed below:" + self.pre_memory[-1]
54
+ return insight_str + suggestion_str
55
+
56
+ def act(
57
+ self,
58
+ state_description,
59
+ action_description,
60
+ env_info,
61
+ game_description,
62
+ goal_description,
63
+ logfile=None,
64
+ ):
65
+ self.game_description = game_description
66
+ self.goal_description = goal_description
67
+ self.env_history.add("observation", state_description)
68
+ chat = AzureChatOpenAI(
69
+ openai_api_type=openai.api_type,
70
+ openai_api_version=openai.api_version,
71
+ openai_api_base=openai.api_base,
72
+ openai_api_key=openai.api_key,
73
+ deployment_name=self.args.gpt_version,
74
+ temperature=self.temperature,
75
+ max_tokens=self.max_tokens,
76
+ )
77
+ reply_format_description = \
78
+ "Your response should choose an optimal action from valid action list, and terminated with following format: "
79
+ # only task relevant examplesA
80
+ template = "Now you are completing a task. "
81
+ template += "You need to carefully understand the description of the game. "
82
+ # TODO: few shot example handle
83
+ if self.irr_few_shot_examples:
84
+ template += "Here are some examples of how you should completing a task."
85
+ for examples in self.irr_few_shot_examples:
86
+ template += "\nQuestion: \n" + examples['question'] + "Answer: \n" + examples['answer']
87
+
88
+ if self.fewshot_example:
89
+ if self.expert_knowledge:
90
+ template += "Here, I will provide you with some expert knowledge to help you better understand the rules of the task."
91
+ template += self.expert_knowledge + '\n'
92
+ template += "Next are some examples: "
93
+ system_message_prompt = SystemMessagePromptTemplate.from_template(template)
94
+
95
+ human_template = ""
96
+ human_template += "\n\nNow you are in the task.\n"
97
+ human_template += "{game_description}\n{action_description}\n{goal_description}\n"
98
+ human_template += "You are observing something and " \
99
+ "you need to choose the optimal action acoordingly. "
100
+ human_template += 'Response and interact using the format: {reply_format_description}{format_instructions}\n'
101
+ human_template += self._read_mem()
102
+ human_template += "\n\nHere are some history states listed below:\n"
103
+
104
+ fewshot_example_prompt = PromptTemplate(
105
+ input_variables=["question", "answer"],
106
+ template="Question: \n{question}\n{answer}"
107
+ )
108
+ human_message_prompt = FewShotPromptTemplate(
109
+ examples=self.fewshot_example,
110
+ example_prompt=fewshot_example_prompt,
111
+ suffix=human_template,
112
+ input_variables=[
113
+ 'game_description', 'goal_description',
114
+ 'action_description', 'reply_format_description'],
115
+ partial_variables={'format_instructions': self.parser.get_format_instructions()}
116
+ )
117
+ human_message_prompt = HumanMessagePromptTemplate(prompt=human_message_prompt)
118
+
119
+ short_memory_template = HumanMessagePromptTemplate.from_template("{history} Please select an action based on the current game state:")
120
+
121
+ chat_prompt = ChatPromptTemplate.from_messages(
122
+ [system_message_prompt, human_message_prompt, short_memory_template])
123
+
124
+
125
+ if logfile:
126
+ # logger.remove()
127
+ if self.first_call:
128
+ logger.add(logfile, colorize=True, enqueue=True, filter=lambda x: '[Reflexion Memory]' not in x['message'])
129
+ self.first_call = False
130
+ handler = FileCallbackHandler(logfile)
131
+ total_tokens, total_cost = 0, 0
132
+ max_think_times = 1
133
+ # TODO: ADD REACT Support
134
+ # print(str(self.env_history))
135
+ if self.use_short_mem:
136
+ my_history = str(self.env_history)
137
+ else:
138
+ my_history = ""
139
+ for i_think in range(max_think_times):
140
+ chain = LLMChain(llm=chat, prompt=chat_prompt, callbacks=[handler], verbose=False)
141
+ with get_openai_callback() as cb:
142
+ response = run_chain(
143
+ chain,
144
+ game_description=game_description,
145
+ goal_description=goal_description,
146
+ action_description=action_description,
147
+ history=str(self.env_history),
148
+ format_instructions=self.parser.get_format_instructions(),
149
+ reply_format_description=reply_format_description,
150
+ max_token = 3000
151
+ )
152
+
153
+ total_tokens += cb.total_tokens
154
+ total_cost += cb.total_cost
155
+ action = self.parser.parse(response).action
156
+
157
+ text_prompt = chat_prompt.format_messages(
158
+ game_description=game_description,
159
+ goal_description=goal_description,
160
+ action_description=action_description,
161
+ history=str(self.env_history),
162
+ format_instructions=self.parser.get_format_instructions(),
163
+ reply_format_description=reply_format_description,
164
+ )
165
+ texts = ""
166
+ for text in text_prompt:
167
+ texts += text.content + "\n"
168
+
169
+ self._add_history_after_action(action)
170
+ logger.info(f'The GPT response is: {response}.')
171
+ logger.info(f'The optimal action is: {action}.')
172
+ if self.pre_memory:
173
+ logger.info(f'The suggestion is: {self.pre_memory[-1]}.')
174
+ if self.post_memory:
175
+ logger.info(f'The summary is: {self.post_memory[-1]}.')
176
+ if env_info.get('history'):
177
+ logger.info(f'History: {history_to_str(env_info["history"])}')
178
+
179
+ return action, texts, response, logger, total_tokens, total_cost
deciders/jarvis_without_shortmem.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+ from .misc import history_to_str
3
+ from langchain.chat_models import AzureChatOpenAI
4
+ from langchain.prompts.chat import (
5
+ PromptTemplate,
6
+ ChatPromptTemplate,
7
+ SystemMessagePromptTemplate,
8
+ HumanMessagePromptTemplate,
9
+ )
10
+ from langchain.prompts.few_shot import FewShotPromptTemplate
11
+ from langchain import LLMChain
12
+ from loguru import logger
13
+ from langchain.callbacks import FileCallbackHandler
14
+ from langchain.callbacks import get_openai_callback
15
+ from .act import NaiveAct
16
+ from memory.env_history import EnvironmentHistory
17
+ import tiktoken
18
+ from .utils import run_chain
19
+
20
+
21
+ class JarvisWithoutShortMem(NaiveAct):
22
+ def __init__(self, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None):
23
+ super().__init__(action_space, args, prompts, distiller, temperature, max_tokens)
24
+ self.pre_memory = []
25
+ self.post_memory = []
26
+ self.is_first = True
27
+ self.num_trails = args.num_trails
28
+ self.game_description = args.game_description
29
+ self.goal_description = args.goal_description
30
+ self.action_description = args.action_description
31
+ self._update_mem(None)
32
+
33
+ def update_mem(self,):
34
+ traj = self.game_description
35
+ traj += self.goal_description
36
+ max_step_num = min(14000 // self.num_tokens_from_string(self.env_history.get_one_history()),200)
37
+ traj += self.env_history.get_histories(max_step_num)
38
+ self._update_mem(traj)
39
+
40
+ def _update_mem(self, traj):
41
+ if not self.is_first:
42
+ summary = self.distiller.generate_summary(traj, self.post_memory)
43
+ self.post_memory.append(summary)
44
+ self.insight = self.distiller.generate_insight(self.post_memory)
45
+ else:
46
+ self.is_first = False
47
+ suggestion = self.distiller.generate_suggestion(self.game_description, self.goal_description, self.action_description, self.pre_memory, self.post_memory, self.num_trails)
48
+ self.pre_memory.append(suggestion)
49
+ self.env_history.reset()
50
+
51
+ def _read_mem(self, ):
52
+ insight_str = ""
53
+ if len(self.post_memory) > 0:
54
+ insight_str += "The insights of the game are listed below: "
55
+ insight_str += f"{self.insight}\n"
56
+ suggestion_str = "The suggestions are listed below:" + self.pre_memory[-1]
57
+ return insight_str + suggestion_str
58
+
59
+ def act(
60
+ self,
61
+ state_description,
62
+ action_description,
63
+ env_info,
64
+ game_description,
65
+ goal_description,
66
+ logfile=None,
67
+ ):
68
+ self.game_description = game_description
69
+ self.goal_description = goal_description
70
+ self.env_history.add("observation", state_description)
71
+ chat = AzureChatOpenAI(
72
+ openai_api_type=openai.api_type,
73
+ openai_api_version=openai.api_version,
74
+ openai_api_base=openai.api_base,
75
+ openai_api_key=openai.api_key,
76
+ deployment_name=self.args.gpt_version,
77
+ temperature=self.temperature,
78
+ max_tokens=self.max_tokens,
79
+ )
80
+ reply_format_description = \
81
+ "Your response should choose an optimal action from valid action list, and terminated with following format: "
82
+ # only task relevant examplesA
83
+ template = "Now you are completing a task. "
84
+ template += "You need to carefully understand the description of the game. "
85
+ # TODO: few shot example handle
86
+ if self.irr_few_shot_examples:
87
+ template += "Here are some examples of how you should completing a task."
88
+ for examples in self.irr_few_shot_examples:
89
+ template += "\nQuestion: \n" + examples['question'] + "Answer: \n" + examples['answer']
90
+
91
+ if self.fewshot_example:
92
+ if self.expert_knowledge:
93
+ template += "Here, I will provide you with some expert knowledge to help you better understand the rules of the task."
94
+ template += self.expert_knowledge + '\n'
95
+ template += "Next are some examples: "
96
+ system_message_prompt = SystemMessagePromptTemplate.from_template(template)
97
+
98
+ human_template = ""
99
+ human_template += "\n\nNow you are in the task.\n"
100
+ human_template += "{game_description}\n{action_description}\n{goal_description}\n"
101
+ human_template += "You are observing something and " \
102
+ "you need to choose the optimal action acoordingly. "
103
+ human_template += 'Response and interact using the format: {reply_format_description}{format_instructions}\n'
104
+ human_template += self._read_mem()
105
+ human_template += "\n\nHere are some history states listed below:\n"
106
+
107
+ fewshot_example_prompt = PromptTemplate(
108
+ input_variables=["question", "answer"],
109
+ template="Question: \n{question}\n{answer}"
110
+ )
111
+ human_message_prompt = FewShotPromptTemplate(
112
+ examples=self.fewshot_example,
113
+ example_prompt=fewshot_example_prompt,
114
+ suffix=human_template,
115
+ input_variables=[
116
+ 'game_description', 'goal_description',
117
+ 'action_description', 'reply_format_description'],
118
+ partial_variables={'format_instructions': self.parser.get_format_instructions()}
119
+ )
120
+ human_message_prompt = HumanMessagePromptTemplate(prompt=human_message_prompt)
121
+
122
+ short_memory_template = HumanMessagePromptTemplate.from_template("{history} Please select an action based on the current game state:")
123
+
124
+ chat_prompt = ChatPromptTemplate.from_messages(
125
+ [system_message_prompt, human_message_prompt, short_memory_template])
126
+
127
+
128
+ if logfile:
129
+ # logger.remove()
130
+ if self.first_call:
131
+ logger.add(logfile, colorize=True, enqueue=True, filter=lambda x: '[Reflexion Memory]' not in x['message'])
132
+ self.first_call = False
133
+ handler = FileCallbackHandler(logfile)
134
+ total_tokens, total_cost = 0, 0
135
+ max_think_times = 1
136
+ # TODO: ADD REACT Support
137
+ # print(str(self.env_history))
138
+ if self.use_short_mem:
139
+ my_history = str(self.env_history)
140
+ else:
141
+ my_history = ""
142
+ for i_think in range(max_think_times):
143
+ chain = LLMChain(llm=chat, prompt=chat_prompt, callbacks=[handler], verbose=False)
144
+ with get_openai_callback() as cb:
145
+ response = run_chain(
146
+ chain,
147
+ game_description=game_description,
148
+ goal_description=goal_description,
149
+ action_description=action_description,
150
+ history=self.env_history.get_last_history(),
151
+ format_instructions=self.parser.get_format_instructions(),
152
+ reply_format_description=reply_format_description,
153
+ max_token = 3000
154
+ )
155
+
156
+ total_tokens += cb.total_tokens
157
+ total_cost += cb.total_cost
158
+ action = self.parser.parse(response).action
159
+
160
+ text_prompt = chat_prompt.format_messages(
161
+ game_description=game_description,
162
+ goal_description=goal_description,
163
+ action_description=action_description,
164
+ history=self.env_history.get_last_history(),
165
+ format_instructions=self.parser.get_format_instructions(),
166
+ reply_format_description=reply_format_description,
167
+ )
168
+ texts = ""
169
+ for text in text_prompt:
170
+ texts += text.content + "\n"
171
+
172
+ self._add_history_after_action(action)
173
+ logger.info(f'The GPT response is: {response}.')
174
+ logger.info(f'The optimal action is: {action}.')
175
+ if self.pre_memory:
176
+ logger.info(f'The suggestion is: {self.pre_memory[-1]}.')
177
+ if self.post_memory:
178
+ logger.info(f'The summary is: {self.post_memory[-1]}.')
179
+ if env_info.get('history'):
180
+ logger.info(f'History: {history_to_str(env_info["history"])}')
181
+
182
+ return action, texts, response, logger, total_tokens, total_cost
deciders/jarvis_without_suggestions.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+ from .misc import history_to_str
3
+ from langchain.chat_models import AzureChatOpenAI
4
+ from langchain.prompts.chat import (
5
+ PromptTemplate,
6
+ ChatPromptTemplate,
7
+ SystemMessagePromptTemplate,
8
+ HumanMessagePromptTemplate,
9
+ )
10
+ from langchain.prompts.few_shot import FewShotPromptTemplate
11
+ from langchain import LLMChain
12
+ from loguru import logger
13
+ from langchain.callbacks import FileCallbackHandler
14
+ from langchain.callbacks import get_openai_callback
15
+ from .act import NaiveAct
16
+ from memory.env_history import EnvironmentHistory
17
+ import tiktoken
18
+ from .utils import run_chain
19
+
20
+
21
+ class JarvisWithoutSuggestions(NaiveAct):
22
+ def __init__(self, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None):
23
+ super().__init__(action_space, args, prompts, distiller, temperature, max_tokens)
24
+ self.pre_memory = []
25
+ self.post_memory = []
26
+ self.is_first = True
27
+ self.num_trails = args.num_trails
28
+ self.game_description = args.game_description
29
+ self.goal_description = args.goal_description
30
+ self.action_description = args.action_description
31
+ self._update_mem(None)
32
+
33
+ def update_mem(self,):
34
+ traj = self.game_description
35
+ traj += self.goal_description
36
+ max_step_num = min(14000 // self.num_tokens_from_string(self.env_history.get_one_history()),200)
37
+ traj += self.env_history.get_histories(max_step_num)
38
+ self._update_mem(traj)
39
+
40
+ def _update_mem(self, traj):
41
+ if not self.is_first:
42
+ summary = self.distiller.generate_summary(traj, self.post_memory)
43
+ self.post_memory.append(summary)
44
+ self.insight = self.distiller.generate_insight(self.post_memory)
45
+ else:
46
+ self.is_first = False
47
+ suggestion = self.distiller.generate_suggestion(self.game_description, self.goal_description, self.action_description, self.pre_memory, self.post_memory, self.num_trails)
48
+ self.pre_memory.append(suggestion)
49
+ self.env_history.reset()
50
+
51
+ def _read_mem(self, ):
52
+ insight_str = ""
53
+ if len(self.post_memory) > 0:
54
+ insight_str += "The insights of the game are listed below: "
55
+ insight_str += f"{self.insight}\n"
56
+ suggestion_str = "\n"
57
+ return insight_str + suggestion_str
58
+
59
+ def act(
60
+ self,
61
+ state_description,
62
+ action_description,
63
+ env_info,
64
+ game_description,
65
+ goal_description,
66
+ logfile=None,
67
+ ):
68
+ self.game_description = game_description
69
+ self.goal_description = goal_description
70
+ self.env_history.add("observation", state_description)
71
+ chat = AzureChatOpenAI(
72
+ openai_api_type=openai.api_type,
73
+ openai_api_version=openai.api_version,
74
+ openai_api_base=openai.api_base,
75
+ openai_api_key=openai.api_key,
76
+ deployment_name=self.args.gpt_version,
77
+ temperature=self.temperature,
78
+ max_tokens=self.max_tokens,
79
+ )
80
+ reply_format_description = \
81
+ "Your response should choose an optimal action from valid action list, and terminated with following format: "
82
+ # only task relevant examplesA
83
+ template = "Now you are completing a task. "
84
+ template += "You need to carefully understand the description of the game. "
85
+ # TODO: few shot example handle
86
+ if self.irr_few_shot_examples:
87
+ template += "Here are some examples of how you should completing a task."
88
+ for examples in self.irr_few_shot_examples:
89
+ template += "\nQuestion: \n" + examples['question'] + "Answer: \n" + examples['answer']
90
+
91
+ if self.fewshot_example:
92
+ if self.expert_knowledge:
93
+ template += "Here, I will provide you with some expert knowledge to help you better understand the rules of the task."
94
+ template += self.expert_knowledge + '\n'
95
+ template += "Next are some examples: "
96
+ system_message_prompt = SystemMessagePromptTemplate.from_template(template)
97
+
98
+ human_template = ""
99
+ human_template += "\n\nNow you are in the task.\n"
100
+ human_template += "{game_description}\n{action_description}\n{goal_description}\n"
101
+ human_template += "You are observing something and " \
102
+ "you need to choose the optimal action acoordingly. "
103
+ human_template += 'Response and interact using the format: {reply_format_description}{format_instructions}\n'
104
+ human_template += self._read_mem()
105
+ human_template += "\n\nHere are some history states listed below:\n"
106
+
107
+ fewshot_example_prompt = PromptTemplate(
108
+ input_variables=["question", "answer"],
109
+ template="Question: \n{question}\n{answer}"
110
+ )
111
+ human_message_prompt = FewShotPromptTemplate(
112
+ examples=self.fewshot_example,
113
+ example_prompt=fewshot_example_prompt,
114
+ suffix=human_template,
115
+ input_variables=[
116
+ 'game_description', 'goal_description',
117
+ 'action_description', 'reply_format_description'],
118
+ partial_variables={'format_instructions': self.parser.get_format_instructions()}
119
+ )
120
+ human_message_prompt = HumanMessagePromptTemplate(prompt=human_message_prompt)
121
+
122
+ short_memory_template = HumanMessagePromptTemplate.from_template("{history} Please select an action based on the current game state:")
123
+
124
+ chat_prompt = ChatPromptTemplate.from_messages(
125
+ [system_message_prompt, human_message_prompt, short_memory_template])
126
+
127
+
128
+ if logfile:
129
+ # logger.remove()
130
+ if self.first_call:
131
+ logger.add(logfile, colorize=True, enqueue=True, filter=lambda x: '[Reflexion Memory]' not in x['message'])
132
+ self.first_call = False
133
+ handler = FileCallbackHandler(logfile)
134
+ total_tokens, total_cost = 0, 0
135
+ max_think_times = 1
136
+ # TODO: ADD REACT Support
137
+ # print(str(self.env_history))
138
+ if self.use_short_mem:
139
+ my_history = str(self.env_history)
140
+ else:
141
+ my_history = ""
142
+ for i_think in range(max_think_times):
143
+ chain = LLMChain(llm=chat, prompt=chat_prompt, callbacks=[handler], verbose=False)
144
+ with get_openai_callback() as cb:
145
+ response = run_chain(
146
+ chain,
147
+ game_description=game_description,
148
+ goal_description=goal_description,
149
+ action_description=action_description,
150
+ history=str(self.env_history),
151
+ format_instructions=self.parser.get_format_instructions(),
152
+ reply_format_description=reply_format_description,
153
+ max_token = 3000
154
+ )
155
+
156
+ total_tokens += cb.total_tokens
157
+ total_cost += cb.total_cost
158
+ action = self.parser.parse(response).action
159
+
160
+ text_prompt = chat_prompt.format_messages(
161
+ game_description=game_description,
162
+ goal_description=goal_description,
163
+ action_description=action_description,
164
+ history=str(self.env_history),
165
+ format_instructions=self.parser.get_format_instructions(),
166
+ reply_format_description=reply_format_description,
167
+ )
168
+ texts = ""
169
+ for text in text_prompt:
170
+ texts += text.content + "\n"
171
+
172
+ self._add_history_after_action(action)
173
+ logger.info(f'The GPT response is: {response}.')
174
+ logger.info(f'The optimal action is: {action}.')
175
+ if self.post_memory:
176
+ logger.info(f'The summary is: {self.post_memory[-1]}.')
177
+ if env_info.get('history'):
178
+ logger.info(f'History: {history_to_str(env_info["history"])}')
179
+
180
+ return action, texts, response, logger, total_tokens, total_cost
deciders/jarvis_without_summary.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+ from .misc import history_to_str
3
+ from langchain.chat_models import AzureChatOpenAI
4
+ from langchain.prompts.chat import (
5
+ PromptTemplate,
6
+ ChatPromptTemplate,
7
+ SystemMessagePromptTemplate,
8
+ HumanMessagePromptTemplate,
9
+ )
10
+ from langchain.prompts.few_shot import FewShotPromptTemplate
11
+ from langchain import LLMChain
12
+ from loguru import logger
13
+ from langchain.callbacks import FileCallbackHandler
14
+ from langchain.callbacks import get_openai_callback
15
+ from .act import NaiveAct
16
+ from memory.env_history import EnvironmentHistory
17
+ import tiktoken
18
+
19
+
20
+ class Jarvis(NaiveAct):
21
+ def __init__(self, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None):
22
+ super().__init__(action_space, args, prompts, distiller, temperature, max_tokens)
23
+ self.pre_memory = []
24
+ self.post_memory = []
25
+ self.is_first = True
26
+ self.num_trails = args.num_trails
27
+ self.game_description = args.game_description
28
+ self.goal_description = args.goal_description
29
+ self.action_description = args.action_description
30
+ self._update_mem(None)
31
+
32
+ def update_mem(self,):
33
+ traj = self.game_description
34
+ traj += self.goal_description
35
+ max_step_num = min(14000 // self.num_tokens_from_string(self.env_history.get_one_history()),200)
36
+ traj += self.env_history.get_histories(max_step_num)
37
+ self._update_mem(traj)
38
+
39
+ def _update_mem(self, traj):
40
+ if not self.is_first:
41
+ summary = self.distiller.generate_summary(traj, self.post_memory)
42
+ self.post_memory.append(summary)
43
+ self.insight = self.distiller.generate_insight(self.post_memory)
44
+ else:
45
+ self.is_first = False
46
+ suggestion = self.distiller.generate_suggestion(self.game_description, self.goal_description, self.action_description, self.pre_memory, self.post_memory, self.num_trails)
47
+ self.pre_memory.append(suggestion)
48
+ self.env_history.reset()
49
+
50
+ def _read_mem(self, ):
51
+ insight_str = ""
52
+ if len(self.post_memory) > 0:
53
+ insight_str += "The insights of the game are listed below: "
54
+ insight_str += f"{self.insight}\n"
55
+ suggestion_str = "The suggestions are listed below:" + self.pre_memory[-1]
56
+ return insight_str + suggestion_str
57
+
58
+ def act(
59
+ self,
60
+ state_description,
61
+ action_description,
62
+ env_info,
63
+ game_description,
64
+ goal_description,
65
+ logfile=None,
66
+ ):
67
+ self.game_description = game_description
68
+ self.goal_description = goal_description
69
+ self.env_history.add("observation", state_description)
70
+ chat = AzureChatOpenAI(
71
+ openai_api_type=openai.api_type,
72
+ openai_api_version=openai.api_version,
73
+ openai_api_base=openai.api_base,
74
+ openai_api_key=openai.api_key,
75
+ deployment_name=self.args.gpt_version,
76
+ temperature=self.temperature,
77
+ max_tokens=self.max_tokens,
78
+ )
79
+ reply_format_description = \
80
+ "Your response should choose an optimal action from valid action list, and terminated with following format: "
81
+ # only task relevant examplesA
82
+ template = "Now you are completing a task. "
83
+ template += "You need to carefully understand the description of the game. "
84
+ # TODO: few shot example handle
85
+ if self.irr_few_shot_examples:
86
+ template += "Here are some examples of how you should completing a task."
87
+ for examples in self.irr_few_shot_examples:
88
+ template += "\nQuestion: \n" + examples['question'] + "Answer: \n" + examples['answer']
89
+
90
+ if self.fewshot_example:
91
+ if self.expert_knowledge:
92
+ template += "Here, I will provide you with some expert knowledge to help you better understand the rules of the task."
93
+ template += self.expert_knowledge + '\n'
94
+ template += "Next are some examples: "
95
+ system_message_prompt = SystemMessagePromptTemplate.from_template(template)
96
+
97
+ human_template = ""
98
+ human_template += "\n"
99
+ human_template += "{game_description}\n{action_description}\n{goal_description}\n"
100
+ human_template += "You are observing something and " \
101
+ "you need to choose the optimal action acoordingly. "
102
+ human_template += 'Response and interact using the format: {reply_format_description}{format_instructions}\n'
103
+ human_template += self._read_mem()
104
+ human_template += "\n\nHere are some history states listed below:\n"
105
+
106
+ fewshot_example_prompt = PromptTemplate(
107
+ input_variables=["question", "answer"],
108
+ template="Question: \n{question}\n{answer}"
109
+ )
110
+ human_message_prompt = FewShotPromptTemplate(
111
+ examples=self.fewshot_example,
112
+ example_prompt=fewshot_example_prompt,
113
+ suffix=human_template,
114
+ input_variables=[
115
+ 'game_description', 'goal_description',
116
+ 'action_description', 'reply_format_description'],
117
+ partial_variables={'format_instructions': self.parser.get_format_instructions()}
118
+ )
119
+ human_message_prompt = HumanMessagePromptTemplate(prompt=human_message_prompt)
120
+
121
+ short_memory_template = HumanMessagePromptTemplate.from_template("{history} Please select an action based on the current game state. You must select the appropriate action from the given action descriptions and cannot refrain from taking action or perform any prohibited actions. Here's the action description below: \n {action_description}\n")
122
+
123
+ chat_prompt = ChatPromptTemplate.from_messages(
124
+ [system_message_prompt, human_message_prompt, short_memory_template])
125
+
126
+ if logfile:
127
+ # logger.remove()
128
+ if self.first_call:
129
+ logger.add(logfile, colorize=True, enqueue=True, filter=lambda x: '[Reflexion Memory]' not in x['message'])
130
+ self.first_call = False
131
+ handler = FileCallbackHandler(logfile)
132
+ total_tokens, total_cost = 0, 0
133
+ max_think_times = 1
134
+ # TODO: ADD REACT Support
135
+ # print(str(self.env_history))
136
+ if self.use_short_mem:
137
+ my_history = str(self.env_history)
138
+ else:
139
+ my_history = ""
140
+ for i_think in range(max_think_times):
141
+ chain = LLMChain(llm=chat, prompt=chat_prompt, callbacks=[handler], verbose=False)
142
+ with get_openai_callback() as cb:
143
+ response = chain.run(
144
+ game_description=game_description,
145
+ goal_description=goal_description,
146
+ action_description=action_description,
147
+ history=self.env_history.get_histories(11),
148
+ format_instructions=self.parser.get_format_instructions(),
149
+ reply_format_description=reply_format_description,
150
+ max_token = 3000
151
+ )
152
+
153
+ total_tokens += cb.total_tokens
154
+ total_cost += cb.total_cost
155
+ action = self.parser.parse(response).action
156
+
157
+ text_prompt = chat_prompt.format_messages(
158
+ game_description=game_description,
159
+ goal_description=goal_description,
160
+ action_description=action_description,
161
+ history=self.env_history.get_histories(11),
162
+ format_instructions=self.parser.get_format_instructions(),
163
+ reply_format_description=reply_format_description,
164
+ )
165
+ texts = ""
166
+ for text in text_prompt:
167
+ texts += text.content + "\n"
168
+
169
+ self._add_history_after_action(action)
170
+ logger.info(f'The GPT response is: {response}.')
171
+ logger.info(f'The optimal action is: {action}.')
172
+ if self.pre_memory:
173
+ logger.info(f'The suggestion is: {self.pre_memory[-1]}.')
174
+ if self.post_memory:
175
+ logger.info(f'The summary is: {self.post_memory[-1]}.')
176
+ if env_info.get('history'):
177
+ logger.info(f'History: {history_to_str(env_info["history"])}')
178
+
179
+ return action, texts, response, logger, total_tokens, total_cost
deciders/misc.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def history_to_str(history):
2
+ history_str = ""
3
+ for d in history:
4
+ history_str += f"state: {d['state']}, action: {d['action']}, reward: {d['reward']}\n"
5
+ return history_str
6
+
7
+ def get_majority_vote(actions):
8
+ return max(set(actions), key=actions.count)
9
+
10
+ def test_get_majority_vote():
11
+ assert get_majority_vote([1, 1, 1, 2, 2]) == 1
12
+ assert get_majority_vote([1, 1, 2, 2, 2]) == 2
13
+ assert get_majority_vote([1, 1, 2, 2, 3]) == 1
14
+ assert get_majority_vote([1, 2, 3, 4, 5]) == 1
15
+ assert get_majority_vote([1, 2, 3, 4, 5, 1, 1, 1, 1, 1]) == 1
16
+ assert get_majority_vote([1, 2, 3, 4, 5, 1, 1, 1, 1, 2]) == 1
17
+ assert get_majority_vote([1, 2, 3, 4, 5, 1, 1, 1, 2, 2]) == 1
18
+ assert get_majority_vote([1, 2, 3, 4, 5, 1, 1, 2, 2, 2]) == 2
19
+
20
+ if __name__ == "__main__":
21
+ test_get_majority_vote()
deciders/pal.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+ from .misc import history_to_str
3
+ from langchain.chat_models import AzureChatOpenAI
4
+ from langchain.prompts.chat import (
5
+ PromptTemplate,
6
+ ChatPromptTemplate,
7
+ SystemMessagePromptTemplate,
8
+ HumanMessagePromptTemplate,
9
+ )
10
+ from langchain.prompts.few_shot import FewShotPromptTemplate
11
+ from langchain import LLMChain
12
+ from loguru import logger
13
+ from langchain.callbacks import FileCallbackHandler
14
+ from langchain.callbacks import get_openai_callback
15
+ from .act import NaiveAct
16
+ from .utils import run_chain
17
+
18
+ def get_last_n_lines(text, n):
19
+ lines = text.splitlines()
20
+ return '\n'.join(lines[-n:])
21
+
22
+ class PAL(NaiveAct):
23
+ def __init__(self, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None, logger=None):
24
+ super().__init__(action_space, args, prompts, distiller, temperature, max_tokens, logger)
25
+
26
+ def act(
27
+ self,
28
+ state_description,
29
+ action_description,
30
+ env_info,
31
+ game_description,
32
+ goal_description,
33
+ logfile=None,
34
+ ):
35
+ self._add_history_before_action(game_description, goal_description, state_description)
36
+ chat = AzureChatOpenAI(
37
+ openai_api_type=openai.api_type,
38
+ openai_api_version=openai.api_version,
39
+ openai_api_base=openai.api_base,
40
+ openai_api_key=openai.api_key,
41
+ deployment_name=self.args.gpt_version,
42
+ temperature=self.temperature,
43
+ max_tokens=self.max_tokens
44
+ )
45
+
46
+ suffix_flag = False
47
+ reply_format_description = \
48
+ "Your response should choose an optimal action from a valid action list and terminate with the following format: "
49
+
50
+ # System Message
51
+ human_template = "Now, you are completing a challenging task. You must carefully understand the Program-aided Language method you will use and apply it to the following task.\n"
52
+
53
+ # task-irrelevant SystemMessage
54
+ if self.irr_few_shot_examples:
55
+ human_template += 'In the following example, I shall present a set of question and answer with the Program-aided Language method. Please adhere to the format and reasoning of the provided response when addressing the subsequent task.\n'
56
+ for i, examples in enumerate(self.irr_few_shot_examples):
57
+ human_template += f"\nExample {i+1}:\n"
58
+ human_template += "Question: \n" + examples['question'] + "\nAnswer: \n" + examples['answer']
59
+
60
+ # task-irrelevant few shot if have
61
+ if self.irr_few_shot_examples:
62
+ human_template += "\nMoving forward, I will describe the task, the goal, and the actions you may execute. Please pay close attention to comprehend the information presented below.\n"
63
+
64
+ if self.fewshot_example:
65
+ human_template += "I will describe the task, the goal, and the actions you may execute. Please pay close attention to comprehend the information presented below."
66
+ # print(fewshot_example_prompt.format(**fewshot_examples[0]))
67
+ human_template += '\nTask Description: {game_description} \n'
68
+ human_template += 'Goal Description: {goal_description}\n'
69
+ human_template += 'Actions Description: {action_description}\n'
70
+
71
+ if self.fewshot_example:
72
+ human_template += "Here, I will provide you with some guidance to help you better understand the rules of the task. Next are some examples: "
73
+ for i, examples in enumerate(self.fewshot_example):
74
+ human_template += f"\nExample {i+1}:\n"
75
+ human_template += "Question: \n" + examples['question'] + "\nAnswer: \n" + examples['answer']
76
+
77
+ if self.prompt_level in [2, 3, 4]:
78
+ if self.memory:
79
+ human_template += '\nSubsequently, I will offer pertinent guidance or information about the task. Please utilize this instruction to accomplish the given task effectively.\n'
80
+ suffix_flag = True
81
+ if self.prompt_level == 2:
82
+ human_template += 'I have collected a few trajectories from a random policy, and the summaries are listed below.'
83
+ elif self.prompt_level == 3:
84
+ human_template += 'I have collected a few trajectories before, and the summaries are listed below.'
85
+ elif self.prompt_level == 4:
86
+ human_template += 'I have collected a few trajectories from an expert policy, and the summaries are listed below.'
87
+ human_template += self._read_mem() + "\n"
88
+
89
+ if self.use_short_mem:
90
+ if len(self.env_history) > 1:
91
+ if not suffix_flag:
92
+ human_template += '\nSubsequently, I will offer pertinent guidance or information about the task. Please utilize this instruction to accomplish the given task effectively.'
93
+ human_template += f"\nBelow are the latest {min(self.args.short_mem_num,len(self.env_history)//2)} historical data entries:\n"
94
+ human_template += f"{self.env_history.get_histories(self.mem_num)}"
95
+ human_template += '\nNext is the observation that the agent gets:\nCurrent {state_description}\n'
96
+ human_template += 'Please select an action based on the current game state and the information you get. You must select the appropriate action from the given action descriptions and cannot refrain from taking action or performing any prohibited actions. Here is the action description below:\n{action_description}\n'
97
+ human_template += 'Please generate Python program as answers to given questions, similar to the provided examples.\n'
98
+ human_template += 'And You should calculate the final result based on the program ,not just give a code script alone!\n'
99
+
100
+ human_message_prompt = PromptTemplate(
101
+ template=human_template,
102
+ input_variables=[
103
+ 'state_description', 'goal_description', 'game_description',
104
+ 'action_description'],
105
+ )
106
+
107
+ human_message_prompt = HumanMessagePromptTemplate(prompt=human_message_prompt)
108
+
109
+ chat_prompt = ChatPromptTemplate.from_messages([human_message_prompt])
110
+
111
+ if not self.logger:
112
+ logger.remove()
113
+ self.logger = logger.add(logfile, colorize=True, enqueue=True)
114
+ handler = FileCallbackHandler(logfile)
115
+
116
+ chain = LLMChain(llm=chat, prompt=chat_prompt, callbacks=[handler], verbose=False)
117
+
118
+ with get_openai_callback() as cb:
119
+ response = run_chain(
120
+ chain,
121
+ game_description=game_description,
122
+ state_description=state_description,
123
+ goal_description=goal_description,
124
+ action_description=action_description,
125
+ )
126
+ total_tokens = cb.total_tokens
127
+ total_cost = cb.total_cost
128
+ _response = get_last_n_lines(response, 2)
129
+
130
+
131
+ action = self.parser.parse(_response).action
132
+
133
+ text_prompt = chat_prompt.format_messages(
134
+ game_description=game_description,
135
+ state_description=state_description,
136
+ goal_description=goal_description,
137
+ action_description=action_description,
138
+ )
139
+ texts = ""
140
+ for text in text_prompt:
141
+ texts += text.content + "\n"
142
+
143
+ self._add_history_after_action(action)
144
+ self.logger.info(f'The GPT response is: {response}.')
145
+ self.logger.info(f'The optimal action is: {action}.')
146
+ if env_info.get('history'):
147
+ self.logger.info(f'History: {history_to_str(env_info["history"])}')
148
+
149
+ return action, texts, response, total_tokens, total_cost
deciders/parser.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field, validator
2
+
3
+ # Define your desired data structure.
4
+ class TwoAction(BaseModel):
5
+ action: int = Field(description="the choosed action to perform")
6
+
7
+ # You can add custom validation logic easily with Pydantic.
8
+ @validator('action')
9
+ def action_is_valid(cls, field):
10
+ if field not in [1, 2]:
11
+ raise ValueError("Action is not valid ([1, 2])!")
12
+ return field
13
+
14
+ class ThreeAction(BaseModel):
15
+ action: int = Field(description="the choosed action to perform")
16
+
17
+ # You can add custom validation logic easily with Pydantic.
18
+ @validator('action')
19
+ def action_is_valid(cls, field):
20
+ if field not in [1, 2, 3]:
21
+ raise ValueError("Action is not valid ([1, 2, 3])!")
22
+ return field
23
+
24
+ class FourAction(BaseModel):
25
+ action: int = Field(description="the choosed action to perform")
26
+
27
+ # You can add custom validation logic easily with Pydantic.
28
+ @validator('action')
29
+ def action_is_valid(cls, field):
30
+ if field not in [1, 2, 3, 4]:
31
+ raise ValueError("Action is not valid ([1, 2, 3, 4])!")
32
+ return field
33
+
34
+ class SixAction(BaseModel):
35
+ action: int = Field(description="the choosed action to perform")
36
+
37
+ # You can add custom validation logic easily with Pydantic.
38
+ @validator('action')
39
+ def action_is_valid(cls, field):
40
+ if field not in [1, 2, 3, 4, 5, 6]:
41
+ raise ValueError("Action is not valid ([1, 2, 3, 4, 5, 6])!")
42
+ return field
43
+
44
+ class ContinuousAction(BaseModel):
45
+ action: float = Field(description="the choosed action to perform")
46
+ # You can add custom validation logic easily with Pydantic.
47
+ @validator('action')
48
+ def action_is_valid(cls, field):
49
+ if not (field >= -1 and field <= 1):
50
+ raise ValueError("Action is not valid ([-1,1])!")
51
+ return field
52
+
53
+ PARSERS = {1:ContinuousAction, 2: TwoAction, 3: ThreeAction, 4: FourAction, 6: SixAction}
deciders/reflexion.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+ from .misc import history_to_str
3
+ from langchain.chat_models import AzureChatOpenAI
4
+ from langchain.prompts.chat import (
5
+ PromptTemplate,
6
+ ChatPromptTemplate,
7
+ SystemMessagePromptTemplate,
8
+ HumanMessagePromptTemplate,
9
+ )
10
+ from langchain.prompts.few_shot import FewShotPromptTemplate
11
+ from langchain import LLMChain
12
+ from loguru import logger
13
+ from langchain.callbacks import FileCallbackHandler
14
+ from langchain.callbacks import get_openai_callback
15
+ from .act import NaiveAct
16
+ from memory.env_history import EnvironmentHistory
17
+ import tiktoken
18
+ from .utils import run_chain
19
+
20
+
21
+ class Reflexion(NaiveAct):
22
+ def __init__(self, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None, logger=None):
23
+ super().__init__(action_space, args, prompts, distiller, temperature, max_tokens, logger)
24
+
25
+ def num_tokens_from_string(self,string: str) -> int:
26
+ """Returns the number of tokens in a text string."""
27
+ num_tokens = len(self.encoding.encode(string))
28
+ return num_tokens
29
+
30
+ def update_mem(self,):
31
+ traj = self.game_description
32
+ traj += self.goal_description
33
+ one_history_token = self.num_tokens_from_string(self.env_history.get_one_history())
34
+ history_num = 4000 // one_history_token
35
+ traj += self.env_history.get_histories_with_last(history_num)
36
+ self._update_mem(traj)
37
+
38
+ def _update_mem(self, traj):
39
+ my_reflection = self.distiller.generate(traj, self.memory)
40
+ self.memory.append(my_reflection)
41
+ self.env_history.reset()
42
+
43
+ def act(
44
+ self,
45
+ state_description,
46
+ action_description,
47
+ env_info,
48
+ game_description,
49
+ goal_description,
50
+ logfile=None,
51
+ ):
52
+ self.action_description = action_description
53
+ self.game_description = game_description
54
+ self.goal_description = goal_description
55
+ self.env_history.add("observation", state_description)
56
+ chat = AzureChatOpenAI(
57
+ openai_api_type=openai.api_type,
58
+ openai_api_version=openai.api_version,
59
+ openai_api_base=openai.api_base,
60
+ openai_api_key=openai.api_key,
61
+ deployment_name=self.args.gpt_version,
62
+ temperature=self.temperature,
63
+ max_tokens=self.max_tokens,
64
+ )
65
+ suffix_flag = False
66
+ reply_format_description = \
67
+ "Your response should choose an optimal action from a valid action list and terminate with the following format: "
68
+
69
+ # System Message
70
+ human_template = "Now, you are completing a challenging task. You must carefully understand the Reflexion method you will use and apply it to the following task.\n"
71
+
72
+ # task-irrelevant SystemMessage
73
+ if self.irr_few_shot_examples:
74
+ human_template += 'In the following example, I shall present a set of question and answer about the Reflexion method. Please adhere to the format and reasoning of the provided response when addressing the subsequent task.\n'
75
+ for i, examples in enumerate(self.irr_few_shot_examples):
76
+ human_template += f"\nExample {i+1}:\n"
77
+ human_template += "Question: \n" + examples['question'] + "\nAnswer: \n" + examples['answer']
78
+
79
+ # task-irrelevant few shot if have
80
+ if self.irr_few_shot_examples:
81
+ human_template += "\nMoving forward, I will describe the task, the goal, and the actions you may execute. Please pay close attention to comprehend the information presented below.\n"
82
+
83
+ if self.fewshot_example:
84
+ human_template += "I will describe the task, the goal, and the actions you may execute. Please pay close attention to comprehend the information presented below."
85
+ # print(fewshot_example_prompt.format(**fewshot_examples[0]))
86
+ human_template += '\nTask Description: {game_description} \n'
87
+ human_template += 'Goal Description: {goal_description}\n'
88
+ human_template += 'Actions Description: {action_description}\n'
89
+
90
+ if self.fewshot_example:
91
+ human_template += "Here, I will provide you with some guidance to help you better understand the rules of the task. Next are some examples: "
92
+ for i, examples in enumerate(self.fewshot_example):
93
+ human_template += f"\nExample {i+1}:\n"
94
+ human_template += "Question: \n" + examples['question'] + "\nAnswer: \n" + examples['answer']
95
+
96
+ if self.prompt_level in [2, 3, 4]:
97
+ if self.memory:
98
+ human_template += '\nSubsequently, I will offer pertinent guidance or information about the task. Please utilize this instruction to accomplish the given task effectively.\n'
99
+ suffix_flag = True
100
+ if self.prompt_level == 2:
101
+ human_template += 'I have collected a few trajectories from a random policy, and the summaries are listed below.'
102
+ elif self.prompt_level == 3:
103
+ human_template += 'I have collected a few trajectories before, and the summaries are listed below.'
104
+ elif self.prompt_level == 4:
105
+ human_template += 'I have collected a few trajectories from an expert policy, and the summaries are listed below.'
106
+ human_template += self._read_mem() + "\n"
107
+
108
+ if self.use_short_mem:
109
+ if len(self.env_history) > 1:
110
+ if not suffix_flag:
111
+ human_template += '\nSubsequently, I will offer pertinent guidance or information about the task. Please utilize this instruction to accomplish the given task effectively.'
112
+ human_template += f"\nBelow are the latest {self.mem_num} historical data entries:\n"
113
+ human_template += f"{self.env_history.get_histories(self.mem_num)}"
114
+ human_template += '\nNext is the observation that the agent gets:\nCurrent {state_description}\n'
115
+ human_template += 'Please select an action based on the current game state and the information you get. You must select the appropriate action from the given action descriptions and cannot refrain from taking action or performing any prohibited actions. Here is the action description below:\n{action_description}\n'
116
+ human_template += 'Also, please keep in mind not to answer with any redundant and irrelevant content.\n'
117
+ human_template += "Finally, you also need to normalize your output according to the reply format description.\n"
118
+ human_template += 'Reply format description: {reply_format_description}{format_instructions}\n'
119
+
120
+ human_message_prompt = PromptTemplate(
121
+ template=human_template,
122
+ input_variables=[
123
+ 'state_description', 'goal_description', 'game_description',
124
+ 'action_description', 'reply_format_description'],
125
+ partial_variables={'format_instructions': self.parser.get_format_instructions()}
126
+ )
127
+
128
+ human_message_prompt = HumanMessagePromptTemplate(prompt=human_message_prompt)
129
+
130
+ chat_prompt = ChatPromptTemplate.from_messages([human_message_prompt])
131
+ if not self.logger:
132
+ # logger.remove()
133
+ if self.first_call:
134
+ self.logger = logger.add(logfile, colorize=True, enqueue=True, filter=lambda x: '[Reflexion Memory]' not in x['message'])
135
+ self.first_call = False
136
+ handler = FileCallbackHandler(logfile)
137
+ total_tokens, total_cost = 0, 0
138
+ max_think_times = 1
139
+ # TODO: ADD REACT Support
140
+ # print(str(self.env_history))
141
+
142
+ for i_think in range(max_think_times):
143
+ chain = LLMChain(llm=chat, prompt=chat_prompt, callbacks=[handler], verbose=False)
144
+ with get_openai_callback() as cb:
145
+ response = run_chain(
146
+ chain,
147
+ state_description=self.env_history.get_last_history(),
148
+ game_description=game_description,
149
+ goal_description=goal_description,
150
+ action_description=action_description,
151
+ format_instructions=self.parser.get_format_instructions(),
152
+ reply_format_description=reply_format_description,
153
+ max_token = 3000
154
+ )
155
+
156
+ total_tokens += cb.total_tokens
157
+ total_cost += cb.total_cost
158
+ action = self.parser.parse(response).action
159
+ text_prompt = chat_prompt.format_messages(
160
+ state_description=self.env_history.get_last_history(),
161
+ game_description=game_description,
162
+ goal_description=goal_description,
163
+ action_description=action_description,
164
+ format_instructions=self.parser.get_format_instructions(),
165
+ reply_format_description=reply_format_description,
166
+ )
167
+ texts = ""
168
+ for text in text_prompt:
169
+ texts += text.content + "\n"
170
+
171
+ self._add_history_after_action(action)
172
+ self.logger.info(f'The GPT response is: {response}.')
173
+ self.logger.info(f'The optimal action is: {action}.')
174
+ if self.memory:
175
+ self.logger.info(f'The memory is: {self.memory[-1]}.')
176
+ if env_info.get('history'):
177
+ self.logger.info(f'History: {history_to_str(env_info["history"])}')
178
+
179
+ return action, texts, response, total_tokens, total_cost
deciders/self_consistency.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+ from .misc import history_to_str
3
+ from langchain.chat_models import AzureChatOpenAI
4
+ from langchain.prompts.chat import (
5
+ PromptTemplate,
6
+ ChatPromptTemplate,
7
+ SystemMessagePromptTemplate,
8
+ HumanMessagePromptTemplate,
9
+ )
10
+ from langchain.prompts.few_shot import FewShotPromptTemplate
11
+ from langchain import LLMChain
12
+ from loguru import logger
13
+ from langchain.callbacks import FileCallbackHandler
14
+ from langchain.callbacks import get_openai_callback
15
+ from .act import NaiveAct
16
+ from .utils import run_chain
17
+
18
+
19
+ class SelfConsistency(NaiveAct):
20
+ def __init__(self, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None, logger=None):
21
+ temperature = 0.7
22
+ super().__init__(action_space, args, prompts, distiller, temperature, max_tokens, logger)
23
+ self.temperature = temperature
24
+
25
+ def act(
26
+ self,
27
+ state_description,
28
+ action_description,
29
+ env_info,
30
+ game_description,
31
+ goal_description,
32
+ logfile=None,
33
+ ):
34
+ # print(self.temperature)
35
+ self.action_description = action_description
36
+ self._add_history_before_action(game_description, goal_description, state_description)
37
+ chat = AzureChatOpenAI(
38
+ openai_api_type=openai.api_type,
39
+ openai_api_version=openai.api_version,
40
+ openai_api_base=openai.api_base,
41
+ openai_api_key=openai.api_key,
42
+ deployment_name=self.args.gpt_version,
43
+ temperature=self.temperature,
44
+ max_tokens=self.max_tokens
45
+ )
46
+
47
+ suffix_flag = False
48
+ reply_format_description = \
49
+ "Your response should choose an optimal action from a valid action list and terminate with the following format: "
50
+
51
+ # System Message
52
+ human_template = "Now, you are completing a challenging task. You must carefully understand the Self-Consistency method you will use and apply it to the following task.\n"
53
+
54
+ # task-irrelevant SystemMessage
55
+ if self.irr_few_shot_examples:
56
+ human_template += 'In the following example, I shall present a set of question and answer with the Self-Consistency method. Please adhere to the format and reasoning of the provided response when addressing the subsequent task.\n'
57
+ for i, examples in enumerate(self.irr_few_shot_examples):
58
+ human_template += f"\nExample {i+1}:\n"
59
+ human_template += "Question: \n" + examples['question'] + "\nAnswer: \n" + examples['answer']
60
+
61
+ # task-irrelevant few shot if have
62
+ if self.irr_few_shot_examples:
63
+ human_template += "\nMoving forward, I will describe the task, the goal, and the actions you may execute. Please pay close attention to comprehend the information presented below.\n"
64
+
65
+ if self.fewshot_example:
66
+ human_template += "I will describe the task, the goal, and the actions you may execute. Please pay close attention to comprehend the information presented below."
67
+ # print(fewshot_example_prompt.format(**fewshot_examples[0]))
68
+ human_template += '\nTask Description: {game_description} \n'
69
+ human_template += 'Goal Description: {goal_description}\n'
70
+ human_template += 'Actions Description: {action_description}\n'
71
+
72
+ if self.fewshot_example:
73
+ human_template += "Here, I will provide you with some guidance to help you better understand the rules of the task. Next are some examples: "
74
+ for i, examples in enumerate(self.fewshot_example):
75
+ human_template += f"\nExample {i+1}:\n"
76
+ human_template += "Question: \n" + examples['question'] + "\nAnswer: \n" + examples['answer']
77
+
78
+ if self.prompt_level in [2, 3, 4]:
79
+ if self.memory:
80
+ human_template += '\nSubsequently, I will offer pertinent guidance or information about the task. Please utilize this instruction to accomplish the given task effectively.\n'
81
+ suffix_flag = True
82
+ if self.prompt_level == 2:
83
+ human_template += 'I have collected a few trajectories from a random policy, and the summaries are listed below.'
84
+ elif self.prompt_level == 3:
85
+ human_template += 'I have collected a few trajectories before, and the summaries are listed below.'
86
+ elif self.prompt_level == 4:
87
+ human_template += 'I have collected a few trajectories from an expert policy, and the summaries are listed below.'
88
+ human_template += self._read_mem() + "\n"
89
+
90
+ if self.use_short_mem:
91
+ if len(self.env_history) > 1:
92
+ if not suffix_flag:
93
+ human_template += '\nSubsequently, I will offer pertinent guidance or information about the task. Please utilize this instruction to accomplish the given task effectively.'
94
+ human_template += f"\nBelow are the latest {self.args.short_mem_num} historical data entries:\n"
95
+ human_template += f"{self.env_history.get_histories(self.mem_num)}"
96
+ human_template += '\nNext is the observation that the agent gets:\nCurrent {state_description}\n'
97
+ human_template += 'Please select an action based on the current game state and the information you get. You must select the appropriate action from the given action descriptions and cannot refrain from taking action or performing any prohibited actions. Here is the action description below:\n{action_description}\n'
98
+ human_template += 'Please note that you need to carefully lay out your thought process on the question, not just give an answer. You need to write the corresponding logic of your thinking following the example above. Also, please keep in mind not to answer with any redundant and irrelevant content.\n'
99
+ human_template += "Finally, you also need to normalize your output according to the reply format description.\n"
100
+ human_template += 'Reply format description: {reply_format_description}{format_instructions}\n'
101
+
102
+ human_message_prompt = PromptTemplate(
103
+ template=human_template,
104
+ input_variables=[
105
+ 'state_description', 'goal_description', 'game_description',
106
+ 'action_description', 'reply_format_description'],
107
+ partial_variables={'format_instructions': self.parser.get_format_instructions()}
108
+ )
109
+
110
+ human_message_prompt = HumanMessagePromptTemplate(prompt=human_message_prompt)
111
+
112
+ chat_prompt = ChatPromptTemplate.from_messages([human_message_prompt])
113
+
114
+ if not self.logger:
115
+ logger.remove()
116
+ self.logger = logger.add(logfile, colorize=True, enqueue=True)
117
+ handler = FileCallbackHandler(logfile)
118
+
119
+ chain = LLMChain(llm=chat, prompt=chat_prompt, callbacks=[handler], verbose=False)
120
+
121
+ text_prompt = chat_prompt.format_messages(
122
+ game_description=game_description,
123
+ state_description=state_description,
124
+ goal_description=goal_description,
125
+ action_description=action_description,
126
+ reply_format_description=reply_format_description
127
+ )
128
+ texts = ""
129
+ for text in text_prompt:
130
+ texts += text.content + "\n"
131
+
132
+ actions = []
133
+ response_dict = {}
134
+ error_flag = True
135
+ for i in range(5):
136
+ try:
137
+ with get_openai_callback() as cb:
138
+ response = run_chain(
139
+ chain,
140
+ game_description=game_description,
141
+ state_description=state_description,
142
+ goal_description=goal_description,
143
+ action_description=action_description,
144
+ reply_format_description=reply_format_description
145
+ )
146
+ total_tokens = cb.total_tokens
147
+ total_cost = cb.total_cost
148
+ action = self.parser.parse(response).action
149
+ actions.append(action)
150
+ response_dict[action] = response
151
+
152
+ self.logger.info(f'The GPT response is: {response}.')
153
+ self.logger.info(f'The optimal action is: {action}.\n')
154
+ except:
155
+ continue
156
+
157
+ action = max(set(actions), key=actions.count)
158
+ # print(actions)
159
+ # print(action)
160
+ if actions:
161
+ self._add_history_after_action(action)
162
+ self.logger.info(f'The action list is: {actions}.')
163
+ self.logger.info(f'The GPT response is: {response_dict[action]}.')
164
+ self.logger.info(f'The optimal action is: {action}.')
165
+ if env_info.get('history'):
166
+ self.logger.info(f'History: {history_to_str(env_info["history"])}')
167
+ else:
168
+ raise Exception("No valid Actions!")
169
+
170
+ return action, texts, response, total_tokens, total_cost
deciders/selfask.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+ from .misc import history_to_str
3
+ from langchain.chat_models import AzureChatOpenAI
4
+ from langchain.prompts.chat import (
5
+ PromptTemplate,
6
+ ChatPromptTemplate,
7
+ SystemMessagePromptTemplate,
8
+ HumanMessagePromptTemplate,
9
+ )
10
+ from langchain.prompts.few_shot import FewShotPromptTemplate
11
+ from langchain import LLMChain
12
+ from loguru import logger
13
+ from langchain.callbacks import FileCallbackHandler
14
+ from langchain.callbacks import get_openai_callback
15
+ from .act import NaiveAct
16
+ from .utils import run_chain
17
+
18
+
19
+ class SelfAskAct(NaiveAct):
20
+ def __init__(self, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None, logger=None):
21
+ super().__init__(action_space, args, prompts, distiller, temperature, max_tokens,logger)
22
+
23
+ def act(
24
+ self,
25
+ state_description,
26
+ action_description,
27
+ env_info,
28
+ game_description,
29
+ goal_description,
30
+ logfile=None,
31
+ ):
32
+ self.action_description = action_description
33
+ self._add_history_before_action(game_description, goal_description, state_description)
34
+ chat = AzureChatOpenAI(
35
+ openai_api_type=openai.api_type,
36
+ openai_api_version=openai.api_version,
37
+ openai_api_base=openai.api_base,
38
+ openai_api_key=openai.api_key,
39
+ deployment_name=self.args.gpt_version,
40
+ temperature=self.temperature,
41
+ max_tokens=self.max_tokens
42
+ )
43
+
44
+ suffix_flag = False
45
+ reply_format_description = \
46
+ "Your response should choose an optimal action from a valid action list and terminate with the following format: "
47
+
48
+ # System Message
49
+ human_template = "Now, you are completing a challenging task. You must carefully understand the self-ask method you will use and apply it to the following task.\n"
50
+
51
+ # task-irrelevant SystemMessage
52
+ if self.irr_few_shot_examples:
53
+ human_template += 'In the following example, I shall present a set of question and answer with the self-ask method. Please adhere to the format and reasoning of the provided response when addressing the subsequent task.\n'
54
+ for i, examples in enumerate(self.irr_few_shot_examples):
55
+ human_template += f"\nExample {i+1}:\n"
56
+ human_template += "Question: \n" + examples['question'] + "\nAnswer: \n" + examples['answer']
57
+
58
+ # task-irrelevant few shot if have
59
+ if self.irr_few_shot_examples:
60
+ human_template += "\nMoving forward, I will describe the task, the goal, and the actions you may execute. Please pay close attention to comprehend the information presented below.\n"
61
+
62
+ if self.fewshot_example:
63
+ human_template += "I will describe the task, the goal, and the actions you may execute. Please pay close attention to comprehend the information presented below."
64
+ # print(fewshot_example_prompt.format(**fewshot_examples[0]))
65
+ human_template += '\nTask Description: {game_description} \n'
66
+ human_template += 'Goal Description: {goal_description}\n'
67
+ human_template += 'Actions Description: {action_description}\n'
68
+
69
+ if self.fewshot_example:
70
+ human_template += "Here, I will provide you with some guidance to help you better understand the rules of the task. Next are some examples: "
71
+ for i, examples in enumerate(self.fewshot_example):
72
+ human_template += f"\nExample {i+1}:\n"
73
+ human_template += "Question: \n" + examples['question'] + "\nAnswer: \n" + examples['answer']
74
+
75
+ if self.prompt_level in [2, 3, 4]:
76
+ if self.memory:
77
+ human_template += '\nSubsequently, I will offer pertinent guidance or information about the task. Please utilize this instruction to accomplish the given task effectively.\n'
78
+ suffix_flag = True
79
+ if self.prompt_level == 2:
80
+ human_template += 'I have collected a few trajectories from a random policy, and the summaries are listed below.'
81
+ elif self.prompt_level == 3:
82
+ human_template += 'I have collected a few trajectories before, and the summaries are listed below.'
83
+ elif self.prompt_level == 4:
84
+ human_template += 'I have collected a few trajectories from an expert policy, and the summaries are listed below.'
85
+ human_template += self._read_mem() + "\n"
86
+
87
+ if self.use_short_mem:
88
+ if len(self.env_history) > 1:
89
+ if not suffix_flag:
90
+ human_template += '\nSubsequently, I will offer pertinent guidance or information about the task. Please utilize this instruction to accomplish the given task effectively.'
91
+ human_template += f"\nBelow are the latest {self.args.short_mem_num} historical data entries:\n"
92
+ human_template += f"{self.env_history.get_histories(self.mem_num)}"
93
+ human_template += '\nNext is the observation that the agent gets:\nCurrent {state_description}\n'
94
+ human_template += 'Please select an action based on the current game state and the information you get. You must select the appropriate action from the given action descriptions and cannot refrain from taking action or performing any prohibited actions. Here is the action description below:\n{action_description}\n'
95
+ human_template += 'You must utilize a multi-turn dialogue approach, just as the format illustrated in the example above(like "Follow up" and "Intermediate answer"). And you need to write down the thought process during the self-ask process. Also, please keep in mind not to answer with any redundant and irrelevant content.\n'
96
+ human_template += "Finally, you also need to normalize your output according to the reply format description.\n"
97
+ human_template += 'Reply format description: {reply_format_description}{format_instructions}\n'
98
+
99
+ human_message_prompt = PromptTemplate(
100
+ template=human_template,
101
+ input_variables=[
102
+ 'state_description', 'goal_description', 'game_description',
103
+ 'action_description', 'reply_format_description'],
104
+ partial_variables={'format_instructions': self.parser.get_format_instructions()}
105
+ )
106
+
107
+ human_message_prompt = HumanMessagePromptTemplate(prompt=human_message_prompt)
108
+
109
+ chat_prompt = ChatPromptTemplate.from_messages([human_message_prompt])
110
+
111
+ if not self.logger:
112
+ logger.remove()
113
+ self.logger = logger.add(logfile, colorize=True, enqueue=True)
114
+ handler = FileCallbackHandler(logfile)
115
+
116
+ chain = LLMChain(
117
+ llm=chat, prompt=chat_prompt, callbacks=[handler], verbose=False)
118
+
119
+ with get_openai_callback() as cb:
120
+ response = run_chain(
121
+ chain,
122
+ game_description=game_description,
123
+ state_description=state_description,
124
+ goal_description=goal_description,
125
+ action_description=action_description,
126
+ reply_format_description=reply_format_description
127
+ )
128
+ total_tokens = cb.total_tokens
129
+ total_cost = cb.total_cost
130
+ action = self.parser.parse(response).action
131
+
132
+ text_prompt = chat_prompt.format_messages(
133
+ game_description=game_description,
134
+ state_description=state_description,
135
+ goal_description=goal_description,
136
+ action_description=action_description,
137
+ reply_format_description=reply_format_description
138
+ )
139
+ texts = ""
140
+ for text in text_prompt:
141
+ texts += text.content + "\n"
142
+
143
+ self._add_history_after_action(action)
144
+
145
+ self.logger.info(f'The GPT response is: {response}.')
146
+ self.logger.info(f'The optimal action is: {action}.')
147
+ if env_info.get('history'):
148
+ self.logger.info(f'History: {history_to_str(env_info["history"])}')
149
+
150
+ return action, texts, response, total_tokens, total_cost
deciders/spp.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+ from .misc import history_to_str
3
+ from langchain.chat_models import AzureChatOpenAI
4
+ from langchain.prompts.chat import (
5
+ PromptTemplate,
6
+ ChatPromptTemplate,
7
+ SystemMessagePromptTemplate,
8
+ HumanMessagePromptTemplate,
9
+ )
10
+ from langchain.prompts.few_shot import FewShotPromptTemplate
11
+ from langchain import LLMChain
12
+ from loguru import logger
13
+ from langchain.callbacks import FileCallbackHandler
14
+ from langchain.callbacks import get_openai_callback
15
+ from .act import NaiveAct
16
+ from .utils import run_chain
17
+
18
+ class SPP(NaiveAct):
19
+ def __init__(self, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None, logger=None):
20
+ super().__init__(action_space, args, prompts, distiller, temperature, max_tokens, logger)
21
+
22
+ def act(
23
+ self,
24
+ state_description,
25
+ action_description,
26
+ env_info,
27
+ game_description,
28
+ goal_description,
29
+ logfile=None,
30
+ ):
31
+ self.action_description = action_description
32
+ self._add_history_before_action(game_description, goal_description, state_description)
33
+ chat = AzureChatOpenAI(
34
+ openai_api_type=openai.api_type,
35
+ openai_api_version=openai.api_version,
36
+ openai_api_base=openai.api_base,
37
+ openai_api_key=openai.api_key,
38
+ deployment_name=self.args.gpt_version,
39
+ temperature=self.temperature,
40
+ max_tokens=self.max_tokens
41
+ )
42
+
43
+ self.fewshot_example = self.irr_few_shot_examples if not self.fewshot_example else self.fewshot_example
44
+ self.irr_few_shot_examples = self.irr_few_shot_examples if not self.fewshot_example else self.fewshot_example
45
+ suffix_flag = False
46
+ reply_format_description = \
47
+ "Your response should choose an optimal action from a valid action list and terminate with the following format: "
48
+
49
+ # System Message
50
+ human_template = "When faced with a task, begin by identifying the participants who will contribute to solving the task. Then, initiate a multi-round collaboration process until a final solution is reached. The participants will give critical comments and detailed suggestions whenever necessary.\n"
51
+ human_template += "Now, you are completing a challenging task. You must carefully understand the Solo-Performance-Prompting method you will use and apply it to the following task.\n"
52
+
53
+ # task-irrelevant SystemMessage
54
+ if self.irr_few_shot_examples:
55
+ human_template += 'In the following example, I shall present a set of question and answer with the Solo-Performance-Prompting method. Please adhere to the format and reasoning of the provided response when addressing the subsequent task.\n'
56
+ for i, examples in enumerate(self.irr_few_shot_examples):
57
+ human_template += f"\nExample {i+1}:\n"
58
+ human_template += "Question: \n" + examples['question'] + "\nAnswer: \n" + examples['answer']
59
+
60
+ # task-irrelevant few shot if have
61
+ if self.irr_few_shot_examples:
62
+ human_template += "\nMoving forward, I will describe the task, the goal, and the actions you may execute. Please pay close attention to comprehend the information presented below.\n"
63
+
64
+ human_template += '\nTask Description: {game_description} \n'
65
+ human_template += 'Goal Description: {goal_description}\n'
66
+ human_template += 'Actions Description: {action_description}\n'
67
+
68
+ if self.prompt_level in [2, 3, 4]:
69
+ if self.memory:
70
+ human_template += '\nSubsequently, I will offer pertinent guidance or information about the task. Please utilize this instruction to accomplish the given task effectively.\n'
71
+ suffix_flag = True
72
+ if self.prompt_level == 2:
73
+ human_template += 'I have collected a few trajectories from a random policy, and the summaries are listed below.'
74
+ elif self.prompt_level == 3:
75
+ human_template += 'I have collected a few trajectories before, and the summaries are listed below.'
76
+ elif self.prompt_level == 4:
77
+ human_template += 'I have collected a few trajectories from an expert policy, and the summaries are listed below.'
78
+ human_template += self._read_mem() + "\n"
79
+
80
+ if self.use_short_mem:
81
+ if len(self.env_history) > 1:
82
+ if not suffix_flag:
83
+ human_template += '\nSubsequently, I will offer pertinent guidance or information about the task. Please utilize this instruction to accomplish the given task effectively.'
84
+ human_template += f"\nBelow are the latest {self.args.short_mem_num} historical data entries:\n"
85
+ human_template += f"{self.env_history.get_histories(self.mem_num)}"
86
+ human_template += '\nNext is the observation that the agent gets:\nCurrent {state_description}\n'
87
+ human_template += 'Please select an action based on the current game state and the information you get. You must select the appropriate action from the given action descriptions and cannot refrain from taking action or performing any prohibited actions. Here is the action description below:\n{action_description}\n'
88
+ human_template += 'Please note that you need to carefully lay out the participants who will contribute to solving the task and initiate a multi-round collaboration process until a final solution is reached. Now, identify the participants and collaboratively solve the following task step by step.Also, please keep in mind not to answer with any redundant and irrelevant content.\n'
89
+ human_template += "Finally, you also need to normalize your output according to the reply format description.\n"
90
+ human_template += 'Reply format description: {reply_format_description}{format_instructions}\n'
91
+
92
+ human_message_prompt = PromptTemplate(
93
+ template=human_template,
94
+ input_variables=[
95
+ 'state_description', 'goal_description', 'game_description',
96
+ 'action_description', 'reply_format_description'],
97
+ partial_variables={'format_instructions': self.parser.get_format_instructions()}
98
+ )
99
+
100
+ human_message_prompt = HumanMessagePromptTemplate(prompt=human_message_prompt)
101
+
102
+ chat_prompt = ChatPromptTemplate.from_messages([human_message_prompt])
103
+
104
+ if not self.logger:
105
+ logger.remove()
106
+ self.logger = logger.add(logfile, colorize=True, enqueue=True)
107
+ handler = FileCallbackHandler(logfile)
108
+
109
+ chain = LLMChain(llm=chat, prompt=chat_prompt, callbacks=[handler], verbose=False)
110
+
111
+ with get_openai_callback() as cb:
112
+ response = run_chain(
113
+ chain,
114
+ game_description=game_description,
115
+ state_description=state_description,
116
+ goal_description=goal_description,
117
+ action_description=action_description,
118
+ reply_format_description=reply_format_description
119
+ )
120
+ total_tokens = cb.total_tokens
121
+ total_cost = cb.total_cost
122
+ action = self.parser.parse(response).action
123
+
124
+ text_prompt = chat_prompt.format_messages(
125
+ game_description=game_description,
126
+ state_description=state_description,
127
+ goal_description=goal_description,
128
+ action_description=action_description,
129
+ reply_format_description=reply_format_description
130
+ )
131
+ texts = ""
132
+ for text in text_prompt:
133
+ texts += text.content + "\n"
134
+
135
+ self._add_history_after_action(action)
136
+
137
+ self.logger.info(f'The GPT response is: {response}.')
138
+ self.logger.info(f'The optimal action is: {action}.')
139
+ if env_info.get('history'):
140
+ self.logger.info(f'History: {history_to_str(env_info["history"])}')
141
+
142
+ return action, texts, response, total_tokens, total_cost
deciders/utils.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import openai
4
+ from openai import OpenAI
5
+ from tenacity import (
6
+ retry,
7
+ stop_after_attempt, # type: ignore
8
+ wait_random_exponential, # type: ignore
9
+ )
10
+
11
+ from typing import Optional, List
12
+ if sys.version_info >= (3, 8):
13
+ from typing import Literal
14
+ else:
15
+ from typing_extensions import Literal
16
+
17
+
18
+ Model = Literal["gpt-4", "gpt-35-turbo", "text-davinci-003"]
19
+
20
+ from .gpt import gpt
21
+ gpt().__init__()
22
+
23
+ import timeout_decorator
24
+ @timeout_decorator.timeout(30)
25
+ def run_chain(chain, *args, **kwargs):
26
+ return chain.run(*args, **kwargs)
27
+
28
+ @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
29
+ def get_completion(prompt: str, engine: str = "gpt-35-turbo", temperature: float = 0.0, max_tokens: int = 256, stop_strs: Optional[List[str]] = None) -> str:
30
+
31
+ client = OpenAI(api_key=openai.api_key)
32
+ response = client.chat.completions.create(
33
+ model=engine,
34
+ prompt=prompt,
35
+ temperature=temperature,
36
+ max_tokens=max_tokens,
37
+ top_p=1,
38
+ frequency_penalty=0.0,
39
+ presence_penalty=0.0,
40
+ stop=stop_strs,
41
+ # request_timeout = 1
42
+ )
43
+ return response.choices[0].text
44
+
45
+ # @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
46
+ def get_chat(prompt: str, model: str = "gpt-35-turbo", engine: str = "gpt-35-turbo", temperature: float = 0.0, max_tokens: int = 256, stop_strs: Optional[List[str]] = None, is_batched: bool = False) -> str:
47
+ assert model != "text-davinci-003"
48
+ messages = [
49
+ {
50
+ "role": "user",
51
+ "content": prompt
52
+ }
53
+ ]
54
+ # import pdb;pdb.set_trace()
55
+ client = OpenAI(api_key=openai.api_key)
56
+
57
+ response = client.chat.completions.create(
58
+ model=model,
59
+ messages=messages,
60
+ max_tokens=max_tokens,
61
+ stop=stop_strs,
62
+ temperature=temperature,
63
+ # request_timeout = 1
64
+ )
65
+ return response.choices[0].message.content
distillers/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from .raw_prompt_generator import RawPromptGenerator
2
+ from .self_reflection import RefletionGenerator
3
+ from .traj_prompt_summarizer import TrajPromptSummarizer
4
+ from .guider import Guidance_Generator
5
+
6
+ REGISTRY = {}
7
+ REGISTRY['raw_distiller'] = RawPromptGenerator
8
+ REGISTRY['reflect_distiller'] = RefletionGenerator
9
+ REGISTRY['traj_distiller'] = TrajPromptSummarizer
10
+ REGISTRY['guide_generator'] = Guidance_Generator
distillers/guidance_summary_few_shot_examples.txt ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ You are in the middle of a room. Looking quickly around you, you see a cabinet 10, a cabinet 9, a cabinet 8, a cabinet 7, a cabinet 6, a cabinet 5, a cabinet 4, a cabinet 3, a cabinet 2, a cabinet 1, a coffeemachine 1, a countertop 2, a countertop 1, a diningtable 2, a diningtable 1, a drawer 3, a drawer 2, a drawer 1, a fridge 1, a garbagecan 1, a microwave 1, a sinkbasin 1, a stoveburner 4, a stoveburner 3, a stoveburner 2, a stoveburner 1, and a toaster 1.
2
+ Your task is to: heat some mug and put it in coffeemachine.
3
+ Action: think: To solve the task, I need to find and take a mug, then heat it with stoveburner, then put it in coffeemachine.
4
+ Ok.
5
+ Action: look
6
+ You are in the middle of a room. Looking quickly around you, you see nothing.
7
+ Action: think: First I need to find a mug. A mug is more likely to appear in countertop (1-2), diningtable (1-2), drawer (1-3), fridge (1), sinkbasin (1), stoveburner (1-4), toaster (1). I can check one by one, starting with countertop 1.
8
+ Ok.
9
+ Action: look
10
+ You are in the middle of a room. Looking quickly around you, you see nothing.
11
+ Action: go to countertop 1
12
+ On the countertop 1, you see a butterknife 1, a cup 1, a fork 2, and a mug 1.
13
+ Action: think: Now I find a mug (1). Next, I need to take it.
14
+ Ok.
15
+ Action: take mug 1 from countertop 1
16
+ You pick up the mug 1 from the countertop 1.
17
+ Action: think: Now I take a mug (1). Next, I need to heat it with stoveburner (1-4).
18
+ Ok.
19
+ Action: go to stoveburner 1
20
+ On the stoveburner 1, you see a pan 2.
21
+ Action: examine stoveburner 1
22
+ On the stoveburner 1, you see a pan 2.
23
+ Action: examine stoveburner 1
24
+ On the stoveburner 1, you see a pan 2.
25
+ Action: examine stoveburner 1
26
+ On the stoveburner 1, you see a pan 2.
27
+ Action: examine stoveburner 1
28
+ On the stoveburner 1, you see a pan 2.
29
+ STATUS: FAIL
30
+ Summary and Key Findings: Summary of Trajectory and Outcomes:
31
+ The task was to heat a mug and put it in the coffeemachine. The initial plan was to find and take a mug, heat it with a stoveburner, and then put it in the coffeemachine. The user successfully found and took a mug from countertop 1. However, upon reaching stoveburner 1, the user got stuck in a loop of examining the stoveburner without making any progress in heating the mug. The task ultimately ended in failure.
32
+ Key Findings:
33
+ 1. The user was able to identify potential locations for finding a mug and successfully located one on countertop 1.
34
+ 2. The user got stuck in a loop of examining stoveburner 1 without taking any action to heat the mug.
35
+ 3. The task was not completed successfully due to the user's inability to progress beyond examining the stoveburner.
36
+
37
+ You are in the middle of a room. Looking quickly around you, you see a bed 1, a desk 2, a desk 1, a drawer 6, a drawer 5, a drawer 4, a drawer 3, a drawer 2, a drawer 1, a garbagecan 1, a laundryhamper 1, a safe 1, a shelf 6, a shelf 5, a shelf 4, a shelf 3, a shelf 2, and a shelf 1.
38
+ Your task is to: look at bowl under the desklamp.
39
+ Action: think: To solve the task, I need to find and take a bowl, then find and use a desklamp.
40
+ OK.
41
+ Action: think: First I need to find a bowl. A bowl is more likely to appear in drawer (1-6), desk (1-2), shelf (1-6), garbagecan (1), laundryhamper (1). I can check one by one, starting with drawer 1.
42
+ OK.
43
+ Action: go to drawer 1
44
+ The drawer 1 is closed.
45
+ Action: open drawer 1
46
+ You open the drawer 1. The drawer 1 is open. In it, you see a cellphone 2.
47
+ Action: go to drawer 2
48
+ On the drawer 2, you see a keychain 1.
49
+ Action: go to drawer 3
50
+ Nothing happens.
51
+ Action: go to drawer 4
52
+ The drawer 4 is closed.
53
+ Action: open drawer 4
54
+ You open the drawer 4. The drawer 4 is open. In it, you see nothing.
55
+ Action: go to drawer 5
56
+ On the drawer 5, you see nothing.
57
+ Action: go to drawer 6
58
+ The drawer 6 is closed.
59
+ Action: open drawer 6
60
+ You open the drawer 6. The drawer 6 is open. In it, you see a book 2, and a pen 2.
61
+ Action: go to desk 1
62
+ On the desk 1, you see a alarmclock 1, a desklamp 1, a mug 1, and a pencil 1.
63
+ Action: think: Now I find a desklamp (1). Next, I need to use it.
64
+ OK.
65
+ Action: use desklamp 1
66
+ You turn on the desklamp 1.
67
+ Action:
68
+ Nothing happens.
69
+ Action: go to desk 2
70
+ On the desk 2, you see a bowl 2, a bowl 1, a cd 2, a cd 1, and a creditcard 3.
71
+ Action: think: Now I find a bowl (2). Next, I need to take it.
72
+ OK.
73
+ Action: take bowl 2 from desk 2
74
+ You pick up the bowl 2 from the desk 2.
75
+ Action:
76
+ Nothing happens.
77
+ Action:
78
+ Nothing happens.
79
+ STATUS: FAIL
80
+ Summary: Summary of Trajectory and Outcomes:
81
+ The task was to look at a bowl under the desklamp. The initial plan was to find and take a bowl, then find and use a desklamp. The user successfully found a desklamp on desk 1 and turned it on. They also found a bowl on desk 2 and picked it up. However, the user did not complete the task of looking at the bowl under the desklamp and instead got stuck in a state of inaction. The task ultimately ended in failure.
82
+ Key Findings:
83
+ 1. The user was able to identify potential locations for finding a bowl and a desklamp and successfully located them on desk 2 and desk 1, respectively.
84
+ 2. The user turned on the desklamp but did not proceed to look at the bowl under it.
85
+ 3. The task was not completed successfully due to the user's inability to progress beyond taking the bowl and turning on the desklamp.
distillers/guider.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from deciders.utils import get_completion, get_chat
2
+
3
+ from typing import List, Dict, Any
4
+ from loguru import logger
5
+ import random
6
+ import json
7
+ class Guidance_Generator():
8
+ def __init__(self,logfile="",args=None):
9
+ self.args = args
10
+ with open("./distillers/guidance_summary_few_shot_examples.txt", 'r') as f:
11
+ self.SUMMARY_FEW_SHOT_EXAMPLES = f.read()
12
+ # with open("./distillers/exploration_few_shot_examples.txt", 'r') as f:
13
+ # self.SUGGEST_FEW_SHOT_EXAMPLES = f.read()
14
+ self.insight = ""
15
+ self.suggestion = ""
16
+ if logfile:
17
+ # logger.remove()
18
+ logger.add(logfile, colorize=True, enqueue=True, filter=lambda x: '[Reflexion Memory]' in x['message'])
19
+
20
+ def generate_from_file(self, file_path,max_step_num=200):
21
+ mem = []
22
+ with open(file_path, 'r') as infile:
23
+ data = json.load(infile)
24
+ for traj in data:
25
+ traj_text = traj[0]['game_description']
26
+ traj_text += traj[0]['goal_description']
27
+ for transition in traj[-max_step_num:]:
28
+ traj_text += transition['observation']
29
+ traj_text += f"Action: {transition['action']}"
30
+ summary = self.generate_summary(traj_text, mem)
31
+ mem.append(summary)
32
+ return mem
33
+
34
+ def _generate_summary_query(self, traj, post_memory):
35
+ """
36
+ Generates an exploration guidance query for GPT-3.5 based on given trajectory and memory.
37
+
38
+ Parameters:
39
+ - traj: Trajectory of the new experience.
40
+ - post_memory: List of memory items to summarize.
41
+
42
+ Returns:
43
+ - query: Formulated query string for GPT-3.5.
44
+ """
45
+ segments = []
46
+
47
+ # Summarization memory
48
+ # if post_memory:
49
+ # segments.append('Your summarization memory is as below:')
50
+ # segments.extend([f'Episode #{i}: {m}' for i, m in enumerate(post_memory)])
51
+
52
+ # Trajectory
53
+ segments.append(f"Your new collected trajectory is as below:\n {traj}")
54
+ segments.append(f"The suggestion to guide the trajectory is:\n{self.suggestion}")
55
+ # Questions
56
+ questions = """
57
+ Please answer the following questions directly, without additional explanation:
58
+ 1. Based on the collected trajectory, infer the specific values of game-relevant knowledge proposed in the suggestion with json format.
59
+ 2. Summarize the policy behavior and its performance.
60
+ Provide concise responses.
61
+ """
62
+ segments.append(questions)
63
+
64
+ # Construct the final query
65
+ query = '\n'.join(segments)
66
+ return query
67
+
68
+ # def _generate_summary_query(self, traj, post_memory):
69
+ # """Allows the Agent to generate exploration guidance."""
70
+ # query = ""
71
+ # if len(post_memory) > 0:
72
+ # query += '\Your summarization memory is as below:\n'
73
+ # for i, m in enumerate(post_memory):
74
+ # query += f'Episode #{i}: {m}\n'
75
+ # query += f"""
76
+ # {traj}
77
+ # Above is the trajectory of the new experience.
78
+ # """
79
+ # query += '\n Anwser the following questions.\n 1. What is the performance of this policy and does it improve the performance compared to before? 2. Summarize the main reason that makes the policy improve or reduce the performance; 3. What new information of the task can be inferred compared to the memory?'
80
+ # return query
81
+
82
+ def generate_summary(self, traj, post_memory):
83
+ query = self._generate_summary_query(traj, post_memory)
84
+ summary = get_chat(query,model=self.args.gpt_version, engine=self.args.gpt_version)
85
+ logger.info(f'[Reflexion Memory]The summary prompt is: {query}.')
86
+ logger.info(f'[Reflexion Memory]The summary response is: {summary}.')
87
+ return summary
88
+
89
+ def generate_insight(self, post_memory):
90
+ query: str = f"""As an AI assistant, you are helping a six-year-old player who has never played this game before. The experiences you have are as follows:"""
91
+ if len(post_memory) > 0:
92
+ for i, m in enumerate(post_memory):
93
+ query += f'Episode #{i}: {m}\n'
94
+ query += '\n Identify and summarize the key information that can be exploited to improve performance of the player.'
95
+ insight = get_chat(query,model=self.args.gpt_version, engine=self.args.gpt_version)
96
+ logger.info(f'[Reflexion Memory]The insight prompt is: {query}.')
97
+ logger.info(f'[Reflexion Memory]The insight response is: {insight}.')
98
+ return insight
99
+
100
+ def generate_suggestion(self, game_description, goal_description, action_description, pre_memory, post_memory, insight, max_num_trials):
101
+ query: str = f"""You are an AI assitant that help a human player win the following game.
102
+ The game is \n"{game_description}" \n, the action space is described as {action_description},\n the player's goal is \n "{goal_description}".\n
103
+ The player can play for {max_num_trials} episodes. The main aim for you is to help the player win the game in the last episode. """
104
+ if len(post_memory) > 0:
105
+ query += f"""You have obtained experience as below """
106
+ for i, m in enumerate(post_memory):
107
+ query += f'Episode #{i}: {m}\n'
108
+ # if max_num_trials - len(post_memory) == 1:
109
+ # query = (f"\n The main goal is to aid the human player in winning the game in the next episode. "
110
+ # f"This is his {len(post_memory) + 1} try out of {max(max_num_trials, 1)} episodes. "
111
+ # "Your suggestions should be simple, executable with heuristic policy, and suitable for an LLM agent. "
112
+ # "Reply in an item list format. Specifically, focus on:"
113
+ # "\n1. How to achieve optimal performance (exploitation) using the obtained knowledge?"
114
+ # "\nNote: Stress the importance of prioritizing performance without exploration.")
115
+ # suggestion = get_chat(query) + "\n Remember, in this attempt, aim solely for high performance without exploration."
116
+ # else:
117
+ # if max_num_trials-len(post_memory) == 1:
118
+ # query += f"\n The main aim for you is to help the human player win the game in the last episode. The next episode is the last episode. You can give suggestions before each episode. Then what is your suggestion for his next episode? Note that this is the last try and he should not explore which may decrease his performance. The suggestions should be simple to follow, executable with heuristic policy, easy to use for an llm agent,and reply in item list format. The answer should instruct him to exploit all the knowlegde to gain the highest performance (exploitation) in the next episode. "
119
+ # else:
120
+ query += f"\n The main aim for you is to help the human player win the game in the last episode. He has only {max(max_num_trials-len(post_memory), 1)} episodes left to try.You can give suggestions before each episode. Then what is your suggestion for his next episode? Please provide simple, concise answers suitable for a six-year-old child, focusing on the following in item list format: 1. What game-relevant knowledge is critical to determine the optimal policy. Notice that the knowledge should be obtainable by interacting with the environment and helpful for the decisions.\n 2. How should the player conduct exploration in the next episode to acquire this information?\n3. How can the player exploit the information obtained to achieve higher performance in subsequent episodes?\n 4. How should exploration and exploitation be balanced to improve performance in the next episode?\n"
121
+ # query += (f"\n The primary goal is to assist the human player in winning the game in the final episode. "
122
+ # f"This is his {len(post_memory) + 1} try out of {max(max_num_trials, 1)} episodes. "
123
+ # "Provide suggestions for the next episode that balance both exploration and exploitation. "
124
+ # "The suggestions should be in item list format, easy to follow, aligned with heuristic policy, and usable for an LLM agent. Address:"
125
+ # "\n1. Which information the player should gather via exploration and the best ways to explore?"
126
+ # "\n2. Strategies to refine the policy for enhanced performance (exploitation)?"
127
+ # "\n3. How should exploration and exploitation be weighted in the next episode?")
128
+
129
+ # TODO: consider the inconsistency between past suggestion and past memory.
130
+ suggestion = get_chat(query,model=self.args.gpt_version, engine=self.args.gpt_version)
131
+ self.suggestion = suggestion
132
+ logger.info(f'[Reflexion Memory]The suggestion prompt is: {query}.')
133
+ logger.info(f'[Reflexion Memory]The suggestion response is: {suggestion}.')
134
+ return suggestion
135
+
136
+ def generate(self, traj, memory, max_len_mem=5):
137
+ if len(memory)> max_len_mem:
138
+ reflection_query = self._generate_summary_query(traj, memory[-max_len_mem:])
139
+ else:
140
+ reflection_query = self._generate_summary_query(traj, memory)
141
+ reflection = get_completion(reflection_query,engine=self.args.gpt_version)
142
+ logger.info(f'[Reflexion Memory]The reflexion prompt is: {reflection_query}.')
143
+ logger.info(f'[Reflexion Memory]The reflexion response is: {reflection}.')
144
+ return reflection
distillers/raw_prompt_generator.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import json
3
+ class RawPromptGenerator():
4
+ def __init__(self,args=None):
5
+ self.args = args
6
+ pass
7
+
8
+ def generate_from_file(self, file_path, choice_num=1):
9
+ with open(file_path, 'r') as infile:
10
+ data = json.load(infile)
11
+ result = []
12
+ for my_data in data[0]:
13
+ result.append({'question': my_data['question'], 'answer': my_data['answer']})
14
+ selected_index = random.sample(range(len(result)), choice_num)
15
+ selected_result = [result[index] for index in selected_index]
16
+ return selected_result
distillers/reflexion_few_shot_examples.txt ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ You are in the middle of a room. Looking quickly around you, you see a cabinet 10, a cabinet 9, a cabinet 8, a cabinet 7, a cabinet 6, a cabinet 5, a cabinet 4, a cabinet 3, a cabinet 2, a cabinet 1, a coffeemachine 1, a countertop 2, a countertop 1, a diningtable 2, a diningtable 1, a drawer 3, a drawer 2, a drawer 1, a fridge 1, a garbagecan 1, a microwave 1, a sinkbasin 1, a stoveburner 4, a stoveburner 3, a stoveburner 2, a stoveburner 1, and a toaster 1.
2
+ Your task is to: heat some mug and put it in coffeemachine.
3
+ > think: To solve the task, I need to find and take a mug, then heat it with stoveburner, then put it in coffeemachine.
4
+ Ok.
5
+ > look
6
+ You are in the middle of a room. Looking quickly around you, you see nothing.
7
+ > think: First I need to find a mug. A mug is more likely to appear in countertop (1-2), diningtable (1-2), drawer (1-3), fridge (1), sinkbasin (1), stoveburner (1-4), toaster (1). I can check one by one, starting with countertop 1.
8
+ Ok.
9
+ > look
10
+ You are in the middle of a room. Looking quickly around you, you see nothing.
11
+ > go to countertop 1
12
+ On the countertop 1, you see a butterknife 1, a cup 1, a fork 2, and a mug 1.
13
+ > think: Now I find a mug (1). Next, I need to take it.
14
+ Ok.
15
+ > take mug 1 from countertop 1
16
+ You pick up the mug 1 from the countertop 1.
17
+ > think: Now I take a mug (1). Next, I need to heat it with stoveburner (1-4).
18
+ Ok.
19
+ > go to stoveburner 1
20
+ On the stoveburner 1, you see a pan 2.
21
+ > examine stoveburner 1
22
+ On the stoveburner 1, you see a pan 2.
23
+ > examine stoveburner 1
24
+ On the stoveburner 1, you see a pan 2.
25
+ > examine stoveburner 1
26
+ On the stoveburner 1, you see a pan 2.
27
+ > examine stoveburner 1
28
+ On the stoveburner 1, you see a pan 2.
29
+ STATUS: FAIL
30
+ New plan: I was stuck in a loop in which I continually examined stoveburner 1 instead of heating mug 1 with stoveburner 1. I should have taken mug 1 from countertop 1, then heated it with stoveburner 1, then put it in coffeemachine 1. It did not help to execute two identical actions in a row. I will try to execute a different action if I am stuck in a loop again.
31
+
32
+ You are in the middle of a room. Looking quickly around you, you see a bed 1, a desk 2, a desk 1, a drawer 6, a drawer 5, a drawer 4, a drawer 3, a drawer 2, a drawer 1, a garbagecan 1, a laundryhamper 1, a safe 1, a shelf 6, a shelf 5, a shelf 4, a shelf 3, a shelf 2, and a shelf 1.
33
+ Your task is to: look at bowl under the desklamp.
34
+ > think: To solve the task, I need to find and take a bowl, then find and use a desklamp.
35
+ OK.
36
+ > think: First I need to find a bowl. A bowl is more likely to appear in drawer (1-6), desk (1-2), shelf (1-6), garbagecan (1), laundryhamper (1). I can check one by one, starting with drawer 1.
37
+ OK.
38
+ > go to drawer 1
39
+ The drawer 1 is closed.
40
+ > open drawer 1
41
+ You open the drawer 1. The drawer 1 is open. In it, you see a cellphone 2.
42
+ > go to drawer 2
43
+ On the drawer 2, you see a keychain 1.
44
+ > go to drawer 3
45
+ Nothing happens.
46
+ > go to drawer 4
47
+ The drawer 4 is closed.
48
+ > open drawer 4
49
+ You open the drawer 4. The drawer 4 is open. In it, you see nothing.
50
+ > go to drawer 5
51
+ On the drawer 5, you see nothing.
52
+ > go to drawer 6
53
+ The drawer 6 is closed.
54
+ > open drawer 6
55
+ You open the drawer 6. The drawer 6 is open. In it, you see a book 2, and a pen 2.
56
+ > go to desk 1
57
+ On the desk 1, you see a alarmclock 1, a desklamp 1, a mug 1, and a pencil 1.
58
+ > think: Now I find a desklamp (1). Next, I need to use it.
59
+ OK.
60
+ > use desklamp 1
61
+ You turn on the desklamp 1.
62
+ >
63
+ Nothing happens.
64
+ > go to desk 2
65
+ On the desk 2, you see a bowl 2, a bowl 1, a cd 2, a cd 1, and a creditcard 3.
66
+ > think: Now I find a bowl (2). Next, I need to take it.
67
+ OK.
68
+ > take bowl 2 from desk 2
69
+ You pick up the bowl 2 from the desk 2.
70
+ >
71
+ Nothing happens.
72
+ >
73
+ Nothing happens.
74
+ STATUS: FAIL
75
+ New plan: In this environment, my plan was to find a bowl then find and use a desklamp. However, the task says to look at bowl under the desklamp. I should have looked for the desklamp then looked for the bowl. I noticed that the desklamp was found on desk 1. In the next trial, I will go to desk 1, find the lamp, then look for the bowl under the desklamp.
distillers/self_reflection.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from deciders.utils import get_completion
2
+
3
+ from typing import List, Dict, Any
4
+ from loguru import logger
5
+ import random
6
+ import json
7
+ class RefletionGenerator():
8
+ def __init__(self,logfile="",args=None):
9
+ self.args = args
10
+ with open("./distillers/reflexion_few_shot_examples.txt", 'r') as f:
11
+ self.FEW_SHOT_EXAMPLES = f.read()
12
+ if logfile:
13
+ # logger.remove()
14
+ logger.add(logfile, colorize=True, enqueue=True, filter=lambda x: '[Reflexion Memory]' in x['message'])
15
+
16
+ def generate_from_file(self, file_path,max_step_num=200):
17
+ mem = []
18
+ with open(file_path, 'r') as infile:
19
+ data = json.load(infile)
20
+ for traj in data:
21
+ traj_text = traj[0]['game_description']
22
+ traj_text += traj[0]['goal_description']
23
+ for transition in traj[-max_step_num:]:
24
+ traj_text += transition['observation']
25
+ traj_text += f"Action: {transition['action']}"
26
+ reflection = self.generate(traj_text, mem, max_len_mem=5)
27
+ mem.append(reflection)
28
+ return mem
29
+
30
+ def _generate_reflection_query(self, traj, memory):
31
+ """Allows the Agent to reflect upon a past experience."""
32
+ query: str = f"""You will be given the history of a past experience in which you were placed in an environment and given a task to complete. You were unsuccessful in completing the task. Do not summarize your environment, but rather think about the strategy and path you took to attempt to complete the task. Think step by step what mistakes you made leading the failure. Then devise a concise, new plan of action that accounts for your mistake with reference to specific actions that you should have taken. For example, if you tried A and B but forgot C, then you should reason that the forgetting C is the key mistake. After that you devise a plan to achieve C with environment-specific actions. You remind yourself the plan your will take in the next trail and Give your plan after "Plan". Here are two examples:
33
+
34
+ {self.FEW_SHOT_EXAMPLES}
35
+
36
+ {traj}"""
37
+ if len(memory) > 0:
38
+ query += '\n\nPlans from past attempts:\n'
39
+ for i, m in enumerate(memory):
40
+ query += f'Trial #{i}: {m}\n'
41
+
42
+ query += '\n\nNew plan:'
43
+ return query
44
+
45
+ def generate(self, traj, memory, max_len_mem=5):
46
+ if len(memory)> max_len_mem:
47
+ reflection_query = self._generate_reflection_query(traj, memory[-max_len_mem:])
48
+ else:
49
+ reflection_query = self._generate_reflection_query(traj, memory)
50
+ reflection = get_completion(reflection_query, engine=self.args.gpt_version)
51
+ logger.info(f'[Reflexion Memory]The reflexion prompt is: {reflection_query}.')
52
+ logger.info(f'[Reflexion Memory]The reflexion response is: {reflection}.')
53
+ return reflection
distillers/traj_prompt_summarizer.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from deciders.utils import get_completion
3
+ import json
4
+ class TrajPromptSummarizer():
5
+ def __init__(self,args=None):
6
+ self.args = args
7
+ with open("./distillers/traj_summary_few_shot_examples.txt", 'r') as f:
8
+ self.FEW_SHOT_EXAMPLES = f.read()
9
+
10
+ def generate_from_file(self, file_path,max_step_num=200):
11
+ mem = []
12
+ with open(file_path, 'r') as infile:
13
+ data = json.load(infile)
14
+ for traj in data:
15
+ traj_text = traj[0]['game_description']
16
+ traj_text += traj[0]['goal_description']
17
+ for transition in traj[-max_step_num:]:
18
+ traj_text += transition['observation']
19
+ traj_text += f"> {transition['action']}"
20
+ traj_text += f"Your performance is: {transition['cum_reward']}"
21
+ reflection = self.generate(traj_text, mem, max_len_mem=5)
22
+ mem.append(reflection)
23
+ return mem
24
+
25
+ def _generate_summary_query(self, traj, memory):
26
+ """Allows the Agent to reflect upon a past experience."""
27
+ query: str = f"""You will be given the history of a past experience in which you were placed in an environment and given a task to complete. Summarize your trajectory and reasoning the relation between your policy and the obtained result. Here are two examples:
28
+
29
+ {self.FEW_SHOT_EXAMPLES}
30
+
31
+ {traj}"""
32
+ if len(memory) > 0:
33
+ query += '\n\nPlans from past attempts:\n'
34
+ for i, m in enumerate(memory):
35
+ query += f'Trial #{i}: {m}\n'
36
+
37
+ query += '\n\nSummary:'
38
+ return query
39
+
40
+ def generate(self, traj, memory, max_len_mem=5):
41
+ if len(memory)> max_len_mem:
42
+ reflection_query = self._generate_summary_query(traj, memory[-max_len_mem:])
43
+ else:
44
+ reflection_query = self._generate_summary_query(traj, memory)
45
+ reflection = get_completion(reflection_query, engine=self.args.gpt_version)
46
+ return reflection
distillers/traj_summary_few_shot_examples.txt ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ You are in the middle of a room. Looking quickly around you, you see a cabinet 10, a cabinet 9, a cabinet 8, a cabinet 7, a cabinet 6, a cabinet 5, a cabinet 4, a cabinet 3, a cabinet 2, a cabinet 1, a coffeemachine 1, a countertop 2, a countertop 1, a diningtable 2, a diningtable 1, a drawer 3, a drawer 2, a drawer 1, a fridge 1, a garbagecan 1, a microwave 1, a sinkbasin 1, a stoveburner 4, a stoveburner 3, a stoveburner 2, a stoveburner 1, and a toaster 1.
2
+ Your task is to: heat some mug and put it in coffeemachine.
3
+ > think: To solve the task, I need to find and take a mug, then heat it with stoveburner, then put it in coffeemachine.
4
+ Ok.
5
+ > look
6
+ You are in the middle of a room. Looking quickly around you, you see nothing.
7
+ > think: First I need to find a mug. A mug is more likely to appear in countertop (1-2), diningtable (1-2), drawer (1-3), fridge (1), sinkbasin (1), stoveburner (1-4), toaster (1). I can check one by one, starting with countertop 1.
8
+ Ok.
9
+ > look
10
+ You are in the middle of a room. Looking quickly around you, you see nothing.
11
+ > go to countertop 1
12
+ On the countertop 1, you see a butterknife 1, a cup 1, a fork 2, and a mug 1.
13
+ > think: Now I find a mug (1). Next, I need to take it.
14
+ Ok.
15
+ > take mug 1 from countertop 1
16
+ You pick up the mug 1 from the countertop 1.
17
+ > think: Now I take a mug (1). Next, I need to heat it with stoveburner (1-4).
18
+ Ok.
19
+ > go to stoveburner 1
20
+ On the stoveburner 1, you see a pan 2.
21
+ > examine stoveburner 1
22
+ On the stoveburner 1, you see a pan 2.
23
+ > examine stoveburner 1
24
+ On the stoveburner 1, you see a pan 2.
25
+ > examine stoveburner 1
26
+ On the stoveburner 1, you see a pan 2.
27
+ > examine stoveburner 1
28
+ On the stoveburner 1, you see a pan 2.
29
+ STATUS: FAIL
30
+ Summary: I continually examined stoveburner 1 instead of heating mug 1 with stoveburner 1. This makes me stuck in a loop and fail.
31
+
32
+ You are in the middle of a room. Looking quickly around you, you see a bed 1, a desk 2, a desk 1, a drawer 6, a drawer 5, a drawer 4, a drawer 3, a drawer 2, a drawer 1, a garbagecan 1, a laundryhamper 1, a safe 1, a shelf 6, a shelf 5, a shelf 4, a shelf 3, a shelf 2, and a shelf 1.
33
+ Your task is to: look at bowl under the desklamp.
34
+ > think: To solve the task, I need to find and take a bowl, then find and use a desklamp.
35
+ OK.
36
+ > think: First I need to find a bowl. A bowl is more likely to appear in drawer (1-6), desk (1-2), shelf (1-6), garbagecan (1), laundryhamper (1). I can check one by one, starting with drawer 1.
37
+ OK.
38
+ > go to drawer 1
39
+ The drawer 1 is closed.
40
+ > open drawer 1
41
+ You open the drawer 1. The drawer 1 is open. In it, you see a cellphone 2.
42
+ > go to drawer 2
43
+ On the drawer 2, you see a keychain 1.
44
+ > go to drawer 3
45
+ Nothing happens.
46
+ > go to drawer 4
47
+ The drawer 4 is closed.
48
+ > open drawer 4
49
+ You open the drawer 4. The drawer 4 is open. In it, you see nothing.
50
+ > go to drawer 5
51
+ On the drawer 5, you see nothing.
52
+ > go to drawer 6
53
+ The drawer 6 is closed.
54
+ > open drawer 6
55
+ You open the drawer 6. The drawer 6 is open. In it, you see a book 2, and a pen 2.
56
+ > go to desk 1
57
+ On the desk 1, you see a alarmclock 1, a desklamp 1, a mug 1, and a pencil 1.
58
+ > think: Now I find a desklamp (1). Next, I need to use it.
59
+ OK.
60
+ > use desklamp 1
61
+ You turn on the desklamp 1.
62
+ >
63
+ Nothing happens.
64
+ > go to desk 2
65
+ On the desk 2, you see a bowl 2, a bowl 1, a cd 2, a cd 1, and a creditcard 3.
66
+ > think: Now I find a bowl (2). Next, I need to take it.
67
+ OK.
68
+ > take bowl 2 from desk 2
69
+ You pick up the bowl 2 from the desk 2.
70
+ >
71
+ Nothing happens.
72
+ >
73
+ Nothing happens.
74
+ STATUS: FAIL
75
+ Summary: I try to find a bowl then find and use a desklamp. This is inconsistent to the task which require looking at
76
+ bowl under the desklamp. Thus I fail.
draw_overall_performance.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import matplotlib.pyplot as plt
3
+
4
+ # Load the CSV data
5
+ data = pd.read_csv("performance_data.csv")
6
+
7
+ # Group games by type
8
+ game_types = {
9
+ "Classic Control": ["Acrobot-v1", "CartPole-v0", "MountainCar-v0"],
10
+ "Box 2D": ["LunarLander-v2"],
11
+ "Toy Text": ["Taxi-v3", "CliffWalking-v0", "Blackjack-v1"]
12
+ }
13
+
14
+ for game_type, games in game_types.items():
15
+ fig, axs = plt.subplots(1, len(games), figsize=(12 * len(games), 6))
16
+ fig.suptitle(f"Performance Plot: {game_type}", fontsize=28, fontname="Times New Roman")
17
+
18
+ if len(games) == 1:
19
+ axs = [axs]
20
+
21
+ handles, labels = [], []
22
+
23
+ for idx, game in enumerate(games):
24
+ # Filter data to get information for the current game (in the loop)
25
+ game_data = data[data["game"] == game]
26
+
27
+ axs[idx].set_title(game, fontsize=20, fontname="Times New Roman")
28
+ axs[idx].set_xlabel("Levels", fontsize=16, fontname="Times New Roman")
29
+ if idx == 0:
30
+ axs[idx].set_ylabel("Scores", fontsize=16, fontname="Times New Roman")
31
+
32
+ for index, row in game_data.iterrows():
33
+ decider_name = row["decider_name"]
34
+ levels = ["l1", "l2", "l3", "l4", "l5"]
35
+ scores = row[levels].values.tolist()
36
+ lines = axs[idx].plot(levels, scores, "-o", label=decider_name)
37
+ # Grab the handle and label for creating a global legend
38
+ handles.append(lines[0])
39
+ labels.append(decider_name)
40
+
41
+ # Eliminate duplicate labels and handles
42
+ unique_labels = []
43
+ unique_handles = []
44
+ for handle, label in zip(handles, labels):
45
+ if label not in unique_labels:
46
+ unique_labels.append(label)
47
+ unique_handles.append(handle)
48
+
49
+ # Add a legend at the bottom middle of the figure
50
+ fig.legend(
51
+ unique_handles,
52
+ unique_labels,
53
+ loc="lower center",
54
+ ncol=4, prop={'size': 18}
55
+ )
56
+
57
+ # Adjust layout to accommodate the legend and prevent cropping
58
+
59
+ plt.savefig("./vis/" + game_type + ".png", dpi=300)
environment.yml ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: llm-gym
2
+ channels:
3
+ - conda-forge
4
+ - defaults
5
+ dependencies:
6
+ - _libgcc_mutex=0.1=main
7
+ - _openmp_mutex=5.1=1_gnu
8
+ - aiosignal=1.2.0=pyhd3eb1b0_0
9
+ - asttokens=2.0.5=pyhd3eb1b0_0
10
+ - async-timeout=4.0.2=py38h06a4308_0
11
+ - attrs=22.1.0=py38h06a4308_0
12
+ - backcall=0.2.0=pyhd3eb1b0_0
13
+ - blas=1.0=mkl
14
+ - brotlipy=0.7.0=py38h27cfd23_1003
15
+ - ca-certificates=2023.08.22=h06a4308_0
16
+ - cached-property=1.5.2=py_0
17
+ - certifi=2023.7.22=py38h06a4308_0
18
+ - cffi=1.15.1=py38h5eee18b_3
19
+ - chardet=4.0.0=py38h06a4308_1003
20
+ - comm=0.1.2=py38h06a4308_0
21
+ - cryptography=39.0.1=py38h9ce1e76_2
22
+ - cudatoolkit=11.3.1=h2bc3f7f_2
23
+ - debugpy=1.5.1=py38h295c915_0
24
+ - executing=0.8.3=pyhd3eb1b0_0
25
+ - frozenlist=1.3.3=py38h5eee18b_0
26
+ - hdf5=1.10.6=h3ffc7dd_1
27
+ - idna=3.4=py38h06a4308_0
28
+ - importlib_metadata=6.0.0=hd3eb1b0_0
29
+ - intel-openmp=2023.1.0=hdb19cb5_46305
30
+ - ipykernel=6.19.2=py38hb070fc8_0
31
+ - ipython=8.12.0=py38h06a4308_0
32
+ - jedi=0.18.1=py38h06a4308_1
33
+ - jupyter_client=8.1.0=py38h06a4308_0
34
+ - jupyter_core=5.3.0=py38h06a4308_0
35
+ - ld_impl_linux-64=2.38=h1181459_1
36
+ - libffi=3.4.4=h6a678d5_0
37
+ - libgcc-ng=11.2.0=h1234567_1
38
+ - libgfortran-ng=11.2.0=h00389a5_1
39
+ - libgfortran5=11.2.0=h1234567_1
40
+ - libgomp=11.2.0=h1234567_1
41
+ - libllvm14=14.0.6=hdb19cb5_3
42
+ - libprotobuf=3.20.3=he621ea3_0
43
+ - libsodium=1.0.18=h7b6447c_0
44
+ - libstdcxx-ng=11.2.0=h1234567_1
45
+ - loguru=0.7.1=py38h578d9bd_0
46
+ - matplotlib-inline=0.1.6=py38h06a4308_0
47
+ - mkl=2023.1.0=h6d00ec8_46342
48
+ - mkl-service=2.4.0=py38h5eee18b_1
49
+ - mkl_fft=1.3.6=py38h417a72b_1
50
+ - mkl_random=1.2.2=py38h417a72b_1
51
+ - ncurses=6.4=h6a678d5_0
52
+ - nest-asyncio=1.5.6=py38h06a4308_0
53
+ - numpy-base=1.24.3=py38h060ed82_1
54
+ - openssl=3.0.10=h7f8727e_2
55
+ - packaging=23.0=py38h06a4308_0
56
+ - parso=0.8.3=pyhd3eb1b0_0
57
+ - pcre=8.45=h295c915_0
58
+ - pexpect=4.8.0=pyhd3eb1b0_3
59
+ - pickleshare=0.7.5=pyhd3eb1b0_1003
60
+ - pip=23.2.1=py38h06a4308_0
61
+ - platformdirs=2.5.2=py38h06a4308_0
62
+ - prompt-toolkit=3.0.36=py38h06a4308_0
63
+ - psutil=5.9.0=py38h5eee18b_0
64
+ - ptyprocess=0.7.0=pyhd3eb1b0_2
65
+ - pure_eval=0.2.2=pyhd3eb1b0_0
66
+ - pycparser=2.21=pyhd3eb1b0_0
67
+ - pygments=2.15.1=py38h06a4308_1
68
+ - pyopenssl=23.0.0=py38h06a4308_0
69
+ - pysocks=1.7.1=py38h06a4308_0
70
+ - python=3.8.16=h955ad1f_4
71
+ - python-dateutil=2.8.2=pyhd3eb1b0_0
72
+ - python_abi=3.8=2_cp38
73
+ - pyyaml=6.0=py38h0a891b7_4
74
+ - pyzmq=25.1.0=py38h6a678d5_0
75
+ - readline=8.2=h5eee18b_0
76
+ - setuptools=67.8.0=py38h06a4308_0
77
+ - six=1.16.0=pyhd3eb1b0_1
78
+ - sqlite=3.41.2=h5eee18b_0
79
+ - stack_data=0.2.0=pyhd3eb1b0_0
80
+ - tbb=2021.8.0=hdb19cb5_0
81
+ - tk=8.6.12=h1ccaba5_0
82
+ - tornado=6.2=py38h5eee18b_0
83
+ - traitlets=5.7.1=py38h06a4308_0
84
+ - typing_extensions=4.6.3=py38h06a4308_0
85
+ - wcwidth=0.2.5=pyhd3eb1b0_0
86
+ - wheel=0.38.4=py38h06a4308_0
87
+ - xz=5.4.2=h5eee18b_0
88
+ - yaml=0.2.5=h7b6447c_0
89
+ - zeromq=4.3.4=h2531618_0
90
+ - zlib=1.2.13=h5eee18b_0
91
+ - pip:
92
+ - absl-py==1.4.0
93
+ - aiohttp==3.8.4
94
+ - ale-py==0.8.1
95
+ - annotated-types==0.5.0
96
+ - appdirs==1.4.4
97
+ - beautifulsoup4==4.12.2
98
+ - box2d-py==2.3.5
99
+ - cachetools==5.3.1
100
+ - cchardet==2.1.7
101
+ - charset-normalizer==3.1.0
102
+ - click==8.1.3
103
+ - cloudpickle==2.2.1
104
+ - contourpy==1.1.0
105
+ - cycler==0.11.0
106
+ - cython==3.0.1
107
+ - dataclasses-json==0.5.14
108
+ - decorator==4.4.2
109
+ - docker-pycreds==0.4.0
110
+ - fasteners==0.18
111
+ - filelock==3.12.2
112
+ - fonttools==4.40.0
113
+ - fsspec==2023.6.0
114
+ - gitdb==4.0.10
115
+ - gitpython==3.1.31
116
+ - glfw==2.6.2
117
+ - google-auth==2.21.0
118
+ - google-auth-oauthlib==1.0.0
119
+ - greenlet==2.0.2
120
+ - grpcio==1.56.0
121
+ - gym==0.26.2
122
+ - gym-notices==0.0.8
123
+ - h5py==3.9.0
124
+ - huggingface-hub==0.15.1
125
+ - imageio==2.31.2
126
+ - imageio-ffmpeg==0.4.8
127
+ - importlib-metadata==6.6.0
128
+ - importlib-resources==5.12.0
129
+ - iniconfig==2.0.0
130
+ - kiwisolver==1.4.4
131
+ - langchain==0.0.284
132
+ - langsmith==0.0.33
133
+ - llvmlite==0.40.1
134
+ - lz4==4.3.2
135
+ - markdown==3.4.3
136
+ - markupsafe==2.1.1
137
+ - marshmallow==3.20.1
138
+ - matplotlib==3.7.1
139
+ - moviepy==1.0.3
140
+ - mujoco==2.2.0
141
+ - mujoco-py==2.1.2.14
142
+ - multidict==6.0.4
143
+ - numba==0.57.1
144
+ - numexpr==2.8.5
145
+ - numpy==1.24.4
146
+ - oauthlib==3.2.2
147
+ - openai==0.27.8
148
+ - opencv-python==4.8.0.76
149
+ - pathtools==0.1.2
150
+ - pillow==9.5.0
151
+ - pluggy==1.2.0
152
+ - proglog==0.1.10
153
+ - protobuf==3.19.6
154
+ - py==1.11.0
155
+ - pyasn1==0.5.0
156
+ - pyasn1-modules==0.3.0
157
+ - pydantic==2.3.0
158
+ - pydantic-core==2.6.3
159
+ - pygame==2.1.0
160
+ - pyopengl==3.1.7
161
+ - pyparsing==3.0.9
162
+ - pytest==7.0.1
163
+ - regex==2023.6.3
164
+ - requests==2.31.0
165
+ - requests-oauthlib==1.3.1
166
+ - rsa==4.9
167
+ - safetensors==0.3.1
168
+ - sentry-sdk==1.26.0
169
+ - setproctitle==1.3.2
170
+ - smmap==5.0.0
171
+ - soupsieve==2.4.1
172
+ - sqlalchemy==2.0.20
173
+ - swig==4.1.1
174
+ - tenacity==8.2.3
175
+ - tensorboard==2.14.0
176
+ - tensorboard-data-server==0.7.1
177
+ - tianshou==0.4.10
178
+ - tokenizers==0.13.3
179
+ # - torch==1.12.0+cu113
180
+ # - torchaudio==0.12.0+cu113
181
+ # - torchvision==0.13.0+cu113
182
+ - tqdm==4.65.0
183
+ - transformers==4.30.2
184
+ - typing==3.7.4.3
185
+ - typing-extensions==4.7.1
186
+ - typing-inspect==0.9.0
187
+ - urllib3
188
+ - v==1
189
+ - wandb==0.15.4
190
+ - werkzeug==2.3.6
191
+ - yarl==1.9.2
192
+ - zipp==3.15.0
193
+ - aquarel==0.0.5
envs/__init__.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base_env import BaseEnv, SettableStateEnv
2
+ from .classic_control import cartpole_translator, cartpole_policies
3
+ from .classic_control import acrobot_translator, acrobot_policies
4
+ from .classic_control import mountaincar_translator, mountaincar_policies
5
+ from .classic_control import mountaincarContinuous_translator,mountaincarContinuous_policies
6
+
7
+ from .box2d import LunarLander_translator, LunarLander_policies
8
+
9
+ from .toy_text import blackjack_translator, blackjack_policies
10
+ from .toy_text import taxi_translator, taxi_policies
11
+ from .toy_text import cliffwalking_translator, cliffwalking_policies
12
+ from .toy_text import frozenlake_translator, frozenlake_policies
13
+
14
+ REGISTRY = {}
15
+ REGISTRY["sampling_wrapper"] = SettableStateEnv
16
+ REGISTRY["base_env"] = BaseEnv
17
+ REGISTRY["cart_init_translator"] = cartpole_translator.GameDescriber
18
+ REGISTRY["cart_basic_translator"] = cartpole_translator.BasicStateSequenceTranslator
19
+ REGISTRY["acrobot_init_translator"] = acrobot_translator.GameDescriber
20
+ REGISTRY["acrobot_basic_translator"] = acrobot_translator.BasicStateSequenceTranslator
21
+ REGISTRY["mountaincar_init_translator"] = mountaincar_translator.GameDescriber
22
+ REGISTRY["mountaincar_basic_translator"] = mountaincar_translator.BasicStateSequenceTranslator
23
+
24
+ REGISTRY["cart_policies"] = [cartpole_policies.dedicated_1_policy, cartpole_policies.dedicated_2_policy, cartpole_policies.pseudo_random_policy, cartpole_policies.real_random_policy]
25
+ REGISTRY["acrobot_policies"] = [acrobot_policies.dedicated_1_policy, acrobot_policies.dedicated_2_policy, acrobot_policies.dedicated_3_policy, acrobot_policies.pseudo_random_policy, acrobot_policies.real_random_policy]
26
+ REGISTRY["mountaincar_policies"] = [mountaincar_policies.dedicated_1_policy, mountaincar_policies.dedicated_2_policy, mountaincar_policies.dedicated_3_policy, mountaincar_policies.pseudo_random_policy, mountaincar_policies.real_random_policy]
27
+
28
+ REGISTRY["lunarLander_init_translator"] = LunarLander_translator.GameDescriber
29
+ REGISTRY["lunarLander_basic_translator"] = LunarLander_translator.BasicStateSequenceTranslator
30
+ REGISTRY["lunarLander_policies"] = [LunarLander_policies.dedicated_1_policy, LunarLander_policies.dedicated_2_policy, LunarLander_policies.dedicated_3_policy,LunarLander_policies.dedicated_4_policy, LunarLander_policies.pseudo_random_policy, LunarLander_policies.real_random_policy]
31
+
32
+ REGISTRY["blackjack_init_translator"] = blackjack_translator.GameDescriber
33
+ REGISTRY["blackjack_basic_translator"] = blackjack_translator.BasicStateSequenceTranslator
34
+ REGISTRY["blackjack_policies"] = [blackjack_policies.dedicated_1_policy, blackjack_policies.dedicated_2_policy, blackjack_policies.pseudo_random_policy, blackjack_policies.real_random_policy]
35
+
36
+ REGISTRY["taxi_init_translator"] = taxi_translator.GameDescriber
37
+ REGISTRY["taxi_basic_translator"] = taxi_translator.BasicStateSequenceTranslator
38
+ REGISTRY["taxi_policies"] = [taxi_policies.dedicated_1_policy, taxi_policies.dedicated_2_policy, taxi_policies.dedicated_3_policy, taxi_policies.dedicated_4_policy, taxi_policies.dedicated_5_policy, taxi_policies.dedicated_6_policy, taxi_policies.pseudo_random_policy, taxi_policies.real_random_policy]
39
+
40
+ REGISTRY["cliffwalking_init_translator"] = cliffwalking_translator.GameDescriber
41
+ REGISTRY["cliffwalking_basic_translator"] = cliffwalking_translator.BasicStateSequenceTranslator
42
+ REGISTRY["cliffwalking_policies"] = [cliffwalking_policies.dedicated_1_policy, cliffwalking_policies.dedicated_2_policy, cliffwalking_policies.dedicated_3_policy, cliffwalking_policies.dedicated_4_policy, cliffwalking_policies.pseudo_random_policy, cliffwalking_policies.real_random_policy]
43
+
44
+ REGISTRY["frozenlake_init_translator"] = frozenlake_translator.GameDescriber
45
+ REGISTRY["frozenlake_basic_translator"] = frozenlake_translator.BasicStateSequenceTranslator
46
+ REGISTRY["frozenlake_policies"] = [frozenlake_policies.dedicated_1_policy, frozenlake_policies.dedicated_2_policy, frozenlake_policies.dedicated_3_policy, frozenlake_policies.dedicated_4_policy, frozenlake_policies.pseudo_random_policy, frozenlake_policies.real_random_policy]
47
+
48
+
49
+ REGISTRY["mountaincarContinuous_init_translator"] = mountaincarContinuous_translator.GameDescriber
50
+ REGISTRY["mountaincarContinuous_basic_translator"] = mountaincarContinuous_translator.BasicStateSequenceTranslator
51
+ REGISTRY["mountaincarContinuous_policies"] = [mountaincarContinuous_policies.pseudo_random_policy, mountaincarContinuous_policies.real_random_policy]
envs/base_env.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file contains functions for interacting with the CartPole environment
2
+
3
+ import gym
4
+
5
+ class SettableStateEnv(gym.Wrapper):
6
+ def __init__(self, env):
7
+ super().__init__(env)
8
+ self.env = env
9
+
10
+ def set_state(self, state):
11
+ self.env.state = state
12
+ self.env.steps_beyond_terminated = None
13
+
14
+ class BaseEnv(gym.Wrapper):
15
+ def __init__(self, env, translator):
16
+ super().__init__(env)
17
+ self.translator = translator
18
+ self.env_name = super().spec.id
19
+ self.transition_data = {}
20
+ self.game_description = self.get_game_description()
21
+ self.goal_description = self.get_goal_description()
22
+ self.action_description = self.get_action_description()
23
+ self.action_desc_dict = self.get_action_desc_dict()
24
+ self.reward_desc_dict = self.get_reward_desc_dict()
25
+
26
+ def reset(self, **kwargs):
27
+ state, _ = super().reset(**kwargs)
28
+ self.transition_data['state'] = state
29
+ self.translator.obtain(self.transition_data)
30
+ summary, future_summary = self.translator.translate()
31
+ info = {
32
+ 'future_summary': future_summary
33
+ }
34
+ self.state = state
35
+ return summary, info
36
+
37
+ def step(self, action):
38
+ potential_next_state = self.get_potential_next_state(action)
39
+ state, reward, terminated, _, info = super().step(action)
40
+ self.transition_data['action'] = action
41
+ self.transition_data['next_state'] = state
42
+ self.transition_data['reward'] = reward
43
+ self.transition_data['terminated'] = terminated
44
+ self.translator.update(self.transition_data)
45
+ self.transition_data = {}
46
+ self.transition_data['state'] = state
47
+ self.translator.obtain(self.transition_data)
48
+ summary, future_summary = self.translator.translate()
49
+ info = {
50
+ 'future_summary': future_summary,
51
+ 'potential_state': potential_next_state
52
+ }
53
+ return summary, reward, terminated, _, info
54
+
55
+
56
+ def step_llm(self, action):
57
+ potential_next_state = self.get_potential_next_state(action)
58
+ if "Continuous" in self.env_name:
59
+ state, reward, terminated, _, info = super().step(action)
60
+ else:
61
+ state, reward, terminated, _, info = super().step(action-1)
62
+ self.transition_data['action'] = action
63
+ self.transition_data['next_state'] = state
64
+ self.transition_data['reward'] = reward
65
+ self.transition_data['terminated'] = terminated
66
+ self.translator.update(self.transition_data)
67
+ self.transition_data = {}
68
+ self.transition_data['state'] = state
69
+ self.translator.obtain(self.transition_data)
70
+ self.state = state
71
+ summary, future_summary = self.translator.translate()
72
+ info = {
73
+ 'future_summary': future_summary,
74
+ 'potential_state': potential_next_state,
75
+ }
76
+ return summary, reward, terminated, _, info
77
+
78
+ def get_terminate_state(self, episode_len, max_episode_len):
79
+ return self.translator.translate_terminate_state(self.state, episode_len, max_episode_len)
80
+
81
+ def get_game_description(self,):
82
+ return self.translator.describe_game()
83
+
84
+ def get_goal_description(self,):
85
+ return self.translator.describe_goal()
86
+
87
+ def get_action_description(self,):
88
+ return self.translator.describe_action()
89
+
90
+ def get_action_desc_dict(self,):
91
+ return self.translator.get_action_desc_dict()
92
+
93
+ def get_reward_desc_dict(self,):
94
+ return self.translator.get_reward_desc_dict()
95
+
96
+ def get_potential_next_state(self, action):
97
+ return self.translator.translate_potential_next_state(self.state, action)
envs/box2d/LunarLander_policies.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ def dedicated_1_policy(state, pre_action=1):
3
+ def get_description():
4
+ return "Always select action 1 which do nothing"
5
+ dedicated_1_policy.description = get_description()
6
+ return 1
7
+
8
+ def dedicated_2_policy(state, pre_action=1):
9
+ def get_description():
10
+ return "Always select action 2 which fire left engine"
11
+ dedicated_2_policy.description = get_description()
12
+ return 2
13
+
14
+ def dedicated_3_policy(state, pre_action=1):
15
+ def get_description():
16
+ return "Always select action 3 which fire main engine"
17
+ dedicated_3_policy.description = get_description()
18
+ return 3
19
+
20
+ def dedicated_4_policy(state, pre_action=1):
21
+ def get_description():
22
+ return "Always select action 4 which fire right engine"
23
+ dedicated_4_policy.description = get_description()
24
+ return 4
25
+
26
+ def pseudo_random_policy(state, pre_action):
27
+ def get_description():
28
+ return "Select action 1, 2, 3, 4 alternatively which do nothing, fire left engine, fire main engine, and fire right engine alternatively"
29
+ pseudo_random_policy.description = get_description()
30
+ return pre_action%4+1
31
+
32
+ def real_random_policy(state,pre_action=0):
33
+ def get_description():
34
+ return "Select action with a random policy"
35
+ real_random_policy.description = get_description()
36
+ return np.random.choice([1, 2, 3, 4])
envs/box2d/LunarLander_translator.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # [Translator classes and functions for Lunar Lander environment]
2
+
3
+ class BasicLevelTranslator:
4
+ def __init__(self,):
5
+ pass
6
+
7
+ def translate(self, state):
8
+ x, y, x_dot, y_dot, angle, angular_velocity, left_leg_contact, right_leg_contact = state
9
+ left_contact_info = "in contact" if left_leg_contact else "not in contact"
10
+ right_contact_info = "in contact" if right_leg_contact else "not in contact"
11
+ return f"The lander is at position ({x:.2f}, {y:.2f}), the horizontal speed of movement is {x_dot:.2f}, " \
12
+ f"the vertical velocity speed of movement is {y_dot:.2f}. The angle is {angle:.2f} radians, and it's rotating at {angular_velocity:.2f} radians per second. The left leg is {left_contact_info} with ground. The right leg is {right_contact_info} with ground."
13
+
14
+ class GameDescriber:
15
+ def __init__(self, args):
16
+ self.is_only_local_obs = args.is_only_local_obs == 1
17
+ self.max_episode_len = args.max_episode_len
18
+ self.action_desc_dict = {
19
+ }
20
+ self.reward_desc_dict = {
21
+ }
22
+
23
+ def describe_goal(self):
24
+ return "The goal is to successfully land the lander on the landing pad which is at position (0, 0) while avoiding crash."
25
+
26
+ def translate_terminate_state(self, state, episode_len, max_episode_len):
27
+ return ""
28
+
29
+ def translate_potential_next_state(self, state, action):
30
+ return ""
31
+
32
+ def describe_game(self):
33
+ return "In the Lunar Lander game, you control a lander that is descending towards " \
34
+ "the landing pad. The goal is to successfully land the lander on the landing pad " \
35
+ "while avoiding crash. Please note that the lander is affected by gravity, and the lander starts at the " \
36
+ "top center of the viewport with a random initial force applied to its center of mass. " \
37
+ "Be careful to balance the engine to slow down your descent " \
38
+ "and land gently. If you land too quickly or crash into the landing pad, the game will " \
39
+ "end, and you will be punished."
40
+
41
+ def describe_action(self):
42
+ return "Your Next Move: \n Please choose an action. Type '1' to do noting, '2' to fire left engine and make lander move to right, '3' to fire main engine and make lander move to up, " \
43
+ "or '4' to fire right engine and make lander move to left. Ensure you only provide the action number from the valid action list, i.e., [1, 2, 3, 4]."
44
+
45
+
46
+ class BasicStateSequenceTranslator(BasicLevelTranslator):
47
+ def translate(self, infos, is_current=False):
48
+ descriptions = []
49
+ if is_current:
50
+ state_desc = BasicLevelTranslator().translate(infos[-1]['state'])
51
+ return state_desc
52
+ for i, info in enumerate(infos):
53
+ assert 'state' in info, "info should contain state information"
54
+
55
+ state_desc = BasicLevelTranslator().translate(info['state'])
56
+ if info['action'] == 1:
57
+ action_desc = f"Take Action: 'Do Noting'"
58
+ elif info['action'] == 2:
59
+ action_desc = f"Take Action: 'Fire left engine'"
60
+ elif info['action'] == 3:
61
+ action_desc = f"Take Action: 'Fire main engine'"
62
+ else:
63
+ action_desc = f"Take Action: 'Fire right engine'"
64
+ reward_desc = f"Result: Reward of {info['reward']}, "
65
+ next_state_desc = BasicLevelTranslator().translate(info['next_state'])
66
+ descriptions.append(f"{state_desc}.\n {action_desc} \n {reward_desc} \n Transit to {next_state_desc}")
67
+ return descriptions
envs/box2d/__init__.py ADDED
File without changes
envs/box2d/few_shot_examples/lunarlander_l2.json ADDED
The diff for this file is too large to render. See raw diff
 
envs/box2d/few_shot_examples/lunarlander_l4.json ADDED
The diff for this file is too large to render. See raw diff
 
envs/classic_control/__init__.py ADDED
File without changes
envs/classic_control/acrobot_policies.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ # https://colab.research.google.com/drive/1DdWsGi10232orUv-reY4wsTmT0VMoHaX?usp=sharing#scrollTo=4OfVmDKk7XvG
4
+ # LLMs bias on 0 so make the actions 1, 2 and 3 instead.
5
+
6
+ def dedicated_1_policy(state, pre_action=1):
7
+ def get_description():
8
+ return "Always select action 1"
9
+ dedicated_0_policy.description = get_description()
10
+ return 1
11
+
12
+ def dedicated_2_policy(state, pre_action=1):
13
+ def get_description():
14
+ return "Always select action 2"
15
+ dedicated_2_policy.description = get_description()
16
+ return 2
17
+
18
+ def dedicated_3_policy(state, pre_action=1):
19
+ def get_description():
20
+ return "Always select action 3"
21
+ dedicated_3_policy.description = get_description()
22
+ return 3
23
+
24
+ def pseudo_random_policy(state, pre_action):
25
+ def get_description():
26
+ return "Select action 1, 2, and 3 alternatively"
27
+ pseudo_random_policy.description = get_description()
28
+ return pre_action % 3 + 1
29
+
30
+
31
+ def real_random_policy(state, pre_action=1):
32
+ def get_description():
33
+ return "Select action with a random policy"
34
+ real_random_policy.description = get_description()
35
+ return np.random.choice([1, 2, 3])
36
+
envs/classic_control/acrobot_translator.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ class BasicLevelTranslator:
4
+ def __init__(self):
5
+ pass
6
+
7
+ def translate(self, state):
8
+ cos_theta1, sin_theta1, cos_theta2, sin_theta2, theta1_dot, theta2_dot = state
9
+ theta1_direction = "clockwise" if theta1_dot > 0 else "counterclockwise"
10
+ theta2_direction = "clockwise" if theta2_dot > 0 else "counterclockwise"
11
+ theta1 = math.atan(sin_theta1 / (cos_theta1+1e-6))
12
+ theta2 = math.atan(sin_theta2 / (cos_theta2+1e-6))
13
+ res = (f"Link1: angle theta1 {theta1:.2f} radians, rotating {abs(theta1_dot):.2f} radians per second {theta1_direction}. "
14
+ f"Link2: angle theta2 {theta2:.2f} radians relative to Link1, rotating {abs(theta2_dot):.2f} radians per second {theta2_direction}.")
15
+ return res
16
+
17
+ class GameDescriber:
18
+ def __init__(self, args):
19
+ self.is_only_local_obs = args.is_only_local_obs == 1
20
+ self.max_episode_len = args.max_episode_len
21
+ self.action_desc_dict = {
22
+ }
23
+ self.reward_desc_dict = {
24
+ }
25
+
26
+ def describe_goal(self):
27
+ return "The goal is to apply torque on the actuator to swing the free end of the linear chain above the target height, which is constructed as: -cos(theta1) - cos(theta2 + theta1) > 1.0."
28
+
29
+ def translate_terminate_state(self, state, episode_len, max_episode_len):
30
+ return ""
31
+
32
+ def translate_potential_next_state(self, state, action):
33
+ return ""
34
+
35
+ def describe_game(self):
36
+ return ('''In the Acrobot game, there are two links connected by two joints. The first link is connected to a base, and your goal is to swing the free end of the second link above the target height by applying torques on the actuated joint. The task ends if one of the following occurs: 1. The free end reaches the target height, which is constructed as: -cos(theta1) - cos(theta2 + theta1) > 1.0; or 2. Decision time is greater than 200.''')
37
+
38
+ # https://colab.research.google.com/drive/1DdWsGi10232orUv-reY4wsTmT0VMoHaX?usp=sharing#scrollTo=4OfVmDKk7XvG
39
+ # LLMs bias on 0 so make the actions 1, 2 and 3 instead.
40
+ def describe_action(self):
41
+ return ("Your Next Move: \\n Please choose an action. Type '1' to apply -1 torque, '2' to apply 0 torque, or '3' to apply 1 torque. "
42
+ "Ensure you provide the action number from the valid action list, i.e., [1, 2, 3].")
43
+
44
+ class BasicStateSequenceTranslator(BasicLevelTranslator):
45
+ def translate(self, infos, is_current=False):
46
+ descriptions = []
47
+ if is_current:
48
+ state_desc = BasicLevelTranslator().translate(infos[-1]['state'])
49
+ return state_desc
50
+ for i, info in enumerate(infos):
51
+ assert 'state' in info, "info should contain state information"
52
+
53
+ state_desc = BasicLevelTranslator().translate(info['state'])
54
+ action_desc = f"Take Action: Apply {info['action'] - 2} torque on the actuated joint."
55
+ reward_desc = f"Result: Reward of {info['reward']}."
56
+ next_state_desc = BasicLevelTranslator().translate(info['next_state'])
57
+ descriptions.append(f"{state_desc}.\n {action_desc} \n {reward_desc} \n Transit to {next_state_desc}")
58
+ return descriptions
envs/classic_control/cartpole_policies.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ def dedicated_1_policy(state, pre_action=1):
3
+ def get_description():
4
+ return "Always select action 1"
5
+ dedicated_1_policy.description = get_description()
6
+ return 1
7
+
8
+ def dedicated_2_policy(state, pre_action=1):
9
+ def get_description():
10
+ return "Always select action 2"
11
+ dedicated_2_policy.description = get_description()
12
+ return 2
13
+
14
+ def pseudo_random_policy(state, pre_action):
15
+ def get_description():
16
+ return "Select action 1 and 2 alternatively"
17
+ pseudo_random_policy.description = get_description()
18
+ return pre_action%2 +1
19
+
20
+ def real_random_policy(state,pre_action=1):
21
+ def get_description():
22
+ return "Select action with a random policy"
23
+ real_random_policy.description = get_description()
24
+ return np.random.choice([1, 2])
25
+
envs/classic_control/cartpole_translator.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ class BasicLevelTranslator:
3
+ def __init__(self,):
4
+ pass
5
+
6
+ def translate(self, state):
7
+ cart_position, cart_velocity, pole_angle, pole_angular_velocity = state
8
+ cart_direction = "right" if cart_velocity > 0 else "left"
9
+ pole_direction = "right" if pole_angular_velocity > 0 else "left"
10
+ res = (f"The cart is positioned at {cart_position:.3f}, with a velocity of {abs(cart_velocity):.2f} towards the {cart_direction}. "
11
+ f"The pole is tilted at {abs(pole_angle):.2f} radians, rotating at {abs(pole_angular_velocity):.2f} radians per second towards the {pole_direction}.")
12
+ return res
13
+
14
+ class GameDescriber:
15
+ def __init__(self, args):
16
+ self.is_only_local_obs = args.is_only_local_obs == 1
17
+ self.max_episode_len = args.max_episode_len
18
+ self.action_desc_dict = {
19
+ }
20
+ self.reward_desc_dict = {
21
+ }
22
+
23
+ def describe_goal(self):
24
+ return "The goal is to keep the pole balanced upright for as long as possible."
25
+
26
+ def translate_terminate_state(self, state, episode_len, max_episode_len):
27
+ return ""
28
+
29
+ def translate_potential_next_state(self, state, action):
30
+ return ""
31
+
32
+ def describe_game(self):
33
+ return "In the CartPole game, you control a cart that moves along a horizontal track. There is a pole " \
34
+ "standing upright on the cart. The goal of the game is to keep the pole balanced upright by moving the " \
35
+ "cart left or right. The game ends if the pole tilts too far from the vertical position or if the cart " \
36
+ "moves too far from the center of the track. The longer you can keep the pole balanced, the higher " \
37
+ "your score.Note that when the Cart Position is out of the (-2.4, 2.4) zone or the Pole Angle is out " \
38
+ "of the zone (-.2095, .2095), the round ends and the game is lost. "
39
+
40
+ def describe_action(self):
41
+ return "Your Next Move: \n Please choose an action. Type '1' to push the cart to the left or '2' to push the cart to the right. Ensure you only provide the action number from the valid action list, i.e., [1, 2]."
42
+
43
+ class BasicStateSequenceTranslator(BasicLevelTranslator):
44
+ def translate(self, infos, is_current=False):
45
+ descriptions = []
46
+ if is_current:
47
+ state_desc = BasicLevelTranslator().translate(infos[-1]['state'])
48
+ return state_desc
49
+ for i, info in enumerate(infos):
50
+ assert 'state' in info, "info should contain state information"
51
+
52
+ state_desc = BasicLevelTranslator().translate(info['state'])
53
+ action_desc = f"Take Action: Push {'right' if info['action'] == 2 else 'left'} ({info['action']})."
54
+ reward_desc = f"Result: Reward of {info['reward']}, "
55
+ next_state_desc = BasicLevelTranslator().translate(info['next_state'])
56
+ descriptions.append(f"{state_desc}.\n {action_desc} \n {reward_desc} \n Transit to {next_state_desc}")
57
+ return descriptions
envs/classic_control/few_shot_examples/acrobot_l2.json ADDED
The diff for this file is too large to render. See raw diff
 
envs/classic_control/few_shot_examples/acrobot_l4.json ADDED
The diff for this file is too large to render. See raw diff
 
envs/classic_control/few_shot_examples/cartpole_l2.json ADDED
The diff for this file is too large to render. See raw diff
 
envs/classic_control/few_shot_examples/cartpole_l4.json ADDED
The diff for this file is too large to render. See raw diff
 
envs/classic_control/few_shot_examples/mountaincarContinuous_l2.json ADDED
The diff for this file is too large to render. See raw diff
 
envs/classic_control/few_shot_examples/mountaincarContinuous_l4.json ADDED
The diff for this file is too large to render. See raw diff
 
envs/classic_control/few_shot_examples/mountaincar_l2.json ADDED
The diff for this file is too large to render. See raw diff