Spaces:
Runtime error
Runtime error
Jarvis-K
commited on
Commit
·
2a33798
1
Parent(s):
64eba11
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +97 -0
- RL_based/test_RL.sh +39 -0
- RL_based/train_PPO.py +251 -0
- RL_based/train_RL.sh +39 -0
- RL_based/utils.py +621 -0
- deciders/__init__.py +26 -0
- deciders/act.py +248 -0
- deciders/cot.py +147 -0
- deciders/jarvis.py +177 -0
- deciders/jarvis_without_insights.py +179 -0
- deciders/jarvis_without_shortmem.py +182 -0
- deciders/jarvis_without_suggestions.py +180 -0
- deciders/jarvis_without_summary.py +179 -0
- deciders/misc.py +21 -0
- deciders/pal.py +149 -0
- deciders/parser.py +53 -0
- deciders/reflexion.py +179 -0
- deciders/self_consistency.py +170 -0
- deciders/selfask.py +150 -0
- deciders/spp.py +142 -0
- deciders/utils.py +65 -0
- distillers/__init__.py +10 -0
- distillers/guidance_summary_few_shot_examples.txt +85 -0
- distillers/guider.py +144 -0
- distillers/raw_prompt_generator.py +16 -0
- distillers/reflexion_few_shot_examples.txt +75 -0
- distillers/self_reflection.py +53 -0
- distillers/traj_prompt_summarizer.py +46 -0
- distillers/traj_summary_few_shot_examples.txt +76 -0
- draw_overall_performance.py +59 -0
- environment.yml +193 -0
- envs/__init__.py +51 -0
- envs/base_env.py +97 -0
- envs/box2d/LunarLander_policies.py +36 -0
- envs/box2d/LunarLander_translator.py +67 -0
- envs/box2d/__init__.py +0 -0
- envs/box2d/few_shot_examples/lunarlander_l2.json +0 -0
- envs/box2d/few_shot_examples/lunarlander_l4.json +0 -0
- envs/classic_control/__init__.py +0 -0
- envs/classic_control/acrobot_policies.py +36 -0
- envs/classic_control/acrobot_translator.py +58 -0
- envs/classic_control/cartpole_policies.py +25 -0
- envs/classic_control/cartpole_translator.py +57 -0
- envs/classic_control/few_shot_examples/acrobot_l2.json +0 -0
- envs/classic_control/few_shot_examples/acrobot_l4.json +0 -0
- envs/classic_control/few_shot_examples/cartpole_l2.json +0 -0
- envs/classic_control/few_shot_examples/cartpole_l4.json +0 -0
- envs/classic_control/few_shot_examples/mountaincarContinuous_l2.json +0 -0
- envs/classic_control/few_shot_examples/mountaincarContinuous_l4.json +0 -0
- envs/classic_control/few_shot_examples/mountaincar_l2.json +0 -0
README.md
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Bench LLM Deciders with gym translators
|
2 |
+
This project provides a set of translators to convert OpenAI Gym environments into text-based environments. It is designed to investigate the capabilities of large language models in decision-making tasks within these text-based environments.
|
3 |
+
|
4 |
+
## Summarizer Levels
|
5 |
+
We translate the game with basic level descriptions. It provides a simple description of the current state of the game. It's suitable for beginners who are just getting familiar with the game.
|
6 |
+
## Environment Categories
|
7 |
+
The environments are categorized based on the information that revealed to agents. We propose *5 level* scenarios.
|
8 |
+
|
9 |
+
**L1**: No external information is given. Only abstract game description. (zero shot)
|
10 |
+
|
11 |
+
**L2**: Agents can take a sampling traj of the random policy as external knowledge. (few shots, off-policy info)
|
12 |
+
|
13 |
+
**L3**: self sampling and updating w/ feedback. (few shots, on-policy info)
|
14 |
+
|
15 |
+
**L4**: sampling traj of an expert policy (few shots, expert-info)
|
16 |
+
|
17 |
+
**L5**: expert teaching (few shots, expert-info with guidance)
|
18 |
+
|
19 |
+
The five level scenarios are mainly considering making decision with perception. For future world, we leave it to stage 2 investigation.
|
20 |
+
|
21 |
+
**Perception and Future World**: These environments provide a perception of the current state, and also predict future infos. The futrue info is given in the info dict at step and reset.
|
22 |
+
|
23 |
+
It should be noted that the past memory part should be implemented as a component of deciders.
|
24 |
+
|
25 |
+
## Fewshot Examples Generation
|
26 |
+
For `L1` level, the `[]` is given.
|
27 |
+
For `L2` and `L4` level, we use `gen_few_shots_examples.py` to generate corresponding examples in json format and place them in the `envs/*/few_shot_examples/`.
|
28 |
+
For `L3` level, agent should collect the examples on their own and only a few methods support it. Thus we leave it to the agent design.
|
29 |
+
For `L5` level, we handcraft the few shot examples with domain knowledge in `prompts/task_relevant`.
|
30 |
+
|
31 |
+
## Usage
|
32 |
+
|
33 |
+
1. create `./deciders/gpt.py` to provide your gpt agent:
|
34 |
+
```python
|
35 |
+
import openai
|
36 |
+
class gpt:
|
37 |
+
def __init__(self,):
|
38 |
+
openai.api_type = "azure"
|
39 |
+
openai.api_version = "2023-05-15"
|
40 |
+
# Your Azure OpenAI resource's endpoint value.
|
41 |
+
openai.api_base = "https://js-partner.openai.azure.com/"
|
42 |
+
openai.api_key = "your azure openai key"
|
43 |
+
```
|
44 |
+
|
45 |
+
2. Install Requirements
|
46 |
+
|
47 |
+
```
|
48 |
+
conda env create --file environment.yml
|
49 |
+
```
|
50 |
+
|
51 |
+
3. Testing
|
52 |
+
The project can be run using the provided test.sh script. This script runs a series of commands, each of which initiates a Gym environment and applies different translators to it.
|
53 |
+
|
54 |
+
Here is an example of how to run the script:
|
55 |
+
|
56 |
+
```
|
57 |
+
./test.sh
|
58 |
+
```
|
59 |
+
The commands in test.sh are structured as follows:
|
60 |
+
|
61 |
+
```
|
62 |
+
python main.py --env_name ENV_NAME --init_summarizer INIT_SUMMARIZER --curr_summarizer CURR_SUMMARIZER [--future_summarizer FUTURE_SUMMARIZER --future_horizon FUTURE_HORIZON]
|
63 |
+
```
|
64 |
+
Where:
|
65 |
+
|
66 |
+
* ENV_NAME: The name of the Gym environment to be used (e.g., CartPole-v0).
|
67 |
+
* INIT_SUMMARIZER: The initial summarizer to be used (e.g., cart_init_translator).
|
68 |
+
* CURR_SUMMARIZER: The current summarizer to be used (e.g., cart_basic_translator).
|
69 |
+
* FUTURE_SUMMARIZER (optional): The future summarizer to be used (e.g., cart_basic_translator).
|
70 |
+
* FUTURE_HORIZON (optional): The horizon that each policy will look to (e.g., 3).
|
71 |
+
|
72 |
+
## Supported Environment Translators and LLM Deciders
|
73 |
+
|
74 |
+
| | Acrobot | Cart Pole | Mountain Car | Pendulum | Lunar Lander | Blackjack | Taxi | Cliff Walking | Frozen Lake |
|
75 |
+
|------------------------------|:------------------------:|:----------------------------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|
|
76 |
+
| Translator | :heavy_multiplication_x: | :white_check_mark: | :heavy_multiplication_x: | :heavy_multiplication_x: | :white_check_mark: | :heavy_multiplication_x: | :heavy_multiplication_x: | :heavy_multiplication_x: | :heavy_multiplication_x: |
|
77 |
+
| Chain-of-Thought | :heavy_minus_sign: | :white_check_mark:(L1)<br>:gift:<sup>[1]</sup>(~30) | :heavy_minus_sign: | :heavy_minus_sign: | :white_check_mark:(L1)<br/>:gift:<sup>[1]</sup>(-367) | :heavy_minus_sign: | :heavy_minus_sign: | :heavy_minus_sign: | :heavy_minus_sign: |
|
78 |
+
| Program-aided Language Model | :heavy_minus_sign: | :white_check_mark:(L1)<br>:gift:(168) | :heavy_minus_sign: | :heavy_minus_sign: | :white_check_mark:(L1)<br/>:gift:(-68) | :heavy_minus_sign: | :heavy_minus_sign: | :heavy_minus_sign: | :heavy_minus_sign: |
|
79 |
+
| Self-ask Prompting | :heavy_minus_sign: | :white_check_mark:(L1)<br>:gift:(~10) | :heavy_minus_sign: | :heavy_minus_sign: | :heavy_multiplication_x: | :heavy_minus_sign: | :heavy_minus_sign: | :heavy_minus_sign: | :heavy_minus_sign: |
|
80 |
+
| Self-consistency Prompting | :heavy_minus_sign: | :white_check_mark:(L1)<br>:gift:(~30) | :heavy_minus_sign: | :heavy_minus_sign: | :heavy_multiplication_x: | :heavy_minus_sign: | :heavy_minus_sign: | :heavy_minus_sign: | :heavy_minus_sign: |
|
81 |
+
| Reflexion | :heavy_minus_sign: | :heavy_multiplication_x: | :heavy_minus_sign: | :heavy_minus_sign: | :heavy_multiplication_x: | :heavy_minus_sign: | :heavy_minus_sign: | :heavy_minus_sign: | :heavy_minus_sign: |
|
82 |
+
| Solo Performance Prompting | :heavy_minus_sign: | :white_check_mark:(L1)<br/>:gift:(43) | :heavy_minus_sign: | :heavy_minus_sign: | :white_check_mark:(L1)<br/>:gift:(-583) | :heavy_minus_sign: | :heavy_minus_sign: | :heavy_minus_sign: | :heavy_minus_sign: |
|
83 |
+
|
84 |
+
<sup>[1]: Cumulative reward.</sup>
|
85 |
+
![Image text](https://github.com/mail-ecnu/LLM-Decider-Bench/blob/master/vis/Classic%20Control.png)
|
86 |
+
![Image text](https://github.com/mail-ecnu/LLM-Decider-Bench/blob/master/vis/Box%202D.png)
|
87 |
+
![Image text](https://github.com/mail-ecnu/LLM-Decider-Bench/blob/master/vis/Toy%20Text.png)
|
88 |
+
|
89 |
+
>
|
90 |
+
> 1. Except for the reflexion L3 decider, all other L3 deciders in this task do not have memory.
|
91 |
+
> 2. reflexion L1 and L3 both have memory.
|
92 |
+
> 3. reflexion L1 run 5 trails.
|
93 |
+
> 4. Blackjack、MountainCar、Cliffwalking(PAL)、CartPole(PAL)、Taxi(SPP、PAL)、Frozen Lake use deciders modified at 15:29 09.18
|
94 |
+
> 5. update Frozen Lake translator, add prior knowledge.
|
95 |
+
# Remarks
|
96 |
+
1. how to use future info
|
97 |
+
We provide future info in the env_info part. It is a dict and you can convert it to a text further to make your agent aware the world model.
|
RL_based/test_RL.sh
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# # ppo for cartpole-v0
|
2 |
+
# CUDA_VISIBLE_DEVICES=1 python RL_based/train_PPO.py --env_name CartPole-v0 --init_summarizer cart_init_translator --curr_summarizer cart_basic_translator\
|
3 |
+
# --trans_model_name distilbert-base-uncased --model_name nn_embedding --eval --policy-path RL_based/checkpoints/CartPole-v0/expert/policy.pth --collect_one_episode
|
4 |
+
|
5 |
+
# # ppo for lunarlander: treasured-music-91 score: 164.66
|
6 |
+
# TRANSFORMERS_OFFLINE=1 \
|
7 |
+
# CUDA_VISIBLE_DEVICES=2 python RL_based/train_PPO.py --env_name LunarLander-v2 --init_summarizer lunarLander_init_translator --curr_summarizer lunarLander_basic_translator \
|
8 |
+
# --trans_model_name /home/ubuntu/LLM-Decider-Bench/RL_based/transformer_offline_distilbert --model_name nn_embedding --max_length 128 --eval --collect_one_episode --policy-path /home/ubuntu/LLM-Decider-Bench/RL_based/checkpoints/LunarLander-v2/expert/policy.pth
|
9 |
+
|
10 |
+
# ppo for Acrobot-v1: charmed-salad-93 score: -85.8
|
11 |
+
# TRANSFORMERS_OFFLINE=1 \
|
12 |
+
# CUDA_VISIBLE_DEVICES=0 python RL_based/train_PPO.py --env_name Acrobot-v1 --init_summarizer acrobot_init_translator --curr_summarizer acrobot_basic_translator --decider naive_actor --prompt_level 1\
|
13 |
+
# --trans_model_name /home/ubuntu/LLM-Decider-Bench/RL_based/transformer_offline_distilbert --model_name nn_embedding --max_length 128 --eval --collect_one_episode --policy-path /home/ubuntu/LLM-Decider-Bench/RL_based/checkpoints/Acrobot-v1/expert/policy.pth
|
14 |
+
|
15 |
+
# # # # ppo for MountainCar-v0:
|
16 |
+
# TRANSFORMERS_OFFLINE=1 \
|
17 |
+
# CUDA_VISIBLE_DEVICES=1 python RL_based/train_PPO.py --env_name MountainCar-v0 --init_summarizer mountaincar_init_translator --curr_summarizer mountaincar_basic_translator --decider naive_actor --prompt_level 1\
|
18 |
+
# --trans_model_name /home/ubuntu/LLM-Decider-Bench/RL_based/transformer_offline_distilbert --model_name nn_embedding --max_length 128 --eval --collect_one_episode --policy-path /home/ubuntu/LLM-Decider-Bench/RL_based/checkpoints/MountainCar-v0/expert/policy.pth
|
19 |
+
|
20 |
+
# # ppo for Blackjack-v1
|
21 |
+
# TRANSFORMERS_OFFLINE=1 \
|
22 |
+
# CUDA_VISIBLE_DEVICES=2 python RL_based/train_PPO.py --env_name Blackjack-v1 --init_summarizer blackjack_init_translator --curr_summarizer blackjack_basic_translator --decider naive_actor --prompt_level 1\
|
23 |
+
# --trans_model_name /home/ubuntu/LLM-Decider-Bench/RL_based/transformer_offline_distilbert --model_name nn_embedding --max_length 128 --eval --collect_one_episode --policy-path /home/ubuntu/LLM-Decider-Bench/RL_based/checkpoints/Blackjack-v1/expert/policy.pth
|
24 |
+
|
25 |
+
# # # ppo for Taxi-v3
|
26 |
+
TRANSFORMERS_OFFLINE=1 \
|
27 |
+
CUDA_VISIBLE_DEVICES=6 python RL_based/train_PPO.py --env_name Taxi-v3 --init_summarizer taxi_init_translator --curr_summarizer taxi_basic_translator --decider naive_actor --prompt_level 1\
|
28 |
+
--trans_model_name /home/ubuntu/LLM-Decider-Bench/RL_based/transformer_offline_distilbert --model_name nn_embedding --max_length 128 --eval --collect_one_episode --policy-path /home/ubuntu/LLM-Decider-Bench/RL_based/checkpoints/Taxi-v3/expert/policy.pth
|
29 |
+
|
30 |
+
# # # ppo for CliffWalking-v0
|
31 |
+
# TRANSFORMERS_OFFLINE=1 \
|
32 |
+
# CUDA_VISIBLE_DEVICES=4 python RL_based/train_PPO.py --env_name CliffWalking-v0 --init_summarizer cliffwalking_init_translator --curr_summarizer cliffwalking_basic_translator --decider naive_actor --prompt_level 1\
|
33 |
+
# --trans_model_name /home/ubuntu/LLM-Decider-Bench/RL_based/transformer_offline_distilbert --model_name nn_embedding --max_length 128 --eval --collect_one_episode --policy-path /home/ubuntu/LLM-Decider-Bench/RL_based/checkpoints/CliffWalking-v0/expert/policy.pth
|
34 |
+
|
35 |
+
# # # ppo for FrozenLake-v1
|
36 |
+
# TRANSFORMERS_OFFLINE=1 \
|
37 |
+
# CUDA_VISIBLE_DEVICES=5 python RL_based/train_PPO.py --env_name FrozenLake-v1 --init_summarizer frozenlake_init_translator --curr_summarizer frozenlake_basic_translator --decider naive_actor --prompt_level 1\
|
38 |
+
# --trans_model_name /home/ubuntu/LLM-Decider-Bench/RL_based/transformer_offline_distilbert --model_name nn_embedding --max_length 128 --eval --collect_one_episode --policy-path /home/ubuntu/LLM-Decider-Bench/RL_based/checkpoints/FrozenLake-v1/expert/policy.pth
|
39 |
+
|
RL_based/train_PPO.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import sys
|
3 |
+
sys.path.insert(0, sys.path[0]+"/../")
|
4 |
+
import prompts as task_prompts
|
5 |
+
import envs
|
6 |
+
import os
|
7 |
+
from envs.translator import InitSummarizer, CurrSummarizer, FutureSummarizer, Translator
|
8 |
+
import gym
|
9 |
+
from torch.optim.lr_scheduler import LambdaLR
|
10 |
+
import torch
|
11 |
+
from tianshou.data import Collector, VectorReplayBuffer, ReplayBuffer
|
12 |
+
from tianshou.env import DummyVectorEnv, SubprocVectorEnv
|
13 |
+
from tianshou.policy import PPOPolicy, ICMPolicy
|
14 |
+
from tianshou.trainer import onpolicy_trainer
|
15 |
+
from tianshou.utils.net.common import ActorCritic
|
16 |
+
from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule
|
17 |
+
from RL_based.utils import Net_GRU_Bert_tianshou, Net_Bert_CLS_tianshou, Net_Bert_CNN_tianshou, Net_GRU_nn_emb_tianshou
|
18 |
+
from tianshou.utils import WandbLogger
|
19 |
+
from torch.utils.tensorboard import SummaryWriter
|
20 |
+
from tianshou.trainer.utils import test_episode
|
21 |
+
|
22 |
+
import warnings
|
23 |
+
warnings.filterwarnings('ignore')
|
24 |
+
|
25 |
+
class MaxStepLimitWrapper(gym.Wrapper):
|
26 |
+
def __init__(self, env, max_steps=200):
|
27 |
+
super(MaxStepLimitWrapper, self).__init__(env)
|
28 |
+
self.max_steps = max_steps
|
29 |
+
self.current_step = 0
|
30 |
+
|
31 |
+
def reset(self, **kwargs):
|
32 |
+
self.current_step = 0
|
33 |
+
return self.env.reset(**kwargs)
|
34 |
+
|
35 |
+
def step(self, action):
|
36 |
+
observation, reward, terminated, truncated, info = self.env.step(action)
|
37 |
+
self.current_step += 1
|
38 |
+
|
39 |
+
if self.current_step >= self.max_steps:
|
40 |
+
terminated = True
|
41 |
+
info['episode_step_limit'] = self.max_steps
|
42 |
+
|
43 |
+
return observation, reward, terminated, truncated, info
|
44 |
+
|
45 |
+
class SimpleTextWrapper(gym.Wrapper):
|
46 |
+
def __init__(self, env):
|
47 |
+
super(SimpleTextWrapper, self).__init__(env)
|
48 |
+
self.env = env
|
49 |
+
|
50 |
+
def reset(self, **kwargs):
|
51 |
+
observation, _ = self.env.reset(**kwargs)
|
52 |
+
return str(observation), {}
|
53 |
+
|
54 |
+
def step(self, action):
|
55 |
+
observation, reward, terminated, truncated, info = self.env.step(action)
|
56 |
+
return str(observation), reward, terminated, truncated, info
|
57 |
+
|
58 |
+
if __name__ == "__main__":
|
59 |
+
parser = argparse.ArgumentParser(description='Evaluate a translator in a gym environment with a ChatGPT model.')
|
60 |
+
parser.add_argument('--init_summarizer', type=str, required=True, help='The name of the init summarizer to use.')
|
61 |
+
parser.add_argument('--curr_summarizer', type=str, required=True, help='The name of the curr summarizer to use.')
|
62 |
+
parser.add_argument('--future_summarizer', type=str, help='The name of the future summarizer to use.')
|
63 |
+
parser.add_argument('--env', type=str, default='base_env', help='The name of the gym environment to use.')
|
64 |
+
parser.add_argument('--env_name', type=str, default='CartPole-v1', help='The name of the gym environment to use.')
|
65 |
+
parser.add_argument('--decider', type=str, default="naive_actor", help='The actor used to select action')
|
66 |
+
parser.add_argument('--render', type=str, default="rgb_array", help='The render mode')
|
67 |
+
parser.add_argument('--future_horizon', type=int, help='The horizon of looking to future')
|
68 |
+
parser.add_argument(
|
69 |
+
"--prompt_level",
|
70 |
+
type=int,
|
71 |
+
default=1,
|
72 |
+
help="The level of prompts",
|
73 |
+
)
|
74 |
+
parser.add_argument(
|
75 |
+
"--past_horizon", type=int, help="The horizon of looking back"
|
76 |
+
)
|
77 |
+
parser.add_argument(
|
78 |
+
"--max_episode_len", type=int, default=200, help="The max length of an episode"
|
79 |
+
)
|
80 |
+
|
81 |
+
### for RL training
|
82 |
+
parser.add_argument('--max_length', type=int, default=128, help='The token length of the observation')
|
83 |
+
# trans_model_name
|
84 |
+
parser.add_argument('--trans_model_name', type=str, default='bert-base-uncased', help='The name of the pretrained transformer to use.')
|
85 |
+
parser.add_argument('--model_name', type=str, default='bert-embedding', help='The name of the model to use.')
|
86 |
+
parser.add_argument('--vector_env', type=str, default='dummy', help='The name of the vector env to use.')
|
87 |
+
parser.add_argument('--eval', action='store_true', default=False, help='Whether to only eval the model')
|
88 |
+
parser.add_argument('--policy-path', type=str, default=None, help='The path to the policy to be evaluated')
|
89 |
+
parser.add_argument('--collect_one_episode', action='store_true', default=False, help='Whether to only collect one episode')
|
90 |
+
parser.add_argument('--lr', type=float, default=0.0003, help='The learning rate of the model')
|
91 |
+
parser.add_argument('--step_per_epoch', type=int, default=10000, help='The number of steps per epoch')
|
92 |
+
parser.add_argument('--step_per_collect', type=int, default=2000, help='The number of steps per collect')
|
93 |
+
parser.add_argument('--lr_decay', action='store_true', default=False, help='Whether to decay the learning rate')
|
94 |
+
parser.add_argument('--epoch', type=int, default=400, help='The number of epochs to train')
|
95 |
+
parser.add_argument('--resume_path', type=str, default=None, help='The path to the policy to be resumed')
|
96 |
+
parser.add_argument('--taxi_specific_env', action='store_true', default=False, help='Whether to use taxi specific env')
|
97 |
+
args = parser.parse_args()
|
98 |
+
args_dict = vars(args)
|
99 |
+
|
100 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
101 |
+
# Get the specified translator, environment, and ChatGPT model
|
102 |
+
env_class = envs.REGISTRY[args.env]
|
103 |
+
init_summarizer = InitSummarizer(envs.REGISTRY[args.init_summarizer])
|
104 |
+
curr_summarizer = CurrSummarizer(envs.REGISTRY[args.curr_summarizer])
|
105 |
+
if args.future_summarizer:
|
106 |
+
future_summarizer = FutureSummarizer(
|
107 |
+
envs.REGISTRY[args.future_summarizer],
|
108 |
+
envs.REGISTRY["cart_policies"],
|
109 |
+
future_horizon=args.future_horizon,
|
110 |
+
)
|
111 |
+
else:
|
112 |
+
future_summarizer = None
|
113 |
+
|
114 |
+
wandb_log_config = {
|
115 |
+
"env": args.env_name,
|
116 |
+
"init_summarizer": args.init_summarizer,
|
117 |
+
"curr_summarizer": args.curr_summarizer,
|
118 |
+
"future_summarizer": args.future_summarizer,
|
119 |
+
}
|
120 |
+
wandb_log_config.update(args_dict)
|
121 |
+
|
122 |
+
if not args.eval:
|
123 |
+
logger = WandbLogger(
|
124 |
+
project="LLM-decider-bench-RL",
|
125 |
+
entity="llm-bench-team",
|
126 |
+
config=wandb_log_config,
|
127 |
+
)
|
128 |
+
random_name = logger.wandb_run.name
|
129 |
+
log_path = os.path.join('/home/ubuntu/LLM-Decider-Bench/RL_based/results', args.env_name, random_name)
|
130 |
+
writer = SummaryWriter(log_dir=log_path)
|
131 |
+
writer.add_text("args", str(args))
|
132 |
+
logger.load(writer)
|
133 |
+
def save_best_fn(policy):
|
134 |
+
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
135 |
+
|
136 |
+
sampling_env = envs.REGISTRY["sampling_wrapper"](gym.make(args.env_name))
|
137 |
+
if args.prompt_level == 5:
|
138 |
+
prompts_class = task_prompts.REGISTRY[(args.env_name,args.decider)]()
|
139 |
+
else:
|
140 |
+
prompts_class = task_prompts.REGISTRY[(args.decider)]()
|
141 |
+
translator = Translator(
|
142 |
+
init_summarizer, curr_summarizer, future_summarizer, env=sampling_env
|
143 |
+
)
|
144 |
+
if args.taxi_specific_env:
|
145 |
+
environment = gym.make(args.env_name, render_mode=args.render)
|
146 |
+
else:
|
147 |
+
environment = env_class(
|
148 |
+
gym.make(args.env_name, render_mode=args.render), translator
|
149 |
+
)
|
150 |
+
|
151 |
+
# Set the translation level
|
152 |
+
translate_level = 1
|
153 |
+
if args.past_horizon is None and args.future_horizon is None:
|
154 |
+
translate_level = 1
|
155 |
+
if args.past_horizon and args.future_horizon is None:
|
156 |
+
raise NotImplementedError
|
157 |
+
# translate_level = 2
|
158 |
+
if args.past_horizon is None and args.future_horizon:
|
159 |
+
raise NotImplementedError
|
160 |
+
# translate_level = 3
|
161 |
+
if args.past_horizon and args.future_horizon:
|
162 |
+
raise NotImplementedError
|
163 |
+
# translate_level = 3.5
|
164 |
+
|
165 |
+
|
166 |
+
if args.vector_env == 'dummy':
|
167 |
+
ThisEnv = DummyVectorEnv
|
168 |
+
elif args.vector_env == 'subproc':
|
169 |
+
ThisEnv = SubprocVectorEnv
|
170 |
+
def make_env():
|
171 |
+
if args.taxi_specific_env:
|
172 |
+
env = MaxStepLimitWrapper(SimpleTextWrapper(gym.make(args.env_name, render_mode=args.render)), max_steps=200)
|
173 |
+
env._max_episode_steps = args.max_episode_len
|
174 |
+
else:
|
175 |
+
env = env_class(MaxStepLimitWrapper(gym.make(args.env_name, render_mode=args.render), max_steps=200), translator)
|
176 |
+
env._max_episode_steps = args.max_episode_len
|
177 |
+
|
178 |
+
return env
|
179 |
+
train_envs = ThisEnv([make_env for _ in range(20)])
|
180 |
+
test_envs = ThisEnv([make_env for _ in range(10)])
|
181 |
+
# model & optimizer
|
182 |
+
def get_net():
|
183 |
+
if args.model_name == "bert-embedding":
|
184 |
+
net = Net_GRU_Bert_tianshou(state_shape=environment.observation_space.shape, hidden_sizes=[64, 64], device=device, max_length=args.max_length, trans_model_name=args.trans_model_name)
|
185 |
+
elif args.model_name == "bert-CLS-embedding":
|
186 |
+
net = Net_Bert_CLS_tianshou(state_shape=environment.observation_space.shape, hidden_sizes=[256, 128], device=device, max_length=args.max_length, trans_model_name=args.trans_model_name)
|
187 |
+
elif args.model_name == "bert-CNN-embedding":
|
188 |
+
net = Net_Bert_CNN_tianshou(state_shape=environment.observation_space.shape, hidden_sizes=[256, 128], device=device, max_length=args.max_length, trans_model_name=args.trans_model_name)
|
189 |
+
elif args.model_name == "nn_embedding":
|
190 |
+
net = Net_GRU_nn_emb_tianshou(hidden_sizes=[256, 128], device=device, max_length=args.max_length, trans_model_name=args.trans_model_name)
|
191 |
+
return net
|
192 |
+
net = get_net()
|
193 |
+
actor = Actor(net, environment.action_space.n, device=device).to(device)
|
194 |
+
critic = Critic(net, device=device).to(device)
|
195 |
+
actor_critic = ActorCritic(actor, critic)
|
196 |
+
optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr)
|
197 |
+
|
198 |
+
# PPO policy
|
199 |
+
dist = torch.distributions.Categorical
|
200 |
+
lr_scheduler = None
|
201 |
+
if args.lr_decay:
|
202 |
+
max_update_num = args.step_per_epoch // args.step_per_collect * args.epoch
|
203 |
+
|
204 |
+
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
|
205 |
+
policy = PPOPolicy(actor, critic, optim, dist, action_space=environment.action_space, lr_scheduler=lr_scheduler).to(device)
|
206 |
+
# collector
|
207 |
+
train_collector = Collector(policy, train_envs, VectorReplayBuffer(20000, len(train_envs)), exploration_noise=True)
|
208 |
+
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
209 |
+
|
210 |
+
if not args.eval:
|
211 |
+
# trainer
|
212 |
+
# test train_collector and start filling replay buffer
|
213 |
+
|
214 |
+
if args.resume_path:
|
215 |
+
policy.load_state_dict(torch.load(args.resume_path, map_location='cuda'))
|
216 |
+
print("Loaded agent from: ", args.resume_path)
|
217 |
+
|
218 |
+
train_collector.collect(256 * 20)
|
219 |
+
result = onpolicy_trainer(
|
220 |
+
policy,
|
221 |
+
train_collector,
|
222 |
+
test_collector,
|
223 |
+
max_epoch=args.epoch,
|
224 |
+
step_per_epoch=50000, # the number of transitions collected per epoch
|
225 |
+
repeat_per_collect=4,
|
226 |
+
episode_per_test=10,
|
227 |
+
batch_size=256,
|
228 |
+
logger=logger,
|
229 |
+
step_per_collect=1000, # the number of transitions the collector would collect before the network update
|
230 |
+
save_best_fn=save_best_fn,
|
231 |
+
# stop_fn=lambda mean_reward: mean_reward >= environment.spec.reward_threshold,
|
232 |
+
)
|
233 |
+
print(result)
|
234 |
+
else:
|
235 |
+
assert args.policy_path is not None
|
236 |
+
policy.load_state_dict(torch.load(args.policy_path))
|
237 |
+
test_collector = Collector(policy, test_envs)
|
238 |
+
result = test_episode(policy, test_collector, None, None, n_episode=10)
|
239 |
+
print(result)
|
240 |
+
if args.collect_one_episode:
|
241 |
+
replaybuffer = ReplayBuffer(size=1000)
|
242 |
+
test_collector_1 = Collector(policy, environment, replaybuffer)
|
243 |
+
test_collector_1.reset_env()
|
244 |
+
test_collector_1.reset_buffer()
|
245 |
+
policy.eval()
|
246 |
+
result = test_collector_1.collect(n_episode=1)
|
247 |
+
print('sample results', f"/home/ubuntu/LLM-Decider-Bench/RL_based/checkpoints/{args.env_name}/output.txt")
|
248 |
+
sample_result = replaybuffer.sample(0)
|
249 |
+
f = open(f"/home/ubuntu/LLM-Decider-Bench/RL_based/checkpoints/{args.env_name}/output.txt", "w")
|
250 |
+
print(sample_result, file=f)
|
251 |
+
f.close()
|
RL_based/train_RL.sh
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# # ppo for cartpole
|
2 |
+
# CUDA_VISIBLE_DEVICES=1 python RL_based/train_PPO.py --env_name CartPole-v0 --init_summarizer cart_init_translator --curr_summarizer cart_basic_translator\
|
3 |
+
# --trans_model_name /home/ubuntu/LLM-Decider-Bench/RL_based/transformer_offline_distilbert --model_name nn_embedding
|
4 |
+
|
5 |
+
# # ppo for lunarlander
|
6 |
+
# TRANSFORMERS_OFFLINE=1 \
|
7 |
+
# CUDA_VISIBLE_DEVICES=3 python RL_based/train_PPO.py --env_name LunarLander-v2 --init_summarizer lunarLander_init_translator --curr_summarizer lunarLander_basic_translator \
|
8 |
+
# --trans_model_name /home/ubuntu/LLM-Decider-Bench/RL_based/transformer_offline_distilbert --model_name nn_embedding --max_length 128 --lr 0.0003 --lr_decay --epoch 500
|
9 |
+
|
10 |
+
# ppo for Acrobot-v1
|
11 |
+
# TRANSFORMERS_OFFLINE=1 \
|
12 |
+
# CUDA_VISIBLE_DEVICES=0 python RL_based/train_PPO.py --env_name Acrobot-v1 --init_summarizer acrobot_init_translator --curr_summarizer acrobot_basic_translator --decider naive_actor --prompt_level 1\
|
13 |
+
# --trans_model_name /home/ubuntu/LLM-Decider-Bench/RL_based/transformer_offline_distilbert --model_name nn_embedding --max_length 128 --lr 0.0003 --lr_decay --epoch 500 &
|
14 |
+
|
15 |
+
# # # ppo for MountainCar-v0
|
16 |
+
# TRANSFORMERS_OFFLINE=1 \
|
17 |
+
# CUDA_VISIBLE_DEVICES=1 python RL_based/train_PPO.py --env_name MountainCar-v0 --init_summarizer mountaincar_init_translator --curr_summarizer mountaincar_basic_translator --decider naive_actor --prompt_level 1\
|
18 |
+
# --trans_model_name /home/ubuntu/LLM-Decider-Bench/RL_based/transformer_offline_distilbert --model_name nn_embedding --max_length 300 --lr 0.0003 --lr_decay --epoch 500 &
|
19 |
+
|
20 |
+
# ppo for Blackjack-v1
|
21 |
+
# TRANSFORMERS_OFFLINE=1 \
|
22 |
+
# CUDA_VISIBLE_DEVICES=2 python RL_based/train_PPO.py --env_name Blackjack-v1 --init_summarizer blackjack_init_translator --curr_summarizer blackjack_basic_translator --decider naive_actor --prompt_level 1\
|
23 |
+
# --trans_model_name /home/ubuntu/LLM-Decider-Bench/RL_based/transformer_offline_distilbert --model_name nn_embedding --max_length 300 --lr 0.0003 --lr_decay --epoch 500 &
|
24 |
+
|
25 |
+
# # ppo for Taxi-v3
|
26 |
+
TRANSFORMERS_OFFLINE=1 \
|
27 |
+
CUDA_VISIBLE_DEVICES=6 python RL_based/train_PPO.py --env_name Taxi-v3 --init_summarizer taxi_init_translator --curr_summarizer taxi_basic_translator --decider naive_actor --prompt_level 1\
|
28 |
+
--trans_model_name /home/ubuntu/LLM-Decider-Bench/RL_based/transformer_offline_distilbert --model_name nn_embedding --max_length 300 --lr 0.0003 --lr_decay --epoch 500 --taxi_specific_env
|
29 |
+
|
30 |
+
# # ppo for CliffWalking-v0
|
31 |
+
# TRANSFORMERS_OFFLINE=1 \
|
32 |
+
# CUDA_VISIBLE_DEVICES=4 python RL_based/train_PPO.py --env_name CliffWalking-v0 --init_summarizer cliffwalking_init_translator --curr_summarizer cliffwalking_basic_translator --decider naive_actor --prompt_level 1\
|
33 |
+
# --trans_model_name /home/ubuntu/LLM-Decider-Bench/RL_based/transformer_offline_distilbert --model_name nn_embedding --max_length 300 --lr 0.0003 --lr_decay --epoch 500 &
|
34 |
+
|
35 |
+
# # ppo for FrozenLake-v1
|
36 |
+
# TRANSFORMERS_OFFLINE=1 \
|
37 |
+
# CUDA_VISIBLE_DEVICES=5 python RL_based/train_PPO.py --env_name FrozenLake-v1 --init_summarizer frozenlake_init_translator --curr_summarizer frozenlake_basic_translator --decider naive_actor --prompt_level 1\
|
38 |
+
# --trans_model_name /home/ubuntu/LLM-Decider-Bench/RL_based/transformer_offline_distilbert --model_name nn_embedding --max_length 300 --lr 0.0003 --lr_decay --epoch 500 &
|
39 |
+
|
RL_based/utils.py
ADDED
@@ -0,0 +1,621 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
sys.path.insert(0, sys.path[0]+"/../")
|
6 |
+
from typing import (
|
7 |
+
Any,
|
8 |
+
Dict,
|
9 |
+
List,
|
10 |
+
Optional,
|
11 |
+
Sequence,
|
12 |
+
Tuple,
|
13 |
+
Type,
|
14 |
+
Union,
|
15 |
+
no_type_check,
|
16 |
+
)
|
17 |
+
import torch.nn as nn
|
18 |
+
from tianshou.utils.net.discrete import NoisyLinear
|
19 |
+
ModuleType = Type[nn.Module]
|
20 |
+
import random
|
21 |
+
from collections import namedtuple, deque
|
22 |
+
from itertools import count
|
23 |
+
import math
|
24 |
+
import torch
|
25 |
+
import torch.optim as optim
|
26 |
+
from transformers import AutoModel, AutoTokenizer
|
27 |
+
import torch.nn.functional as F
|
28 |
+
from tianshou.utils.net.common import ModuleType, Net, MLP
|
29 |
+
|
30 |
+
|
31 |
+
def bert_embedding(x, max_length=512, device='cuda'):
|
32 |
+
from transformers import logging
|
33 |
+
logging.set_verbosity_error()
|
34 |
+
model_name = 'bert-base-uncased'
|
35 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
36 |
+
bert_model = AutoModel.from_pretrained(model_name)
|
37 |
+
text = x
|
38 |
+
if isinstance(text, np.ndarray):
|
39 |
+
text = list(text)
|
40 |
+
tokens = tokenizer(text, max_length=max_length, padding='max_length', truncation=True, return_tensors='pt')
|
41 |
+
input_ids = tokens['input_ids']
|
42 |
+
attention_mask = tokens['attention_mask']
|
43 |
+
with torch.no_grad():
|
44 |
+
outputs = bert_model(input_ids, attention_mask=attention_mask)
|
45 |
+
embeddings = outputs.last_hidden_state
|
46 |
+
return embeddings
|
47 |
+
|
48 |
+
class Net_GRU(nn.Module):
|
49 |
+
|
50 |
+
def __init__(self, input_size, n_actions, hidden_dim, n_layers, dropout, bidirectional):
|
51 |
+
super(Net_GRU, self).__init__()
|
52 |
+
self.input_size = input_size
|
53 |
+
self.hidden_dim = hidden_dim
|
54 |
+
self.num_classes = n_actions
|
55 |
+
self.n_layers = n_layers
|
56 |
+
self.dropout = dropout
|
57 |
+
self.bidirectional = bidirectional
|
58 |
+
|
59 |
+
# Layers
|
60 |
+
self.gru = nn.GRU(self.input_size, self.hidden_dim, self.n_layers,
|
61 |
+
batch_first=True, dropout=self.dropout, bidirectional=self.bidirectional)
|
62 |
+
self.final_layer = nn.Linear(self.hidden_dim*(1 + int(self.bidirectional)), self.num_classes)
|
63 |
+
|
64 |
+
def forward(self, x):
|
65 |
+
# Input shape: (batch_size, seq_length)
|
66 |
+
batch_size, seq_length, emb_size = x.size()
|
67 |
+
|
68 |
+
gru_out, hidden = self.gru(x)
|
69 |
+
|
70 |
+
# Use the final state
|
71 |
+
# hidden -> (num_direction, batch, hidden_size)
|
72 |
+
if self.bidirectional:
|
73 |
+
hidden = hidden.view(self.n_layers, 2, batch_size, self.hidden_dim)
|
74 |
+
final_hidden = torch.cat((hidden[:, -1, :, :].squeeze(0), hidden[:, 0, :, :].squeeze(0)), 1)
|
75 |
+
else:
|
76 |
+
final_hidden = hidden.squeeze(0)
|
77 |
+
|
78 |
+
# final_hidden -> (batch_size, num_classes)
|
79 |
+
logits = self.final_layer(final_hidden)
|
80 |
+
|
81 |
+
return logits
|
82 |
+
|
83 |
+
class MyGRU(nn.Module):
|
84 |
+
def __init__(self, input_size, hidden_dim, n_layers, dropout, bidirectional, output_dim):
|
85 |
+
super(MyGRU, self).__init__()
|
86 |
+
self.input_size = input_size
|
87 |
+
self.hidden_dim = hidden_dim
|
88 |
+
self.n_layers = n_layers
|
89 |
+
self.dropout = dropout
|
90 |
+
self.bidirectional = bidirectional
|
91 |
+
|
92 |
+
# Layers
|
93 |
+
self.gru = nn.GRU(self.input_size, self.hidden_dim, self.n_layers,
|
94 |
+
batch_first=True, dropout=self.dropout, bidirectional=self.bidirectional)
|
95 |
+
self.final_layer = nn.Linear(self.hidden_dim*(1 + int(self.bidirectional)), output_dim)
|
96 |
+
|
97 |
+
def forward(self, x):
|
98 |
+
batch_size, seq_length, emb_size = x.size()
|
99 |
+
|
100 |
+
gru_out, hidden = self.gru(x)
|
101 |
+
|
102 |
+
# Use the final state
|
103 |
+
# hidden -> (num_direction, batch, hidden_size)
|
104 |
+
if self.bidirectional:
|
105 |
+
hidden = hidden.view(self.n_layers, 2, batch_size, self.hidden_dim)
|
106 |
+
final_hidden = torch.cat((hidden[:, -1, :, :].squeeze(0), hidden[:, 0, :, :].squeeze(0)), 1)
|
107 |
+
else:
|
108 |
+
final_hidden = hidden.squeeze(0)
|
109 |
+
|
110 |
+
# final_hidden -> (batch_size, num_classes)
|
111 |
+
logits = self.final_layer(final_hidden)
|
112 |
+
|
113 |
+
return logits
|
114 |
+
|
115 |
+
class MyCNN(nn.Module):
|
116 |
+
def __init__(self,
|
117 |
+
input_dim: int,
|
118 |
+
output_dim: int = 0,
|
119 |
+
hidden_sizes: Sequence[int] = (),
|
120 |
+
norm_layer: Optional[Union[ModuleType, Sequence[ModuleType]]] = None,
|
121 |
+
activation: ModuleType = nn.ReLU,
|
122 |
+
device: Optional[Union[str, int, torch.device]] = None,
|
123 |
+
linear_layer: Type[nn.Linear] = nn.Linear,
|
124 |
+
flatten_input: bool = True,) -> None:
|
125 |
+
super().__init__()
|
126 |
+
self.model = []
|
127 |
+
input_dim_temp = input_dim
|
128 |
+
for h in hidden_sizes:
|
129 |
+
self.model.append(nn.Conv1d(in_channels=input_dim_temp, out_channels=h, kernel_size=3, padding=1))
|
130 |
+
self.model.append(activation())
|
131 |
+
self.model.append(nn.MaxPool1d(kernel_size=2))
|
132 |
+
input_dim_temp = h
|
133 |
+
self.model = nn.Sequential(*self.model)
|
134 |
+
self.fc = nn.Linear(in_features=input_dim_temp, out_features=output_dim)
|
135 |
+
|
136 |
+
def forward(self, x):
|
137 |
+
x = self.model(x.transpose(1, 2))
|
138 |
+
x.transpose_(1, 2)
|
139 |
+
x = self.fc(x)
|
140 |
+
return x
|
141 |
+
|
142 |
+
class Net_GRU_Bert_tianshou(Net):
|
143 |
+
def __init__(
|
144 |
+
self,
|
145 |
+
state_shape: Union[int, Sequence[int]],
|
146 |
+
action_shape: Union[int, Sequence[int]] = 0,
|
147 |
+
hidden_sizes: Sequence[int] = (),
|
148 |
+
norm_layer: Optional[ModuleType] = None,
|
149 |
+
activation: Optional[ModuleType] = nn.ReLU,
|
150 |
+
device: Union[str, int, torch.device] = "cpu",
|
151 |
+
softmax: bool = False,
|
152 |
+
concat: bool = False,
|
153 |
+
num_atoms: int = 1,
|
154 |
+
dueling_param: Optional[Tuple[Dict[str, Any], Dict[str, Any]]] = None,
|
155 |
+
linear_layer: Type[nn.Linear] = nn.Linear,
|
156 |
+
hidden_dim: int = 128,
|
157 |
+
bidirectional: bool = True,
|
158 |
+
dropout: float = 0.,
|
159 |
+
n_layers: int = 1,
|
160 |
+
max_length: int = 512,
|
161 |
+
trans_model_name: str = 'bert-base-uncased',
|
162 |
+
) -> None:
|
163 |
+
nn.Module.__init__(self)
|
164 |
+
self.device = device
|
165 |
+
self.softmax = softmax
|
166 |
+
self.num_atoms = num_atoms
|
167 |
+
self.hidden_dim = hidden_dim
|
168 |
+
self.bidirectional = bidirectional
|
169 |
+
self.dropout = dropout
|
170 |
+
self.n_layers = n_layers
|
171 |
+
self.trans_model_name = trans_model_name
|
172 |
+
self.max_length = max_length
|
173 |
+
|
174 |
+
input_dim = int(np.prod(state_shape))
|
175 |
+
action_dim = int(np.prod(action_shape)) * num_atoms
|
176 |
+
if concat:
|
177 |
+
input_dim += action_dim
|
178 |
+
self.use_dueling = dueling_param is not None
|
179 |
+
output_dim = action_dim if not self.use_dueling and not concat else 0
|
180 |
+
self.output_dim = output_dim or hidden_dim
|
181 |
+
self.model = MyGRU(768, self.hidden_dim, self.n_layers,
|
182 |
+
self.dropout, self.bidirectional, self.output_dim)
|
183 |
+
if self.use_dueling: # dueling DQN
|
184 |
+
q_kwargs, v_kwargs = dueling_param # type: ignore
|
185 |
+
q_output_dim, v_output_dim = 0, 0
|
186 |
+
if not concat:
|
187 |
+
q_output_dim, v_output_dim = action_dim, num_atoms
|
188 |
+
q_kwargs: Dict[str, Any] = {
|
189 |
+
**q_kwargs, "input_dim": self.output_dim,
|
190 |
+
"output_dim": q_output_dim,
|
191 |
+
"device": self.device
|
192 |
+
}
|
193 |
+
v_kwargs: Dict[str, Any] = {
|
194 |
+
**v_kwargs, "input_dim": self.output_dim,
|
195 |
+
"output_dim": v_output_dim,
|
196 |
+
"device": self.device
|
197 |
+
}
|
198 |
+
self.Q, self.V = MLP(**q_kwargs), MLP(**v_kwargs)
|
199 |
+
self.output_dim = self.Q.output_dim
|
200 |
+
self.bert_model = AutoModel.from_pretrained(self.trans_model_name).to(self.device)
|
201 |
+
self.tokenizer = AutoTokenizer.from_pretrained(trans_model_name)
|
202 |
+
from transformers import logging
|
203 |
+
logging.set_verbosity_error()
|
204 |
+
|
205 |
+
def bert_embedding(self, x, max_length=512):
|
206 |
+
text = x
|
207 |
+
if isinstance(text, np.ndarray):
|
208 |
+
text = list(text)
|
209 |
+
tokens = self.tokenizer(text, max_length=max_length, padding='max_length', truncation=True, return_tensors='pt')
|
210 |
+
input_ids = tokens['input_ids'].to(self.device)
|
211 |
+
attention_mask = tokens['attention_mask'].to(self.device)
|
212 |
+
with torch.no_grad():
|
213 |
+
outputs = self.bert_model(input_ids, attention_mask=attention_mask)
|
214 |
+
embeddings = outputs.last_hidden_state
|
215 |
+
return embeddings
|
216 |
+
|
217 |
+
def forward(
|
218 |
+
self,
|
219 |
+
obs: Union[np.ndarray, torch.Tensor],
|
220 |
+
state: Any = None,
|
221 |
+
info: Dict[str, Any] = {},
|
222 |
+
) -> Tuple[torch.Tensor, Any]:
|
223 |
+
"""Mapping: obs -> flatten (inside MLP)-> logits."""
|
224 |
+
embedding = self.bert_embedding(obs, max_length=self.max_length)
|
225 |
+
logits = self.model(embedding)
|
226 |
+
bsz = logits.shape[0]
|
227 |
+
if self.use_dueling: # Dueling DQN
|
228 |
+
q, v = self.Q(logits), self.V(logits)
|
229 |
+
if self.num_atoms > 1:
|
230 |
+
q = q.view(bsz, -1, self.num_atoms)
|
231 |
+
v = v.view(bsz, -1, self.num_atoms)
|
232 |
+
logits = q - q.mean(dim=1, keepdim=True) + v
|
233 |
+
elif self.num_atoms > 1:
|
234 |
+
logits = logits.view(bsz, -1, self.num_atoms)
|
235 |
+
if self.softmax:
|
236 |
+
logits = torch.softmax(logits, dim=-1)
|
237 |
+
return logits, state
|
238 |
+
|
239 |
+
class Net_Bert_CLS_tianshou(Net):
|
240 |
+
def __init__(
|
241 |
+
self,
|
242 |
+
state_shape: Union[int, Sequence[int]],
|
243 |
+
action_shape: Union[int, Sequence[int]] = 0,
|
244 |
+
hidden_sizes: Sequence[int] = (),
|
245 |
+
norm_layer: Optional[ModuleType] = None,
|
246 |
+
activation: Optional[ModuleType] = nn.ReLU,
|
247 |
+
device: Union[str, int, torch.device] = "cpu",
|
248 |
+
softmax: bool = False,
|
249 |
+
concat: bool = False,
|
250 |
+
num_atoms: int = 1,
|
251 |
+
dueling_param: Optional[Tuple[Dict[str, Any], Dict[str, Any]]] = None,
|
252 |
+
linear_layer: Type[nn.Linear] = nn.Linear,
|
253 |
+
hidden_dim: int = 128,
|
254 |
+
bidirectional: bool = True,
|
255 |
+
dropout: float = 0.,
|
256 |
+
n_layers: int = 1,
|
257 |
+
max_length: int = 512,
|
258 |
+
trans_model_name: str = 'bert-base-uncased',
|
259 |
+
) -> None:
|
260 |
+
nn.Module.__init__(self)
|
261 |
+
self.device = device
|
262 |
+
self.softmax = softmax
|
263 |
+
self.num_atoms = num_atoms
|
264 |
+
self.hidden_dim = hidden_dim
|
265 |
+
self.bidirectional = bidirectional
|
266 |
+
self.dropout = dropout
|
267 |
+
self.n_layers = n_layers
|
268 |
+
self.trans_model_name = trans_model_name
|
269 |
+
self.max_length = max_length
|
270 |
+
|
271 |
+
input_dim = int(np.prod(state_shape))
|
272 |
+
action_dim = int(np.prod(action_shape)) * num_atoms
|
273 |
+
if concat:
|
274 |
+
input_dim += action_dim
|
275 |
+
self.use_dueling = dueling_param is not None
|
276 |
+
output_dim = action_dim if not self.use_dueling and not concat else 0
|
277 |
+
self.output_dim = output_dim or hidden_dim
|
278 |
+
self.model = MLP(768, output_dim, hidden_sizes, norm_layer, activation, device, linear_layer)
|
279 |
+
if self.use_dueling: # dueling DQN
|
280 |
+
q_kwargs, v_kwargs = dueling_param # type: ignore
|
281 |
+
q_output_dim, v_output_dim = 0, 0
|
282 |
+
if not concat:
|
283 |
+
q_output_dim, v_output_dim = action_dim, num_atoms
|
284 |
+
q_kwargs: Dict[str, Any] = {
|
285 |
+
**q_kwargs, "input_dim": self.output_dim,
|
286 |
+
"output_dim": q_output_dim,
|
287 |
+
"device": self.device
|
288 |
+
}
|
289 |
+
v_kwargs: Dict[str, Any] = {
|
290 |
+
**v_kwargs, "input_dim": self.output_dim,
|
291 |
+
"output_dim": v_output_dim,
|
292 |
+
"device": self.device
|
293 |
+
}
|
294 |
+
self.Q, self.V = MLP(**q_kwargs), MLP(**v_kwargs)
|
295 |
+
self.output_dim = self.Q.output_dim
|
296 |
+
self.bert_model = AutoModel.from_pretrained(self.trans_model_name).to(self.device)
|
297 |
+
self.tokenizer = AutoTokenizer.from_pretrained(trans_model_name)
|
298 |
+
from transformers import logging
|
299 |
+
logging.set_verbosity_error()
|
300 |
+
|
301 |
+
def bert_CLS_embedding(self, x, max_length=512):
|
302 |
+
text = x
|
303 |
+
if isinstance(text, np.ndarray):
|
304 |
+
text = list(text)
|
305 |
+
tokens = self.tokenizer(text, max_length=max_length, padding='max_length', truncation=True, return_tensors='pt')
|
306 |
+
input_ids = tokens['input_ids'].to(self.device)
|
307 |
+
attention_mask = tokens['attention_mask'].to(self.device)
|
308 |
+
with torch.no_grad():
|
309 |
+
outputs = self.bert_model(input_ids, attention_mask=attention_mask)
|
310 |
+
embeddings = outputs[0][:, 0, :]
|
311 |
+
return embeddings
|
312 |
+
|
313 |
+
def forward(
|
314 |
+
self,
|
315 |
+
obs: Union[np.ndarray, torch.Tensor],
|
316 |
+
state: Any = None,
|
317 |
+
info: Dict[str, Any] = {},
|
318 |
+
) -> Tuple[torch.Tensor, Any]:
|
319 |
+
"""Mapping: obs -> flatten (inside MLP)-> logits."""
|
320 |
+
embedding = self.bert_CLS_embedding(obs, max_length=self.max_length)
|
321 |
+
logits = self.model(embedding)
|
322 |
+
bsz = logits.shape[0]
|
323 |
+
if self.use_dueling: # Dueling DQN
|
324 |
+
q, v = self.Q(logits), self.V(logits)
|
325 |
+
if self.num_atoms > 1:
|
326 |
+
q = q.view(bsz, -1, self.num_atoms)
|
327 |
+
v = v.view(bsz, -1, self.num_atoms)
|
328 |
+
logits = q - q.mean(dim=1, keepdim=True) + v
|
329 |
+
elif self.num_atoms > 1:
|
330 |
+
logits = logits.view(bsz, -1, self.num_atoms)
|
331 |
+
if self.softmax:
|
332 |
+
logits = torch.softmax(logits, dim=-1)
|
333 |
+
return logits, state
|
334 |
+
|
335 |
+
|
336 |
+
class Net_Bert_CNN_tianshou(Net_GRU_Bert_tianshou):
|
337 |
+
def __init__(
|
338 |
+
self,
|
339 |
+
state_shape: Union[int, Sequence[int]],
|
340 |
+
action_shape: Union[int, Sequence[int]] = 0,
|
341 |
+
hidden_sizes: Sequence[int] = (),
|
342 |
+
norm_layer: Optional[ModuleType] = None,
|
343 |
+
activation: Optional[ModuleType] = nn.ReLU,
|
344 |
+
device: Union[str, int, torch.device] = "cpu",
|
345 |
+
softmax: bool = False,
|
346 |
+
concat: bool = False,
|
347 |
+
num_atoms: int = 1,
|
348 |
+
dueling_param: Optional[Tuple[Dict[str, Any], Dict[str, Any]]] = None,
|
349 |
+
linear_layer: Type[nn.Linear] = nn.Linear,
|
350 |
+
hidden_dim: int = 128,
|
351 |
+
bidirectional: bool = True,
|
352 |
+
dropout: float = 0.,
|
353 |
+
n_layers: int = 1,
|
354 |
+
max_length: int = 512,
|
355 |
+
trans_model_name: str = 'bert-base-uncased',
|
356 |
+
) -> None:
|
357 |
+
nn.Module.__init__(self)
|
358 |
+
self.device = device
|
359 |
+
self.softmax = softmax
|
360 |
+
self.num_atoms = num_atoms
|
361 |
+
self.hidden_dim = hidden_dim
|
362 |
+
self.bidirectional = bidirectional
|
363 |
+
self.dropout = dropout
|
364 |
+
self.n_layers = n_layers
|
365 |
+
self.trans_model_name = trans_model_name
|
366 |
+
self.max_length = max_length
|
367 |
+
|
368 |
+
input_dim = int(np.prod(state_shape))
|
369 |
+
action_dim = int(np.prod(action_shape)) * num_atoms
|
370 |
+
if concat:
|
371 |
+
input_dim += action_dim
|
372 |
+
self.use_dueling = dueling_param is not None
|
373 |
+
output_dim = action_dim if not self.use_dueling and not concat else 0
|
374 |
+
self.output_dim = output_dim or hidden_dim
|
375 |
+
self.model = MyCNN(768, output_dim, hidden_sizes, norm_layer, activation, device, linear_layer, flatten_input=False)
|
376 |
+
if self.use_dueling: # dueling DQN
|
377 |
+
q_kwargs, v_kwargs = dueling_param # type: ignore
|
378 |
+
q_output_dim, v_output_dim = 0, 0
|
379 |
+
if not concat:
|
380 |
+
q_output_dim, v_output_dim = action_dim, num_atoms
|
381 |
+
q_kwargs: Dict[str, Any] = {
|
382 |
+
**q_kwargs, "input_dim": self.output_dim,
|
383 |
+
"output_dim": q_output_dim,
|
384 |
+
"device": self.device
|
385 |
+
}
|
386 |
+
v_kwargs: Dict[str, Any] = {
|
387 |
+
**v_kwargs, "input_dim": self.output_dim,
|
388 |
+
"output_dim": v_output_dim,
|
389 |
+
"device": self.device
|
390 |
+
}
|
391 |
+
self.Q, self.V = MLP(**q_kwargs), MLP(**v_kwargs)
|
392 |
+
self.output_dim = self.Q.output_dim
|
393 |
+
self.bert_model = AutoModel.from_pretrained(self.trans_model_name).to(self.device)
|
394 |
+
self.tokenizer = AutoTokenizer.from_pretrained(trans_model_name)
|
395 |
+
from transformers import logging
|
396 |
+
logging.set_verbosity_error()
|
397 |
+
|
398 |
+
class DQN_GRU(nn.Module):
|
399 |
+
"""Reference: Human-level control through deep reinforcement learning.
|
400 |
+
"""
|
401 |
+
|
402 |
+
def __init__(
|
403 |
+
self,
|
404 |
+
state_shape: Union[int, Sequence[int]],
|
405 |
+
action_shape: Sequence[int],
|
406 |
+
device: Union[str, int, torch.device] = "cpu",
|
407 |
+
features_only: bool = False,
|
408 |
+
output_dim: Optional[int] = None,
|
409 |
+
hidden_dim: int = 128,
|
410 |
+
n_layers: int = 1,
|
411 |
+
dropout: float = 0.,
|
412 |
+
bidirectional: bool = True,
|
413 |
+
trans_model_name: str = 'bert-base-uncased',
|
414 |
+
max_length: int = 512,
|
415 |
+
) -> None:
|
416 |
+
super().__init__()
|
417 |
+
self.device = device
|
418 |
+
self.max_length = max_length
|
419 |
+
action_dim = int(np.prod(action_shape))
|
420 |
+
self.net = MyGRU(768, hidden_dim, n_layers, dropout, bidirectional,
|
421 |
+
hidden_dim)
|
422 |
+
if not features_only:
|
423 |
+
self.net = MyGRU(768, hidden_dim, n_layers, dropout, bidirectional,
|
424 |
+
action_dim)
|
425 |
+
self.output_dim = action_dim
|
426 |
+
elif output_dim is not None:
|
427 |
+
self.net = MyGRU(768, hidden_dim, n_layers, dropout, bidirectional,
|
428 |
+
output_dim)
|
429 |
+
self.output_dim = output_dim
|
430 |
+
else:
|
431 |
+
self.net = MyGRU(768, hidden_dim, n_layers, dropout, bidirectional,
|
432 |
+
hidden_dim)
|
433 |
+
self.output_dim = hidden_dim
|
434 |
+
self.trans_model_name = trans_model_name
|
435 |
+
self.bert_model = AutoModel.from_pretrained(self.trans_model_name).to(self.device)
|
436 |
+
self.tokenizer = AutoTokenizer.from_pretrained(trans_model_name)
|
437 |
+
from transformers import logging
|
438 |
+
logging.set_verbosity_error()
|
439 |
+
|
440 |
+
def bert_embedding(self, x, max_length=512):
|
441 |
+
text = x
|
442 |
+
if isinstance(text, np.ndarray):
|
443 |
+
text = list(text)
|
444 |
+
tokens = self.tokenizer(text, max_length=max_length, padding='max_length', truncation=True, return_tensors='pt')
|
445 |
+
input_ids = tokens['input_ids'].to(self.device)
|
446 |
+
attention_mask = tokens['attention_mask'].to(self.device)
|
447 |
+
with torch.no_grad():
|
448 |
+
outputs = self.bert_model(input_ids, attention_mask=attention_mask)
|
449 |
+
embeddings = outputs.last_hidden_state
|
450 |
+
return embeddings
|
451 |
+
|
452 |
+
def forward(
|
453 |
+
self,
|
454 |
+
obs: Union[np.ndarray, torch.Tensor],
|
455 |
+
state: Optional[Any] = None,
|
456 |
+
info: Dict[str, Any] = {},
|
457 |
+
) -> Tuple[torch.Tensor, Any]:
|
458 |
+
r"""Mapping: s -> Q(s, \*)."""
|
459 |
+
embedding = self.bert_embedding(obs, max_length=self.max_length)
|
460 |
+
return self.net(embedding), state
|
461 |
+
|
462 |
+
class Rainbow_GRU(DQN_GRU):
|
463 |
+
"""Reference: Rainbow: Combining Improvements in Deep Reinforcement Learning.
|
464 |
+
"""
|
465 |
+
|
466 |
+
def __init__(
|
467 |
+
self,
|
468 |
+
state_shape: Union[int, Sequence[int]],
|
469 |
+
action_shape: Sequence[int],
|
470 |
+
num_atoms: int = 51,
|
471 |
+
noisy_std: float = 0.5,
|
472 |
+
device: Union[str, int, torch.device] = "cpu",
|
473 |
+
is_dueling: bool = True,
|
474 |
+
is_noisy: bool = True,
|
475 |
+
output_dim: Optional[int] = None,
|
476 |
+
hidden_dim: int = 128,
|
477 |
+
n_layers: int = 1,
|
478 |
+
dropout: float = 0.,
|
479 |
+
bidirectional: bool = True,
|
480 |
+
trans_model_name: str = 'bert-base-uncased',
|
481 |
+
max_length: int = 512,
|
482 |
+
) -> None:
|
483 |
+
super().__init__(state_shape, action_shape, device, features_only=True,
|
484 |
+
output_dim=output_dim, hidden_dim=hidden_dim, n_layers=n_layers,
|
485 |
+
dropout=dropout, bidirectional=bidirectional, trans_model_name=trans_model_name)
|
486 |
+
self.action_num = np.prod(action_shape)
|
487 |
+
self.num_atoms = num_atoms
|
488 |
+
|
489 |
+
def linear(x, y):
|
490 |
+
if is_noisy:
|
491 |
+
return NoisyLinear(x, y, noisy_std)
|
492 |
+
else:
|
493 |
+
return nn.Linear(x, y)
|
494 |
+
|
495 |
+
self.Q = nn.Sequential(
|
496 |
+
linear(self.output_dim, 512), nn.ReLU(inplace=True),
|
497 |
+
linear(512, self.action_num * self.num_atoms)
|
498 |
+
)
|
499 |
+
self._is_dueling = is_dueling
|
500 |
+
if self._is_dueling:
|
501 |
+
self.V = nn.Sequential(
|
502 |
+
linear(self.output_dim, 512), nn.ReLU(inplace=True),
|
503 |
+
linear(512, self.num_atoms)
|
504 |
+
)
|
505 |
+
self.output_dim = self.action_num * self.num_atoms
|
506 |
+
|
507 |
+
def forward(
|
508 |
+
self,
|
509 |
+
obs: Union[np.ndarray, torch.Tensor],
|
510 |
+
state: Optional[Any] = None,
|
511 |
+
info: Dict[str, Any] = {},
|
512 |
+
) -> Tuple[torch.Tensor, Any]:
|
513 |
+
r"""Mapping: x -> Z(x, \*)."""
|
514 |
+
obs, state = super().forward(obs)
|
515 |
+
q = self.Q(obs)
|
516 |
+
q = q.view(-1, self.action_num, self.num_atoms)
|
517 |
+
if self._is_dueling:
|
518 |
+
v = self.V(obs)
|
519 |
+
v = v.view(-1, 1, self.num_atoms)
|
520 |
+
logits = q - q.mean(dim=1, keepdim=True) + v
|
521 |
+
else:
|
522 |
+
logits = q
|
523 |
+
probs = logits.softmax(dim=2)
|
524 |
+
return probs, state
|
525 |
+
|
526 |
+
class Net_GRU_nn_emb_tianshou(Net):
|
527 |
+
|
528 |
+
def __init__(
|
529 |
+
self,
|
530 |
+
action_shape: Union[int, Sequence[int]] = 0,
|
531 |
+
hidden_sizes: Sequence[int] = (),
|
532 |
+
norm_layer: Optional[ModuleType] = None,
|
533 |
+
activation: Optional[ModuleType] = nn.ReLU,
|
534 |
+
device: Union[str, int, torch.device] = "cpu",
|
535 |
+
softmax: bool = False,
|
536 |
+
concat: bool = False,
|
537 |
+
num_atoms: int = 1,
|
538 |
+
dueling_param: Optional[Tuple[Dict[str, Any], Dict[str, Any]]] = None,
|
539 |
+
linear_layer: Type[nn.Linear] = nn.Linear,
|
540 |
+
hidden_dim: int = 128,
|
541 |
+
bidirectional: bool = True,
|
542 |
+
dropout: float = 0.,
|
543 |
+
n_layers: int = 1,
|
544 |
+
max_length: int = 512,
|
545 |
+
trans_model_name: str = 'bert-base-uncased',
|
546 |
+
word_emb_dim: int = 128,
|
547 |
+
) -> None:
|
548 |
+
nn.Module.__init__(self)
|
549 |
+
self.device = device
|
550 |
+
self.softmax = softmax
|
551 |
+
self.num_atoms = num_atoms
|
552 |
+
self.hidden_dim = hidden_dim
|
553 |
+
self.bidirectional = bidirectional
|
554 |
+
self.dropout = dropout
|
555 |
+
self.n_layers = n_layers
|
556 |
+
self.trans_model_name = trans_model_name
|
557 |
+
self.max_length = max_length
|
558 |
+
|
559 |
+
action_dim = int(np.prod(action_shape)) * num_atoms
|
560 |
+
self.use_dueling = dueling_param is not None
|
561 |
+
output_dim = action_dim if not self.use_dueling and not concat else 0
|
562 |
+
self.output_dim = output_dim or hidden_dim
|
563 |
+
|
564 |
+
self.tokenizer = AutoTokenizer.from_pretrained(trans_model_name)
|
565 |
+
from transformers import logging
|
566 |
+
logging.set_verbosity_error()
|
567 |
+
self.vocab_size = self.tokenizer.vocab_size
|
568 |
+
self.embedding = nn.Embedding(self.vocab_size, word_emb_dim)
|
569 |
+
self.model = MyGRU(word_emb_dim, self.hidden_dim, self.n_layers,
|
570 |
+
self.dropout, self.bidirectional, self.output_dim)
|
571 |
+
if self.use_dueling: # dueling DQN
|
572 |
+
q_kwargs, v_kwargs = dueling_param # type: ignore
|
573 |
+
q_output_dim, v_output_dim = 0, 0
|
574 |
+
if not concat:
|
575 |
+
q_output_dim, v_output_dim = action_dim, num_atoms
|
576 |
+
q_kwargs: Dict[str, Any] = {
|
577 |
+
**q_kwargs, "input_dim": self.output_dim,
|
578 |
+
"output_dim": q_output_dim,
|
579 |
+
"device": self.device
|
580 |
+
}
|
581 |
+
v_kwargs: Dict[str, Any] = {
|
582 |
+
**v_kwargs, "input_dim": self.output_dim,
|
583 |
+
"output_dim": v_output_dim,
|
584 |
+
"device": self.device
|
585 |
+
}
|
586 |
+
self.Q, self.V = MLP(**q_kwargs), MLP(**v_kwargs)
|
587 |
+
self.output_dim = self.Q.output_dim
|
588 |
+
|
589 |
+
|
590 |
+
def forward(
|
591 |
+
self,
|
592 |
+
obs: Union[np.ndarray, torch.Tensor],
|
593 |
+
state: Any = None,
|
594 |
+
info: Dict[str, Any] = {},
|
595 |
+
) -> Tuple[torch.Tensor, Any]:
|
596 |
+
"""Mapping: obs -> flatten (inside MLP)-> logits."""
|
597 |
+
if isinstance(obs, np.ndarray):
|
598 |
+
text = list(obs)
|
599 |
+
else:
|
600 |
+
text = obs
|
601 |
+
tokens = self.tokenizer(text, max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt')
|
602 |
+
input_ids = tokens['input_ids'].to(self.device)
|
603 |
+
attention_mask = tokens['attention_mask'].to(self.device)
|
604 |
+
embedding = self.embedding(input_ids)
|
605 |
+
mask = attention_mask.unsqueeze(-1).expand(embedding.size()).float()
|
606 |
+
embedding = embedding * mask
|
607 |
+
logits = self.model(embedding)
|
608 |
+
bsz = logits.shape[0]
|
609 |
+
if self.use_dueling: # Dueling DQN
|
610 |
+
q, v = self.Q(logits), self.V(logits)
|
611 |
+
if self.num_atoms > 1:
|
612 |
+
q = q.view(bsz, -1, self.num_atoms)
|
613 |
+
v = v.view(bsz, -1, self.num_atoms)
|
614 |
+
logits = q - q.mean(dim=1, keepdim=True) + v
|
615 |
+
elif self.num_atoms > 1:
|
616 |
+
logits = logits.view(bsz, -1, self.num_atoms)
|
617 |
+
if self.softmax:
|
618 |
+
logits = torch.softmax(logits, dim=-1)
|
619 |
+
return logits, state
|
620 |
+
|
621 |
+
|
deciders/__init__.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from .act import NaiveAct, RandomAct
|
3 |
+
from .selfask import SelfAskAct
|
4 |
+
from .pal import PAL
|
5 |
+
from .cot import ChainOfThought
|
6 |
+
from .self_consistency import SelfConsistency
|
7 |
+
from .spp import SPP
|
8 |
+
from .reflexion import Reflexion
|
9 |
+
from .jarvis import Jarvis
|
10 |
+
from .jarvis_without_insights import JarvisWithoutInsight
|
11 |
+
from .jarvis_without_suggestions import JarvisWithoutSuggestions
|
12 |
+
from .jarvis_without_shortmem import JarvisWithoutShortMem
|
13 |
+
|
14 |
+
REGISTRY = {}
|
15 |
+
REGISTRY['random_actor'] = RandomAct
|
16 |
+
REGISTRY['naive_actor'] = NaiveAct
|
17 |
+
REGISTRY['selfask_actor'] = SelfAskAct
|
18 |
+
REGISTRY['pal_actor'] = PAL
|
19 |
+
REGISTRY['cot_actor'] = ChainOfThought
|
20 |
+
REGISTRY['self_consistency_actor'] = SelfConsistency
|
21 |
+
REGISTRY['spp_actor'] = SPP
|
22 |
+
REGISTRY['reflexion_actor'] = Reflexion
|
23 |
+
REGISTRY['jarvis_actor'] = Jarvis
|
24 |
+
REGISTRY['jarvis_actor_woi'] = JarvisWithoutInsight
|
25 |
+
REGISTRY['jarvis_actor_wosug'] = JarvisWithoutSuggestions
|
26 |
+
REGISTRY['jarvis_actor_wosh'] = JarvisWithoutShortMem
|
deciders/act.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file contains functions for interacting with the ChatGPT model
|
2 |
+
|
3 |
+
import openai
|
4 |
+
from .gpt import gpt
|
5 |
+
from loguru import logger
|
6 |
+
from .parser import PARSERS
|
7 |
+
from langchain.output_parsers import PydanticOutputParser
|
8 |
+
from langchain.output_parsers import OutputFixingParser
|
9 |
+
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
|
10 |
+
from memory.env_history import EnvironmentHistory
|
11 |
+
import tiktoken
|
12 |
+
import json
|
13 |
+
import re
|
14 |
+
from .utils import run_chain
|
15 |
+
|
16 |
+
class RandomAct():
|
17 |
+
def __init__(self, action_space):
|
18 |
+
self.action_space = action_space
|
19 |
+
|
20 |
+
def act(self, state_description, action_description, env_info, game_description=None, goal_description=None):
|
21 |
+
return self.action_space.sample()+1, '', '', '', 0, 0
|
22 |
+
|
23 |
+
class NaiveAct(gpt):
|
24 |
+
def __init__(self, action_space, args, prompts, distiller, temperature=0.0, max_tokens=512, logger=None):
|
25 |
+
self.action_space = action_space
|
26 |
+
self.temperature = temperature
|
27 |
+
self.action_desc_dict = args.action_desc_dict
|
28 |
+
self.args = args
|
29 |
+
self.prompts = prompts
|
30 |
+
self.max_tokens = max_tokens
|
31 |
+
self.prompt_level = args.prompt_level
|
32 |
+
if args.gpt_version == "gpt-35-turbo":
|
33 |
+
model = "gpt-3.5-turbo"
|
34 |
+
else:
|
35 |
+
model = args.gpt_version
|
36 |
+
self.encoding = tiktoken.encoding_for_model(model)
|
37 |
+
super().__init__()
|
38 |
+
self.distiller = distiller
|
39 |
+
self.fewshot_example_initialization(args.prompt_level, args.prompt_path, distiller = self.distiller)
|
40 |
+
self.default_action = 1
|
41 |
+
self.parser = self._parser_initialization()
|
42 |
+
self.irr_game_description = ''
|
43 |
+
self.memory = []
|
44 |
+
self.env_history = EnvironmentHistory()
|
45 |
+
self.first_call = True
|
46 |
+
self.logger = logger
|
47 |
+
if self.prompt_level in [2, 4]:
|
48 |
+
self.memory = self.summarized_fewshot_example
|
49 |
+
if args.use_short_mem == 1:
|
50 |
+
self.use_short_mem = True
|
51 |
+
self.mem_num = self.args.trajectories_num
|
52 |
+
else:
|
53 |
+
self.use_short_mem = False
|
54 |
+
self.mem_num = 0
|
55 |
+
|
56 |
+
def num_tokens_from_string(self,string: str) -> int:
|
57 |
+
"""Returns the number of tokens in a text string."""
|
58 |
+
num_tokens = len(self.encoding.encode(string))
|
59 |
+
return num_tokens
|
60 |
+
|
61 |
+
def update_mem(self,):
|
62 |
+
traj = "Firstly, the description and the goal of the task will be provided. Please pay close attention to comprehend the information presented below.\n"
|
63 |
+
traj += "Task Description: " + self.game_description + '\n'
|
64 |
+
traj += "Goal Description: " + self.goal_description + '\n'
|
65 |
+
traj += self.action_description
|
66 |
+
traj += "Below is the historical data for this round of the game, which includes the state and corresponding action for each step.\n"
|
67 |
+
traj += str(self.env_history)
|
68 |
+
# print(traj)
|
69 |
+
self._update_mem(traj)
|
70 |
+
|
71 |
+
def _update_mem(self, traj):
|
72 |
+
my_reflection = self.distiller.generate(traj, self.memory)
|
73 |
+
self.memory.append(my_reflection)
|
74 |
+
self.env_history.reset()
|
75 |
+
|
76 |
+
def clear_mem(self):
|
77 |
+
self.pre_memory = []
|
78 |
+
self.post_memory = []
|
79 |
+
self.is_first = True
|
80 |
+
self._update_mem(None)
|
81 |
+
|
82 |
+
|
83 |
+
def _parser_initialization(self):
|
84 |
+
if hasattr(self.action_space, 'n'):
|
85 |
+
assert self.action_space.n in PARSERS.keys(), f'Action space {self.action_space} is not supported.'
|
86 |
+
num_action = self.action_space.n
|
87 |
+
else:
|
88 |
+
num_action = 1
|
89 |
+
|
90 |
+
# autofixing_chat = AzureChatOpenAI(
|
91 |
+
# openai_api_type=openai.api_type,
|
92 |
+
# openai_api_version=openai.api_version,
|
93 |
+
# openai_api_base=openai.api_base,
|
94 |
+
# openai_api_key=openai.api_key,
|
95 |
+
# deployment_name="gpt-35-turbo",
|
96 |
+
# temperature=self.temperature,
|
97 |
+
# max_tokens=self.max_tokens
|
98 |
+
# )
|
99 |
+
autofixing_chat = ChatOpenAI(temperature=0, openai_api_key=openai.api_key)
|
100 |
+
|
101 |
+
parser = PydanticOutputParser(pydantic_object=PARSERS[num_action])
|
102 |
+
autofixing_parser = OutputFixingParser.from_llm(
|
103 |
+
llm=autofixing_chat, parser=parser)
|
104 |
+
|
105 |
+
return autofixing_parser
|
106 |
+
|
107 |
+
def fewshot_example_initialization(self, level, path=None, distiller=None):
|
108 |
+
self.fewshot_example = []
|
109 |
+
self.irr_few_shot_examples = []
|
110 |
+
self.prompt_level = level
|
111 |
+
self.expert_knowledge = None
|
112 |
+
if level in [1,3]:
|
113 |
+
self.irr_few_shot_examples = self.prompts.TASK_IRRELEVANT_PROMPTS
|
114 |
+
elif level == 5:
|
115 |
+
if hasattr(self.prompts, "expert_prompt"):
|
116 |
+
self.expert_knowledge = self.prompts.expert_prompt
|
117 |
+
self.fewshot_example = self.prompts.PERCEPTRON_BASIC_FS_EXAMPLES
|
118 |
+
else:
|
119 |
+
self.irr_few_shot_examples = self.prompts.TASK_IRRELEVANT_PROMPTS
|
120 |
+
json_file = f'{path}_l{level}.json'
|
121 |
+
with open(json_file, 'r') as infile:
|
122 |
+
data = json.load(infile)
|
123 |
+
max_step_num = 0
|
124 |
+
for traj in data:
|
125 |
+
traj_text = traj[0]['game_description']
|
126 |
+
traj_text += traj[0]['goal_description']
|
127 |
+
for i, transition in enumerate(traj):
|
128 |
+
traj_text += transition['observation']
|
129 |
+
traj_text += f"> {transition['action']}"
|
130 |
+
one_traj_token = self.num_tokens_from_string(traj_text)
|
131 |
+
if one_traj_token > 5000:
|
132 |
+
max_step_num = i+1
|
133 |
+
break
|
134 |
+
traj_text += f"Your performance is: {transition['cum_reward']}"
|
135 |
+
if not max_step_num:
|
136 |
+
max_step_num = 200
|
137 |
+
self.summarized_fewshot_example = self.distiller.generate_from_file(json_file,max_step_num=max_step_num)
|
138 |
+
|
139 |
+
def response(self, state_description, action_description, env_info, game_description=None, goal_description=None, fewshot_examples=None):
|
140 |
+
if env_info['future_summary']:
|
141 |
+
prompt = f"{game_description}\n{goal_description}\n{fewshot_examples}\n{state_description}\n{env_info['future_summary']}\n{action_description} "
|
142 |
+
else:
|
143 |
+
prompt = f"{game_description}\n{goal_description}\n{fewshot_examples}\nCurrent {state_description}\n{action_description} "
|
144 |
+
prompt += "Please select an action based on the current game state and the information you get. You must select the appropriate action from the given action descriptions and cannot refrain from taking action or performing any prohibited actions. Your Action is: "
|
145 |
+
print(f"prompt is {prompt}")
|
146 |
+
res = openai.Completion.create(
|
147 |
+
engine=self.args.gpt_version,
|
148 |
+
prompt=prompt,
|
149 |
+
temperature=self.temperature,
|
150 |
+
max_tokens=self.max_tokens,
|
151 |
+
)
|
152 |
+
return prompt, res
|
153 |
+
|
154 |
+
def _add_history_before_action(self, game_description, goal_description, state_description):
|
155 |
+
self.game_description = game_description
|
156 |
+
self.goal_description = goal_description
|
157 |
+
self.env_history.add("observation", state_description)
|
158 |
+
# print(self.env_history)
|
159 |
+
if len(self.env_history) >= 2:
|
160 |
+
one_history_token = self.num_tokens_from_string(self.env_history.get_one_history())
|
161 |
+
self.env_history.set_history(6000 // one_history_token)
|
162 |
+
|
163 |
+
def act(self, state_description, action_description, env_info, game_description=None, goal_description=None, logfile=None):
|
164 |
+
self._add_history_before_action(game_description, goal_description, state_description)
|
165 |
+
asking_round = 0
|
166 |
+
res = None
|
167 |
+
action = None
|
168 |
+
prompt = None
|
169 |
+
if not self.logger:
|
170 |
+
logger.remove()
|
171 |
+
self.logger = logger.add(logfile, colorize=True, enqueue=True)
|
172 |
+
|
173 |
+
if self.args.prompt_level == 5:
|
174 |
+
my_mem = ""
|
175 |
+
if self.fewshot_example:
|
176 |
+
my_mem += "Here are some examples of how you should complete a task."
|
177 |
+
for examples in self.fewshot_example:
|
178 |
+
my_mem += "\nQuestion: \n" + examples['question'] + "Answer: \n" + examples['answer']
|
179 |
+
my_mem += '\nNow you are in the task.\n'
|
180 |
+
elif self.args.prompt_level in [2,3,4]:
|
181 |
+
my_mem = ""
|
182 |
+
if self.prompt_level == 2:
|
183 |
+
my_mem += 'I have collected a few trajectories from a random policy, and the summaries are listed below.'
|
184 |
+
elif self.prompt_level == 3:
|
185 |
+
my_mem += 'I have collected a few trajectories before, and the summaries are listed below.'
|
186 |
+
elif self.prompt_level == 4:
|
187 |
+
my_mem += 'I have collected a few trajectories from an expert policy, and the summaries are listed below.'
|
188 |
+
my_mem += self._read_mem()
|
189 |
+
else:
|
190 |
+
my_mem = ""
|
191 |
+
|
192 |
+
if self.use_short_mem:
|
193 |
+
if len(self.env_history) > 1:
|
194 |
+
my_mem += '\nSubsequently, I will offer pertinent guidance or information about the task. Please utilize this instruction to accomplish the given task effectively.'
|
195 |
+
my_mem += f"\nBelow are the latest {min(self.args.short_mem_num,len(self.env_history)//2)} historical data entries:\n"
|
196 |
+
my_mem += f"{self.env_history.get_histories(self.mem_num)}"
|
197 |
+
|
198 |
+
while asking_round < 3:
|
199 |
+
prompt, res = self.response(state_description, action_description, env_info, game_description, goal_description, my_mem)
|
200 |
+
action_str = res.choices[0].text.strip()
|
201 |
+
print(f'my anwser is {action_str}')
|
202 |
+
# import pdb; pdb.set_trace()
|
203 |
+
try:
|
204 |
+
if "Continuous" in self.args.env_name:
|
205 |
+
action = float(re.findall(r"[-+]?\d*\.\d+", action_str)[0])
|
206 |
+
|
207 |
+
else:
|
208 |
+
action = int(re.findall(r"\d+", action_str)[0])
|
209 |
+
except:
|
210 |
+
action = None
|
211 |
+
asking_round += 1
|
212 |
+
continue
|
213 |
+
|
214 |
+
if "Continuous" not in self.args.env_name:
|
215 |
+
if (action-1) in self.action_space:
|
216 |
+
break
|
217 |
+
else:
|
218 |
+
asking_round += 1
|
219 |
+
action = None
|
220 |
+
else:
|
221 |
+
if action >= self.action_space.low and action <= self.action_space.high:
|
222 |
+
break
|
223 |
+
else:
|
224 |
+
asking_round += 1
|
225 |
+
action = None
|
226 |
+
|
227 |
+
if action is None:
|
228 |
+
print('err on selecting action')
|
229 |
+
action = self.default_action
|
230 |
+
self._add_history_after_action(action)
|
231 |
+
self.logger.info(f'\n{prompt}')
|
232 |
+
self.logger.info(f'The GPT response is: {res}.')
|
233 |
+
self.logger.info(f'The optimal action is: {action}.')
|
234 |
+
return action, prompt, res, 0, 0
|
235 |
+
|
236 |
+
def _read_mem(self, ):
|
237 |
+
memory = self.memory
|
238 |
+
mem_str = ""
|
239 |
+
if len(memory) > 5:
|
240 |
+
memory = memory[-5:]
|
241 |
+
if len(memory) > 0:
|
242 |
+
mem_str += '\nYour memory for the task below:'
|
243 |
+
for i, m in enumerate(memory):
|
244 |
+
mem_str += f'\nTrial {i}:\n{m.strip()}'
|
245 |
+
return mem_str
|
246 |
+
|
247 |
+
def _add_history_after_action(self, action):
|
248 |
+
self.env_history.add('action', action)
|
deciders/cot.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import openai
|
2 |
+
from .misc import history_to_str
|
3 |
+
from langchain.chat_models import AzureChatOpenAI
|
4 |
+
from langchain.prompts.chat import (
|
5 |
+
PromptTemplate,
|
6 |
+
ChatPromptTemplate,
|
7 |
+
SystemMessagePromptTemplate,
|
8 |
+
HumanMessagePromptTemplate,
|
9 |
+
)
|
10 |
+
from langchain.prompts.few_shot import FewShotPromptTemplate
|
11 |
+
from langchain import LLMChain
|
12 |
+
from loguru import logger
|
13 |
+
from langchain.callbacks import FileCallbackHandler
|
14 |
+
from langchain.callbacks import get_openai_callback
|
15 |
+
from .act import NaiveAct
|
16 |
+
from .utils import run_chain
|
17 |
+
|
18 |
+
|
19 |
+
class ChainOfThought(NaiveAct):
|
20 |
+
def __init__(self, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None, logger=None):
|
21 |
+
super().__init__(action_space, args, prompts, distiller, temperature, max_tokens,logger)
|
22 |
+
|
23 |
+
def act(
|
24 |
+
self,
|
25 |
+
state_description,
|
26 |
+
action_description,
|
27 |
+
env_info,
|
28 |
+
game_description,
|
29 |
+
goal_description,
|
30 |
+
logfile=None,
|
31 |
+
):
|
32 |
+
self.action_description = action_description
|
33 |
+
self._add_history_before_action(game_description, goal_description, state_description)
|
34 |
+
chat = AzureChatOpenAI(
|
35 |
+
openai_api_type=openai.api_type,
|
36 |
+
openai_api_version=openai.api_version,
|
37 |
+
openai_api_base=openai.api_base,
|
38 |
+
openai_api_key=openai.api_key,
|
39 |
+
deployment_name=self.args.gpt_version,
|
40 |
+
temperature=self.temperature,
|
41 |
+
max_tokens=self.max_tokens
|
42 |
+
)
|
43 |
+
|
44 |
+
suffix_flag = False
|
45 |
+
reply_format_description = \
|
46 |
+
"Your response should choose an optimal action from a valid action list and terminate with the following format: "
|
47 |
+
|
48 |
+
# System Message
|
49 |
+
human_template = "Now, you are completing a challenging task. You must carefully understand the Chain-of-Thought method you will use and apply it to the following task.\n"
|
50 |
+
|
51 |
+
# task-irrelevant SystemMessage
|
52 |
+
if self.irr_few_shot_examples:
|
53 |
+
human_template += 'In the following example, I shall present a set of question and answer with the Chain-of-Thought method. Please adhere to the format and reasoning of the provided response when addressing the subsequent task.\n'
|
54 |
+
for i, examples in enumerate(self.irr_few_shot_examples):
|
55 |
+
human_template += f"\nExample {i+1}:\n"
|
56 |
+
human_template += "Question: \n" + examples['question'] + "\nAnswer: \n" + examples['answer']
|
57 |
+
|
58 |
+
# task-irrelevant few shot if have
|
59 |
+
if self.irr_few_shot_examples:
|
60 |
+
human_template += "\nMoving forward, I will describe the task, the goal, and the actions you may execute. Please pay close attention to comprehend the information presented below.\n"
|
61 |
+
|
62 |
+
if self.fewshot_example:
|
63 |
+
human_template += "I will describe the task, the goal, and the actions you may execute. Please pay close attention to comprehend the information presented below."
|
64 |
+
# print(fewshot_example_prompt.format(**fewshot_examples[0]))
|
65 |
+
human_template += '\nTask Description: {game_description} \n'
|
66 |
+
human_template += 'Goal Description: {goal_description}\n'
|
67 |
+
human_template += 'Actions Description: {action_description}\n'
|
68 |
+
|
69 |
+
if self.fewshot_example:
|
70 |
+
human_template += "Here, I will provide you with some guidance to help you better understand the rules of the task. Next are some examples: "
|
71 |
+
for i, examples in enumerate(self.fewshot_example):
|
72 |
+
human_template += f"\nExample {i+1}:\n"
|
73 |
+
human_template += "Question: \n" + examples['question'] + "\nAnswer: \n" + examples['answer']
|
74 |
+
|
75 |
+
if self.prompt_level in [2, 3, 4]:
|
76 |
+
if self.memory:
|
77 |
+
human_template += '\nSubsequently, I will offer pertinent guidance or information about the task. Please utilize this instruction to accomplish the given task effectively.\n'
|
78 |
+
suffix_flag = True
|
79 |
+
if self.prompt_level == 2:
|
80 |
+
human_template += 'I have collected a few trajectories from a random policy, and the summaries are listed below.'
|
81 |
+
elif self.prompt_level == 3:
|
82 |
+
human_template += 'I have collected a few trajectories before, and the summaries are listed below.'
|
83 |
+
elif self.prompt_level == 4:
|
84 |
+
human_template += 'I have collected a few trajectories from an expert policy, and the summaries are listed below.'
|
85 |
+
human_template += self._read_mem() + "\n"
|
86 |
+
|
87 |
+
if self.use_short_mem:
|
88 |
+
if len(self.env_history) > 1:
|
89 |
+
if not suffix_flag:
|
90 |
+
human_template += '\nSubsequently, I will offer pertinent guidance or information about the task. Please utilize this instruction to accomplish the given task effectively.'
|
91 |
+
human_template += f"\nBelow are the latest {self.args.short_mem_num} historical data entries:\n"
|
92 |
+
human_template += f"{self.env_history.get_histories(self.mem_num)}"
|
93 |
+
human_template += '\nNext is the observation that the agent gets:\nCurrent {state_description}\n'
|
94 |
+
human_template += 'Please select an action based on the current game state and the information you get. You must select the appropriate action from the given action descriptions and cannot refrain from taking action or performing any prohibited actions. Here is the action description below:\n{action_description}\n'
|
95 |
+
human_template += 'Please note that you need to carefully lay out your thought process on the question, not just give an answer. You need to write the corresponding logic of your thinking following the example above. Also, please keep in mind not to answer with any redundant and irrelevant content.\n'
|
96 |
+
human_template += "Finally, you also need to normalize your output according to the reply format description.\n"
|
97 |
+
human_template += 'Reply format description: {reply_format_description}{format_instructions}\n'
|
98 |
+
|
99 |
+
human_message_prompt = PromptTemplate(
|
100 |
+
template=human_template,
|
101 |
+
input_variables=[
|
102 |
+
'state_description', 'goal_description', 'game_description',
|
103 |
+
'action_description', 'reply_format_description'],
|
104 |
+
partial_variables={'format_instructions': self.parser.get_format_instructions()}
|
105 |
+
)
|
106 |
+
|
107 |
+
human_message_prompt = HumanMessagePromptTemplate(prompt=human_message_prompt)
|
108 |
+
|
109 |
+
chat_prompt = ChatPromptTemplate.from_messages([human_message_prompt])
|
110 |
+
|
111 |
+
if not self.logger:
|
112 |
+
logger.remove()
|
113 |
+
self.logger = logger.add(logfile, colorize=True, enqueue=True)
|
114 |
+
handler = FileCallbackHandler(logfile)
|
115 |
+
|
116 |
+
chain = LLMChain(llm=chat, prompt=chat_prompt, callbacks=[handler], verbose=False)
|
117 |
+
|
118 |
+
text_prompt = chat_prompt.format_messages(
|
119 |
+
game_description=game_description,
|
120 |
+
state_description=state_description,
|
121 |
+
goal_description=goal_description,
|
122 |
+
action_description=action_description,
|
123 |
+
reply_format_description=reply_format_description
|
124 |
+
)
|
125 |
+
texts = ""
|
126 |
+
for text in text_prompt:
|
127 |
+
texts += text.content + "\n"
|
128 |
+
|
129 |
+
with get_openai_callback() as cb:
|
130 |
+
response = run_chain(
|
131 |
+
chain,
|
132 |
+
game_description=game_description,
|
133 |
+
state_description=state_description,
|
134 |
+
goal_description=goal_description,
|
135 |
+
action_description=action_description,
|
136 |
+
reply_format_description=reply_format_description
|
137 |
+
)
|
138 |
+
total_tokens = cb.total_tokens
|
139 |
+
total_cost = cb.total_cost
|
140 |
+
action = self.parser.parse(response).action
|
141 |
+
self._add_history_after_action(action)
|
142 |
+
self.logger.info(f'The GPT response is: {response}.')
|
143 |
+
self.logger.info(f'The optimal action is: {action}.')
|
144 |
+
if env_info.get('history'):
|
145 |
+
self.logger.info(f'History: {history_to_str(env_info["history"])}')
|
146 |
+
|
147 |
+
return action, texts, response, total_tokens, total_cost
|
deciders/jarvis.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import openai
|
2 |
+
from .misc import history_to_str
|
3 |
+
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
|
4 |
+
from langchain.prompts.chat import (
|
5 |
+
PromptTemplate,
|
6 |
+
ChatPromptTemplate,
|
7 |
+
SystemMessagePromptTemplate,
|
8 |
+
HumanMessagePromptTemplate,
|
9 |
+
)
|
10 |
+
from langchain.prompts.few_shot import FewShotPromptTemplate
|
11 |
+
from langchain import LLMChain
|
12 |
+
from langchain.callbacks import FileCallbackHandler
|
13 |
+
from langchain.callbacks import get_openai_callback
|
14 |
+
from .act import NaiveAct
|
15 |
+
from memory.env_history import EnvironmentHistory
|
16 |
+
import tiktoken
|
17 |
+
from .utils import run_chain
|
18 |
+
from loguru import logger
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
class Jarvis(NaiveAct):
|
23 |
+
def __init__(self, action_space, args, prompts, distiller, temperature=0., max_tokens=None, logger=None, fixed_suggestion=None, fixed_insight=None):
|
24 |
+
super().__init__(action_space, args, prompts, distiller, temperature, max_tokens, logger)
|
25 |
+
self.pre_memory = []
|
26 |
+
self.post_memory = []
|
27 |
+
self.is_first = True
|
28 |
+
self.num_trails = args.num_trails
|
29 |
+
self.game_description = args.game_description
|
30 |
+
self.goal_description = args.goal_description
|
31 |
+
self.action_description = args.action_description
|
32 |
+
self.action_desc_dict = args.action_desc_dict
|
33 |
+
self.mem_num = args.trajectories_num
|
34 |
+
self.temperature = temperature
|
35 |
+
self.fixed_suggestion = fixed_suggestion
|
36 |
+
self.fixed_insight = fixed_insight
|
37 |
+
self._update_mem(None)
|
38 |
+
self.insight = ""
|
39 |
+
|
40 |
+
def num_tokens_from_string(self,string: str) -> int:
|
41 |
+
"""Returns the number of tokens in a text string."""
|
42 |
+
num_tokens = len(self.encoding.encode(string))
|
43 |
+
return num_tokens
|
44 |
+
|
45 |
+
def update_mem(self,):
|
46 |
+
traj = self.game_description
|
47 |
+
traj += self.goal_description
|
48 |
+
traj += self.action_description
|
49 |
+
traj += str(self.env_history)
|
50 |
+
self._update_mem(traj)
|
51 |
+
|
52 |
+
def clear_mem(self):
|
53 |
+
self.pre_memory = []
|
54 |
+
self.post_memory = []
|
55 |
+
self.is_first = True
|
56 |
+
self._update_mem(None)
|
57 |
+
|
58 |
+
def _update_mem(self, traj):
|
59 |
+
if self.memory:
|
60 |
+
self.post_memory = self.memory
|
61 |
+
self.insight = self.distiller.generate_insight(self.post_memory)
|
62 |
+
else:
|
63 |
+
if not self.is_first:
|
64 |
+
summary = self.distiller.generate_summary(traj, self.post_memory)
|
65 |
+
self.post_memory.append(summary)
|
66 |
+
self.insight = self.distiller.generate_insight(self.post_memory)
|
67 |
+
else:
|
68 |
+
self.is_first = False
|
69 |
+
self.insight = ""
|
70 |
+
suggestion = self.distiller.generate_suggestion(self.game_description, self.goal_description, self.action_description, self.pre_memory, self.post_memory, self.insight, self.num_trails)
|
71 |
+
if self.fixed_suggestion:
|
72 |
+
suggestion = self.fixed_suggestion
|
73 |
+
if self.fixed_insight:
|
74 |
+
self.insight = self.fixed_insight
|
75 |
+
self.pre_memory.append(suggestion)
|
76 |
+
self.env_history.reset()
|
77 |
+
|
78 |
+
def _read_mem(self, ):
|
79 |
+
insight_str = ""
|
80 |
+
if self.insight:
|
81 |
+
insight_str += "The insights of the game are listed below: "
|
82 |
+
insight_str += f"{self.insight}\n"
|
83 |
+
suggestion_str = "The suggestions are listed below:" + self.pre_memory[-1]
|
84 |
+
return insight_str + suggestion_str
|
85 |
+
def act(
|
86 |
+
self,
|
87 |
+
state_description,
|
88 |
+
action_description,
|
89 |
+
env_info,
|
90 |
+
game_description,
|
91 |
+
goal_description,
|
92 |
+
logfile=None,
|
93 |
+
):
|
94 |
+
self.game_description = game_description
|
95 |
+
self.goal_description = goal_description
|
96 |
+
self.env_history.add("observation", state_description)
|
97 |
+
chat = ChatOpenAI(temperature=0.5, openai_api_key=openai.api_key, model=self.args.gpt_version)
|
98 |
+
# print(self.logger)
|
99 |
+
reply_format_description = \
|
100 |
+
"Your response should choose an optimal action from valid action list, and terminated with following format: "
|
101 |
+
# only task relevant examplesA
|
102 |
+
template = "Now you are completing a task."
|
103 |
+
template += "You need to carefully understand the description of the game. "
|
104 |
+
# TODO: few shot example handle
|
105 |
+
if self.irr_few_shot_examples:
|
106 |
+
template += "Here are some examples of how you should completing a task."
|
107 |
+
for examples in self.irr_few_shot_examples:
|
108 |
+
template += "\nQuestion: \n" + examples['question'] + "Answer: \n" + examples['answer']
|
109 |
+
|
110 |
+
template += "\n\nNow you are in the task."
|
111 |
+
template += " {game_description} {action_description} {goal_description}"
|
112 |
+
template += "You are observing something and " \
|
113 |
+
"you need to choose the optimal action acoordingly."
|
114 |
+
template += 'Response and interact using the format: {reply_format_description}{format_instructions}\n'
|
115 |
+
|
116 |
+
template += self._read_mem()
|
117 |
+
system_message_prompt = SystemMessagePromptTemplate.from_template(template)
|
118 |
+
|
119 |
+
short_memory_template = HumanMessagePromptTemplate.from_template("{history}")
|
120 |
+
chat_prompt = ChatPromptTemplate.from_messages(
|
121 |
+
[system_message_prompt, short_memory_template])
|
122 |
+
if self.logger:
|
123 |
+
pass
|
124 |
+
else:
|
125 |
+
if logfile:
|
126 |
+
# logger.remove()
|
127 |
+
if self.first_call:
|
128 |
+
self.logger = logger.add(logfile, colorize=True, enqueue=True, filter=lambda x: '[Reflexion Memory]' not in x['message'])
|
129 |
+
self.first_call = False
|
130 |
+
handler = FileCallbackHandler(logfile)
|
131 |
+
total_tokens, total_cost = 0, 0
|
132 |
+
max_think_times = 1
|
133 |
+
# TODO: ADD REACT Support
|
134 |
+
# print(str(self.env_history))
|
135 |
+
if self.use_short_mem:
|
136 |
+
my_history = str(self.env_history)
|
137 |
+
else:
|
138 |
+
my_history = ""
|
139 |
+
for i_think in range(max_think_times):
|
140 |
+
# chain = LLMChain(llm=chat, prompt=chat_prompt, callbacks=[handler], verbose=True)
|
141 |
+
chain = LLMChain(llm=chat, prompt=chat_prompt, callbacks=[handler], verbose=False)
|
142 |
+
with get_openai_callback() as cb:
|
143 |
+
response = run_chain(
|
144 |
+
chain,
|
145 |
+
game_description=game_description,
|
146 |
+
goal_description=goal_description,
|
147 |
+
action_description=action_description,
|
148 |
+
# state_description = self.env_history.get_last_history(),
|
149 |
+
history=self.env_history.get_histories_with_last(self.mem_num),
|
150 |
+
format_instructions=self.parser.get_format_instructions(),
|
151 |
+
reply_format_description=reply_format_description,
|
152 |
+
max_token=3000
|
153 |
+
)
|
154 |
+
|
155 |
+
total_tokens += cb.total_tokens
|
156 |
+
total_cost += cb.total_cost
|
157 |
+
action = self.parser.parse(response).action
|
158 |
+
self._add_history_after_action(action)
|
159 |
+
self.logger.info(f'The GPT response is: {response}.')
|
160 |
+
self.logger.info(f'The optimal action is: {action}.')
|
161 |
+
if self.pre_memory:
|
162 |
+
self.logger.info(f'The suggestion is: {self.pre_memory[-1]}.')
|
163 |
+
if self.post_memory:
|
164 |
+
self.logger.info(f'The summary is: {self.post_memory[-1]}.')
|
165 |
+
if env_info.get('history'):
|
166 |
+
self.logger.info(f'History: {history_to_str(env_info["history"])}')
|
167 |
+
text_prompt = chat_prompt.format_messages(
|
168 |
+
game_description=game_description,
|
169 |
+
goal_description=goal_description,
|
170 |
+
action_description=action_description,
|
171 |
+
# state_description = self.env_history.get_last_history(),
|
172 |
+
history=self.env_history.get_histories_with_last(self.mem_num),
|
173 |
+
format_instructions=self.parser.get_format_instructions(),
|
174 |
+
reply_format_description=reply_format_description,
|
175 |
+
)
|
176 |
+
text_prompt = f'{text_prompt[0].content}\n{text_prompt[1].content}'
|
177 |
+
return action, text_prompt, response, total_tokens, total_cost
|
deciders/jarvis_without_insights.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import openai
|
2 |
+
from .misc import history_to_str
|
3 |
+
from langchain.chat_models import AzureChatOpenAI
|
4 |
+
from langchain.prompts.chat import (
|
5 |
+
PromptTemplate,
|
6 |
+
ChatPromptTemplate,
|
7 |
+
SystemMessagePromptTemplate,
|
8 |
+
HumanMessagePromptTemplate,
|
9 |
+
)
|
10 |
+
from langchain.prompts.few_shot import FewShotPromptTemplate
|
11 |
+
from langchain import LLMChain
|
12 |
+
from loguru import logger
|
13 |
+
from langchain.callbacks import FileCallbackHandler
|
14 |
+
from langchain.callbacks import get_openai_callback
|
15 |
+
from .act import NaiveAct
|
16 |
+
from memory.env_history import EnvironmentHistory
|
17 |
+
import tiktoken
|
18 |
+
from .utils import run_chain
|
19 |
+
|
20 |
+
|
21 |
+
class JarvisWithoutInsight(NaiveAct):
|
22 |
+
def __init__(self, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None):
|
23 |
+
super().__init__(action_space, args, prompts, distiller, temperature, max_tokens)
|
24 |
+
self.pre_memory = []
|
25 |
+
self.post_memory = []
|
26 |
+
self.is_first = True
|
27 |
+
self.num_trails = args.num_trails
|
28 |
+
self.game_description = args.game_description
|
29 |
+
self.goal_description = args.goal_description
|
30 |
+
self.action_description = args.action_description
|
31 |
+
self._update_mem(None)
|
32 |
+
|
33 |
+
def update_mem(self,):
|
34 |
+
traj = self.game_description
|
35 |
+
traj += self.goal_description
|
36 |
+
max_step_num = min(14000 // self.num_tokens_from_string(self.env_history.get_one_history()),200)
|
37 |
+
traj += self.env_history.get_histories(max_step_num)
|
38 |
+
self._update_mem(traj)
|
39 |
+
|
40 |
+
def _update_mem(self, traj):
|
41 |
+
if not self.is_first:
|
42 |
+
summary = self.distiller.generate_summary(traj, self.post_memory)
|
43 |
+
self.post_memory.append(summary)
|
44 |
+
self.insight = self.distiller.generate_insight(self.post_memory)
|
45 |
+
else:
|
46 |
+
self.is_first = False
|
47 |
+
suggestion = self.distiller.generate_suggestion(self.game_description, self.goal_description, self.action_description, self.pre_memory, self.post_memory, self.num_trails)
|
48 |
+
self.pre_memory.append(suggestion)
|
49 |
+
self.env_history.reset()
|
50 |
+
|
51 |
+
def _read_mem(self, ):
|
52 |
+
insight_str = ""
|
53 |
+
suggestion_str = "The suggestions are listed below:" + self.pre_memory[-1]
|
54 |
+
return insight_str + suggestion_str
|
55 |
+
|
56 |
+
def act(
|
57 |
+
self,
|
58 |
+
state_description,
|
59 |
+
action_description,
|
60 |
+
env_info,
|
61 |
+
game_description,
|
62 |
+
goal_description,
|
63 |
+
logfile=None,
|
64 |
+
):
|
65 |
+
self.game_description = game_description
|
66 |
+
self.goal_description = goal_description
|
67 |
+
self.env_history.add("observation", state_description)
|
68 |
+
chat = AzureChatOpenAI(
|
69 |
+
openai_api_type=openai.api_type,
|
70 |
+
openai_api_version=openai.api_version,
|
71 |
+
openai_api_base=openai.api_base,
|
72 |
+
openai_api_key=openai.api_key,
|
73 |
+
deployment_name=self.args.gpt_version,
|
74 |
+
temperature=self.temperature,
|
75 |
+
max_tokens=self.max_tokens,
|
76 |
+
)
|
77 |
+
reply_format_description = \
|
78 |
+
"Your response should choose an optimal action from valid action list, and terminated with following format: "
|
79 |
+
# only task relevant examplesA
|
80 |
+
template = "Now you are completing a task. "
|
81 |
+
template += "You need to carefully understand the description of the game. "
|
82 |
+
# TODO: few shot example handle
|
83 |
+
if self.irr_few_shot_examples:
|
84 |
+
template += "Here are some examples of how you should completing a task."
|
85 |
+
for examples in self.irr_few_shot_examples:
|
86 |
+
template += "\nQuestion: \n" + examples['question'] + "Answer: \n" + examples['answer']
|
87 |
+
|
88 |
+
if self.fewshot_example:
|
89 |
+
if self.expert_knowledge:
|
90 |
+
template += "Here, I will provide you with some expert knowledge to help you better understand the rules of the task."
|
91 |
+
template += self.expert_knowledge + '\n'
|
92 |
+
template += "Next are some examples: "
|
93 |
+
system_message_prompt = SystemMessagePromptTemplate.from_template(template)
|
94 |
+
|
95 |
+
human_template = ""
|
96 |
+
human_template += "\n\nNow you are in the task.\n"
|
97 |
+
human_template += "{game_description}\n{action_description}\n{goal_description}\n"
|
98 |
+
human_template += "You are observing something and " \
|
99 |
+
"you need to choose the optimal action acoordingly. "
|
100 |
+
human_template += 'Response and interact using the format: {reply_format_description}{format_instructions}\n'
|
101 |
+
human_template += self._read_mem()
|
102 |
+
human_template += "\n\nHere are some history states listed below:\n"
|
103 |
+
|
104 |
+
fewshot_example_prompt = PromptTemplate(
|
105 |
+
input_variables=["question", "answer"],
|
106 |
+
template="Question: \n{question}\n{answer}"
|
107 |
+
)
|
108 |
+
human_message_prompt = FewShotPromptTemplate(
|
109 |
+
examples=self.fewshot_example,
|
110 |
+
example_prompt=fewshot_example_prompt,
|
111 |
+
suffix=human_template,
|
112 |
+
input_variables=[
|
113 |
+
'game_description', 'goal_description',
|
114 |
+
'action_description', 'reply_format_description'],
|
115 |
+
partial_variables={'format_instructions': self.parser.get_format_instructions()}
|
116 |
+
)
|
117 |
+
human_message_prompt = HumanMessagePromptTemplate(prompt=human_message_prompt)
|
118 |
+
|
119 |
+
short_memory_template = HumanMessagePromptTemplate.from_template("{history} Please select an action based on the current game state:")
|
120 |
+
|
121 |
+
chat_prompt = ChatPromptTemplate.from_messages(
|
122 |
+
[system_message_prompt, human_message_prompt, short_memory_template])
|
123 |
+
|
124 |
+
|
125 |
+
if logfile:
|
126 |
+
# logger.remove()
|
127 |
+
if self.first_call:
|
128 |
+
logger.add(logfile, colorize=True, enqueue=True, filter=lambda x: '[Reflexion Memory]' not in x['message'])
|
129 |
+
self.first_call = False
|
130 |
+
handler = FileCallbackHandler(logfile)
|
131 |
+
total_tokens, total_cost = 0, 0
|
132 |
+
max_think_times = 1
|
133 |
+
# TODO: ADD REACT Support
|
134 |
+
# print(str(self.env_history))
|
135 |
+
if self.use_short_mem:
|
136 |
+
my_history = str(self.env_history)
|
137 |
+
else:
|
138 |
+
my_history = ""
|
139 |
+
for i_think in range(max_think_times):
|
140 |
+
chain = LLMChain(llm=chat, prompt=chat_prompt, callbacks=[handler], verbose=False)
|
141 |
+
with get_openai_callback() as cb:
|
142 |
+
response = run_chain(
|
143 |
+
chain,
|
144 |
+
game_description=game_description,
|
145 |
+
goal_description=goal_description,
|
146 |
+
action_description=action_description,
|
147 |
+
history=str(self.env_history),
|
148 |
+
format_instructions=self.parser.get_format_instructions(),
|
149 |
+
reply_format_description=reply_format_description,
|
150 |
+
max_token = 3000
|
151 |
+
)
|
152 |
+
|
153 |
+
total_tokens += cb.total_tokens
|
154 |
+
total_cost += cb.total_cost
|
155 |
+
action = self.parser.parse(response).action
|
156 |
+
|
157 |
+
text_prompt = chat_prompt.format_messages(
|
158 |
+
game_description=game_description,
|
159 |
+
goal_description=goal_description,
|
160 |
+
action_description=action_description,
|
161 |
+
history=str(self.env_history),
|
162 |
+
format_instructions=self.parser.get_format_instructions(),
|
163 |
+
reply_format_description=reply_format_description,
|
164 |
+
)
|
165 |
+
texts = ""
|
166 |
+
for text in text_prompt:
|
167 |
+
texts += text.content + "\n"
|
168 |
+
|
169 |
+
self._add_history_after_action(action)
|
170 |
+
logger.info(f'The GPT response is: {response}.')
|
171 |
+
logger.info(f'The optimal action is: {action}.')
|
172 |
+
if self.pre_memory:
|
173 |
+
logger.info(f'The suggestion is: {self.pre_memory[-1]}.')
|
174 |
+
if self.post_memory:
|
175 |
+
logger.info(f'The summary is: {self.post_memory[-1]}.')
|
176 |
+
if env_info.get('history'):
|
177 |
+
logger.info(f'History: {history_to_str(env_info["history"])}')
|
178 |
+
|
179 |
+
return action, texts, response, logger, total_tokens, total_cost
|
deciders/jarvis_without_shortmem.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import openai
|
2 |
+
from .misc import history_to_str
|
3 |
+
from langchain.chat_models import AzureChatOpenAI
|
4 |
+
from langchain.prompts.chat import (
|
5 |
+
PromptTemplate,
|
6 |
+
ChatPromptTemplate,
|
7 |
+
SystemMessagePromptTemplate,
|
8 |
+
HumanMessagePromptTemplate,
|
9 |
+
)
|
10 |
+
from langchain.prompts.few_shot import FewShotPromptTemplate
|
11 |
+
from langchain import LLMChain
|
12 |
+
from loguru import logger
|
13 |
+
from langchain.callbacks import FileCallbackHandler
|
14 |
+
from langchain.callbacks import get_openai_callback
|
15 |
+
from .act import NaiveAct
|
16 |
+
from memory.env_history import EnvironmentHistory
|
17 |
+
import tiktoken
|
18 |
+
from .utils import run_chain
|
19 |
+
|
20 |
+
|
21 |
+
class JarvisWithoutShortMem(NaiveAct):
|
22 |
+
def __init__(self, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None):
|
23 |
+
super().__init__(action_space, args, prompts, distiller, temperature, max_tokens)
|
24 |
+
self.pre_memory = []
|
25 |
+
self.post_memory = []
|
26 |
+
self.is_first = True
|
27 |
+
self.num_trails = args.num_trails
|
28 |
+
self.game_description = args.game_description
|
29 |
+
self.goal_description = args.goal_description
|
30 |
+
self.action_description = args.action_description
|
31 |
+
self._update_mem(None)
|
32 |
+
|
33 |
+
def update_mem(self,):
|
34 |
+
traj = self.game_description
|
35 |
+
traj += self.goal_description
|
36 |
+
max_step_num = min(14000 // self.num_tokens_from_string(self.env_history.get_one_history()),200)
|
37 |
+
traj += self.env_history.get_histories(max_step_num)
|
38 |
+
self._update_mem(traj)
|
39 |
+
|
40 |
+
def _update_mem(self, traj):
|
41 |
+
if not self.is_first:
|
42 |
+
summary = self.distiller.generate_summary(traj, self.post_memory)
|
43 |
+
self.post_memory.append(summary)
|
44 |
+
self.insight = self.distiller.generate_insight(self.post_memory)
|
45 |
+
else:
|
46 |
+
self.is_first = False
|
47 |
+
suggestion = self.distiller.generate_suggestion(self.game_description, self.goal_description, self.action_description, self.pre_memory, self.post_memory, self.num_trails)
|
48 |
+
self.pre_memory.append(suggestion)
|
49 |
+
self.env_history.reset()
|
50 |
+
|
51 |
+
def _read_mem(self, ):
|
52 |
+
insight_str = ""
|
53 |
+
if len(self.post_memory) > 0:
|
54 |
+
insight_str += "The insights of the game are listed below: "
|
55 |
+
insight_str += f"{self.insight}\n"
|
56 |
+
suggestion_str = "The suggestions are listed below:" + self.pre_memory[-1]
|
57 |
+
return insight_str + suggestion_str
|
58 |
+
|
59 |
+
def act(
|
60 |
+
self,
|
61 |
+
state_description,
|
62 |
+
action_description,
|
63 |
+
env_info,
|
64 |
+
game_description,
|
65 |
+
goal_description,
|
66 |
+
logfile=None,
|
67 |
+
):
|
68 |
+
self.game_description = game_description
|
69 |
+
self.goal_description = goal_description
|
70 |
+
self.env_history.add("observation", state_description)
|
71 |
+
chat = AzureChatOpenAI(
|
72 |
+
openai_api_type=openai.api_type,
|
73 |
+
openai_api_version=openai.api_version,
|
74 |
+
openai_api_base=openai.api_base,
|
75 |
+
openai_api_key=openai.api_key,
|
76 |
+
deployment_name=self.args.gpt_version,
|
77 |
+
temperature=self.temperature,
|
78 |
+
max_tokens=self.max_tokens,
|
79 |
+
)
|
80 |
+
reply_format_description = \
|
81 |
+
"Your response should choose an optimal action from valid action list, and terminated with following format: "
|
82 |
+
# only task relevant examplesA
|
83 |
+
template = "Now you are completing a task. "
|
84 |
+
template += "You need to carefully understand the description of the game. "
|
85 |
+
# TODO: few shot example handle
|
86 |
+
if self.irr_few_shot_examples:
|
87 |
+
template += "Here are some examples of how you should completing a task."
|
88 |
+
for examples in self.irr_few_shot_examples:
|
89 |
+
template += "\nQuestion: \n" + examples['question'] + "Answer: \n" + examples['answer']
|
90 |
+
|
91 |
+
if self.fewshot_example:
|
92 |
+
if self.expert_knowledge:
|
93 |
+
template += "Here, I will provide you with some expert knowledge to help you better understand the rules of the task."
|
94 |
+
template += self.expert_knowledge + '\n'
|
95 |
+
template += "Next are some examples: "
|
96 |
+
system_message_prompt = SystemMessagePromptTemplate.from_template(template)
|
97 |
+
|
98 |
+
human_template = ""
|
99 |
+
human_template += "\n\nNow you are in the task.\n"
|
100 |
+
human_template += "{game_description}\n{action_description}\n{goal_description}\n"
|
101 |
+
human_template += "You are observing something and " \
|
102 |
+
"you need to choose the optimal action acoordingly. "
|
103 |
+
human_template += 'Response and interact using the format: {reply_format_description}{format_instructions}\n'
|
104 |
+
human_template += self._read_mem()
|
105 |
+
human_template += "\n\nHere are some history states listed below:\n"
|
106 |
+
|
107 |
+
fewshot_example_prompt = PromptTemplate(
|
108 |
+
input_variables=["question", "answer"],
|
109 |
+
template="Question: \n{question}\n{answer}"
|
110 |
+
)
|
111 |
+
human_message_prompt = FewShotPromptTemplate(
|
112 |
+
examples=self.fewshot_example,
|
113 |
+
example_prompt=fewshot_example_prompt,
|
114 |
+
suffix=human_template,
|
115 |
+
input_variables=[
|
116 |
+
'game_description', 'goal_description',
|
117 |
+
'action_description', 'reply_format_description'],
|
118 |
+
partial_variables={'format_instructions': self.parser.get_format_instructions()}
|
119 |
+
)
|
120 |
+
human_message_prompt = HumanMessagePromptTemplate(prompt=human_message_prompt)
|
121 |
+
|
122 |
+
short_memory_template = HumanMessagePromptTemplate.from_template("{history} Please select an action based on the current game state:")
|
123 |
+
|
124 |
+
chat_prompt = ChatPromptTemplate.from_messages(
|
125 |
+
[system_message_prompt, human_message_prompt, short_memory_template])
|
126 |
+
|
127 |
+
|
128 |
+
if logfile:
|
129 |
+
# logger.remove()
|
130 |
+
if self.first_call:
|
131 |
+
logger.add(logfile, colorize=True, enqueue=True, filter=lambda x: '[Reflexion Memory]' not in x['message'])
|
132 |
+
self.first_call = False
|
133 |
+
handler = FileCallbackHandler(logfile)
|
134 |
+
total_tokens, total_cost = 0, 0
|
135 |
+
max_think_times = 1
|
136 |
+
# TODO: ADD REACT Support
|
137 |
+
# print(str(self.env_history))
|
138 |
+
if self.use_short_mem:
|
139 |
+
my_history = str(self.env_history)
|
140 |
+
else:
|
141 |
+
my_history = ""
|
142 |
+
for i_think in range(max_think_times):
|
143 |
+
chain = LLMChain(llm=chat, prompt=chat_prompt, callbacks=[handler], verbose=False)
|
144 |
+
with get_openai_callback() as cb:
|
145 |
+
response = run_chain(
|
146 |
+
chain,
|
147 |
+
game_description=game_description,
|
148 |
+
goal_description=goal_description,
|
149 |
+
action_description=action_description,
|
150 |
+
history=self.env_history.get_last_history(),
|
151 |
+
format_instructions=self.parser.get_format_instructions(),
|
152 |
+
reply_format_description=reply_format_description,
|
153 |
+
max_token = 3000
|
154 |
+
)
|
155 |
+
|
156 |
+
total_tokens += cb.total_tokens
|
157 |
+
total_cost += cb.total_cost
|
158 |
+
action = self.parser.parse(response).action
|
159 |
+
|
160 |
+
text_prompt = chat_prompt.format_messages(
|
161 |
+
game_description=game_description,
|
162 |
+
goal_description=goal_description,
|
163 |
+
action_description=action_description,
|
164 |
+
history=self.env_history.get_last_history(),
|
165 |
+
format_instructions=self.parser.get_format_instructions(),
|
166 |
+
reply_format_description=reply_format_description,
|
167 |
+
)
|
168 |
+
texts = ""
|
169 |
+
for text in text_prompt:
|
170 |
+
texts += text.content + "\n"
|
171 |
+
|
172 |
+
self._add_history_after_action(action)
|
173 |
+
logger.info(f'The GPT response is: {response}.')
|
174 |
+
logger.info(f'The optimal action is: {action}.')
|
175 |
+
if self.pre_memory:
|
176 |
+
logger.info(f'The suggestion is: {self.pre_memory[-1]}.')
|
177 |
+
if self.post_memory:
|
178 |
+
logger.info(f'The summary is: {self.post_memory[-1]}.')
|
179 |
+
if env_info.get('history'):
|
180 |
+
logger.info(f'History: {history_to_str(env_info["history"])}')
|
181 |
+
|
182 |
+
return action, texts, response, logger, total_tokens, total_cost
|
deciders/jarvis_without_suggestions.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import openai
|
2 |
+
from .misc import history_to_str
|
3 |
+
from langchain.chat_models import AzureChatOpenAI
|
4 |
+
from langchain.prompts.chat import (
|
5 |
+
PromptTemplate,
|
6 |
+
ChatPromptTemplate,
|
7 |
+
SystemMessagePromptTemplate,
|
8 |
+
HumanMessagePromptTemplate,
|
9 |
+
)
|
10 |
+
from langchain.prompts.few_shot import FewShotPromptTemplate
|
11 |
+
from langchain import LLMChain
|
12 |
+
from loguru import logger
|
13 |
+
from langchain.callbacks import FileCallbackHandler
|
14 |
+
from langchain.callbacks import get_openai_callback
|
15 |
+
from .act import NaiveAct
|
16 |
+
from memory.env_history import EnvironmentHistory
|
17 |
+
import tiktoken
|
18 |
+
from .utils import run_chain
|
19 |
+
|
20 |
+
|
21 |
+
class JarvisWithoutSuggestions(NaiveAct):
|
22 |
+
def __init__(self, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None):
|
23 |
+
super().__init__(action_space, args, prompts, distiller, temperature, max_tokens)
|
24 |
+
self.pre_memory = []
|
25 |
+
self.post_memory = []
|
26 |
+
self.is_first = True
|
27 |
+
self.num_trails = args.num_trails
|
28 |
+
self.game_description = args.game_description
|
29 |
+
self.goal_description = args.goal_description
|
30 |
+
self.action_description = args.action_description
|
31 |
+
self._update_mem(None)
|
32 |
+
|
33 |
+
def update_mem(self,):
|
34 |
+
traj = self.game_description
|
35 |
+
traj += self.goal_description
|
36 |
+
max_step_num = min(14000 // self.num_tokens_from_string(self.env_history.get_one_history()),200)
|
37 |
+
traj += self.env_history.get_histories(max_step_num)
|
38 |
+
self._update_mem(traj)
|
39 |
+
|
40 |
+
def _update_mem(self, traj):
|
41 |
+
if not self.is_first:
|
42 |
+
summary = self.distiller.generate_summary(traj, self.post_memory)
|
43 |
+
self.post_memory.append(summary)
|
44 |
+
self.insight = self.distiller.generate_insight(self.post_memory)
|
45 |
+
else:
|
46 |
+
self.is_first = False
|
47 |
+
suggestion = self.distiller.generate_suggestion(self.game_description, self.goal_description, self.action_description, self.pre_memory, self.post_memory, self.num_trails)
|
48 |
+
self.pre_memory.append(suggestion)
|
49 |
+
self.env_history.reset()
|
50 |
+
|
51 |
+
def _read_mem(self, ):
|
52 |
+
insight_str = ""
|
53 |
+
if len(self.post_memory) > 0:
|
54 |
+
insight_str += "The insights of the game are listed below: "
|
55 |
+
insight_str += f"{self.insight}\n"
|
56 |
+
suggestion_str = "\n"
|
57 |
+
return insight_str + suggestion_str
|
58 |
+
|
59 |
+
def act(
|
60 |
+
self,
|
61 |
+
state_description,
|
62 |
+
action_description,
|
63 |
+
env_info,
|
64 |
+
game_description,
|
65 |
+
goal_description,
|
66 |
+
logfile=None,
|
67 |
+
):
|
68 |
+
self.game_description = game_description
|
69 |
+
self.goal_description = goal_description
|
70 |
+
self.env_history.add("observation", state_description)
|
71 |
+
chat = AzureChatOpenAI(
|
72 |
+
openai_api_type=openai.api_type,
|
73 |
+
openai_api_version=openai.api_version,
|
74 |
+
openai_api_base=openai.api_base,
|
75 |
+
openai_api_key=openai.api_key,
|
76 |
+
deployment_name=self.args.gpt_version,
|
77 |
+
temperature=self.temperature,
|
78 |
+
max_tokens=self.max_tokens,
|
79 |
+
)
|
80 |
+
reply_format_description = \
|
81 |
+
"Your response should choose an optimal action from valid action list, and terminated with following format: "
|
82 |
+
# only task relevant examplesA
|
83 |
+
template = "Now you are completing a task. "
|
84 |
+
template += "You need to carefully understand the description of the game. "
|
85 |
+
# TODO: few shot example handle
|
86 |
+
if self.irr_few_shot_examples:
|
87 |
+
template += "Here are some examples of how you should completing a task."
|
88 |
+
for examples in self.irr_few_shot_examples:
|
89 |
+
template += "\nQuestion: \n" + examples['question'] + "Answer: \n" + examples['answer']
|
90 |
+
|
91 |
+
if self.fewshot_example:
|
92 |
+
if self.expert_knowledge:
|
93 |
+
template += "Here, I will provide you with some expert knowledge to help you better understand the rules of the task."
|
94 |
+
template += self.expert_knowledge + '\n'
|
95 |
+
template += "Next are some examples: "
|
96 |
+
system_message_prompt = SystemMessagePromptTemplate.from_template(template)
|
97 |
+
|
98 |
+
human_template = ""
|
99 |
+
human_template += "\n\nNow you are in the task.\n"
|
100 |
+
human_template += "{game_description}\n{action_description}\n{goal_description}\n"
|
101 |
+
human_template += "You are observing something and " \
|
102 |
+
"you need to choose the optimal action acoordingly. "
|
103 |
+
human_template += 'Response and interact using the format: {reply_format_description}{format_instructions}\n'
|
104 |
+
human_template += self._read_mem()
|
105 |
+
human_template += "\n\nHere are some history states listed below:\n"
|
106 |
+
|
107 |
+
fewshot_example_prompt = PromptTemplate(
|
108 |
+
input_variables=["question", "answer"],
|
109 |
+
template="Question: \n{question}\n{answer}"
|
110 |
+
)
|
111 |
+
human_message_prompt = FewShotPromptTemplate(
|
112 |
+
examples=self.fewshot_example,
|
113 |
+
example_prompt=fewshot_example_prompt,
|
114 |
+
suffix=human_template,
|
115 |
+
input_variables=[
|
116 |
+
'game_description', 'goal_description',
|
117 |
+
'action_description', 'reply_format_description'],
|
118 |
+
partial_variables={'format_instructions': self.parser.get_format_instructions()}
|
119 |
+
)
|
120 |
+
human_message_prompt = HumanMessagePromptTemplate(prompt=human_message_prompt)
|
121 |
+
|
122 |
+
short_memory_template = HumanMessagePromptTemplate.from_template("{history} Please select an action based on the current game state:")
|
123 |
+
|
124 |
+
chat_prompt = ChatPromptTemplate.from_messages(
|
125 |
+
[system_message_prompt, human_message_prompt, short_memory_template])
|
126 |
+
|
127 |
+
|
128 |
+
if logfile:
|
129 |
+
# logger.remove()
|
130 |
+
if self.first_call:
|
131 |
+
logger.add(logfile, colorize=True, enqueue=True, filter=lambda x: '[Reflexion Memory]' not in x['message'])
|
132 |
+
self.first_call = False
|
133 |
+
handler = FileCallbackHandler(logfile)
|
134 |
+
total_tokens, total_cost = 0, 0
|
135 |
+
max_think_times = 1
|
136 |
+
# TODO: ADD REACT Support
|
137 |
+
# print(str(self.env_history))
|
138 |
+
if self.use_short_mem:
|
139 |
+
my_history = str(self.env_history)
|
140 |
+
else:
|
141 |
+
my_history = ""
|
142 |
+
for i_think in range(max_think_times):
|
143 |
+
chain = LLMChain(llm=chat, prompt=chat_prompt, callbacks=[handler], verbose=False)
|
144 |
+
with get_openai_callback() as cb:
|
145 |
+
response = run_chain(
|
146 |
+
chain,
|
147 |
+
game_description=game_description,
|
148 |
+
goal_description=goal_description,
|
149 |
+
action_description=action_description,
|
150 |
+
history=str(self.env_history),
|
151 |
+
format_instructions=self.parser.get_format_instructions(),
|
152 |
+
reply_format_description=reply_format_description,
|
153 |
+
max_token = 3000
|
154 |
+
)
|
155 |
+
|
156 |
+
total_tokens += cb.total_tokens
|
157 |
+
total_cost += cb.total_cost
|
158 |
+
action = self.parser.parse(response).action
|
159 |
+
|
160 |
+
text_prompt = chat_prompt.format_messages(
|
161 |
+
game_description=game_description,
|
162 |
+
goal_description=goal_description,
|
163 |
+
action_description=action_description,
|
164 |
+
history=str(self.env_history),
|
165 |
+
format_instructions=self.parser.get_format_instructions(),
|
166 |
+
reply_format_description=reply_format_description,
|
167 |
+
)
|
168 |
+
texts = ""
|
169 |
+
for text in text_prompt:
|
170 |
+
texts += text.content + "\n"
|
171 |
+
|
172 |
+
self._add_history_after_action(action)
|
173 |
+
logger.info(f'The GPT response is: {response}.')
|
174 |
+
logger.info(f'The optimal action is: {action}.')
|
175 |
+
if self.post_memory:
|
176 |
+
logger.info(f'The summary is: {self.post_memory[-1]}.')
|
177 |
+
if env_info.get('history'):
|
178 |
+
logger.info(f'History: {history_to_str(env_info["history"])}')
|
179 |
+
|
180 |
+
return action, texts, response, logger, total_tokens, total_cost
|
deciders/jarvis_without_summary.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import openai
|
2 |
+
from .misc import history_to_str
|
3 |
+
from langchain.chat_models import AzureChatOpenAI
|
4 |
+
from langchain.prompts.chat import (
|
5 |
+
PromptTemplate,
|
6 |
+
ChatPromptTemplate,
|
7 |
+
SystemMessagePromptTemplate,
|
8 |
+
HumanMessagePromptTemplate,
|
9 |
+
)
|
10 |
+
from langchain.prompts.few_shot import FewShotPromptTemplate
|
11 |
+
from langchain import LLMChain
|
12 |
+
from loguru import logger
|
13 |
+
from langchain.callbacks import FileCallbackHandler
|
14 |
+
from langchain.callbacks import get_openai_callback
|
15 |
+
from .act import NaiveAct
|
16 |
+
from memory.env_history import EnvironmentHistory
|
17 |
+
import tiktoken
|
18 |
+
|
19 |
+
|
20 |
+
class Jarvis(NaiveAct):
|
21 |
+
def __init__(self, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None):
|
22 |
+
super().__init__(action_space, args, prompts, distiller, temperature, max_tokens)
|
23 |
+
self.pre_memory = []
|
24 |
+
self.post_memory = []
|
25 |
+
self.is_first = True
|
26 |
+
self.num_trails = args.num_trails
|
27 |
+
self.game_description = args.game_description
|
28 |
+
self.goal_description = args.goal_description
|
29 |
+
self.action_description = args.action_description
|
30 |
+
self._update_mem(None)
|
31 |
+
|
32 |
+
def update_mem(self,):
|
33 |
+
traj = self.game_description
|
34 |
+
traj += self.goal_description
|
35 |
+
max_step_num = min(14000 // self.num_tokens_from_string(self.env_history.get_one_history()),200)
|
36 |
+
traj += self.env_history.get_histories(max_step_num)
|
37 |
+
self._update_mem(traj)
|
38 |
+
|
39 |
+
def _update_mem(self, traj):
|
40 |
+
if not self.is_first:
|
41 |
+
summary = self.distiller.generate_summary(traj, self.post_memory)
|
42 |
+
self.post_memory.append(summary)
|
43 |
+
self.insight = self.distiller.generate_insight(self.post_memory)
|
44 |
+
else:
|
45 |
+
self.is_first = False
|
46 |
+
suggestion = self.distiller.generate_suggestion(self.game_description, self.goal_description, self.action_description, self.pre_memory, self.post_memory, self.num_trails)
|
47 |
+
self.pre_memory.append(suggestion)
|
48 |
+
self.env_history.reset()
|
49 |
+
|
50 |
+
def _read_mem(self, ):
|
51 |
+
insight_str = ""
|
52 |
+
if len(self.post_memory) > 0:
|
53 |
+
insight_str += "The insights of the game are listed below: "
|
54 |
+
insight_str += f"{self.insight}\n"
|
55 |
+
suggestion_str = "The suggestions are listed below:" + self.pre_memory[-1]
|
56 |
+
return insight_str + suggestion_str
|
57 |
+
|
58 |
+
def act(
|
59 |
+
self,
|
60 |
+
state_description,
|
61 |
+
action_description,
|
62 |
+
env_info,
|
63 |
+
game_description,
|
64 |
+
goal_description,
|
65 |
+
logfile=None,
|
66 |
+
):
|
67 |
+
self.game_description = game_description
|
68 |
+
self.goal_description = goal_description
|
69 |
+
self.env_history.add("observation", state_description)
|
70 |
+
chat = AzureChatOpenAI(
|
71 |
+
openai_api_type=openai.api_type,
|
72 |
+
openai_api_version=openai.api_version,
|
73 |
+
openai_api_base=openai.api_base,
|
74 |
+
openai_api_key=openai.api_key,
|
75 |
+
deployment_name=self.args.gpt_version,
|
76 |
+
temperature=self.temperature,
|
77 |
+
max_tokens=self.max_tokens,
|
78 |
+
)
|
79 |
+
reply_format_description = \
|
80 |
+
"Your response should choose an optimal action from valid action list, and terminated with following format: "
|
81 |
+
# only task relevant examplesA
|
82 |
+
template = "Now you are completing a task. "
|
83 |
+
template += "You need to carefully understand the description of the game. "
|
84 |
+
# TODO: few shot example handle
|
85 |
+
if self.irr_few_shot_examples:
|
86 |
+
template += "Here are some examples of how you should completing a task."
|
87 |
+
for examples in self.irr_few_shot_examples:
|
88 |
+
template += "\nQuestion: \n" + examples['question'] + "Answer: \n" + examples['answer']
|
89 |
+
|
90 |
+
if self.fewshot_example:
|
91 |
+
if self.expert_knowledge:
|
92 |
+
template += "Here, I will provide you with some expert knowledge to help you better understand the rules of the task."
|
93 |
+
template += self.expert_knowledge + '\n'
|
94 |
+
template += "Next are some examples: "
|
95 |
+
system_message_prompt = SystemMessagePromptTemplate.from_template(template)
|
96 |
+
|
97 |
+
human_template = ""
|
98 |
+
human_template += "\n"
|
99 |
+
human_template += "{game_description}\n{action_description}\n{goal_description}\n"
|
100 |
+
human_template += "You are observing something and " \
|
101 |
+
"you need to choose the optimal action acoordingly. "
|
102 |
+
human_template += 'Response and interact using the format: {reply_format_description}{format_instructions}\n'
|
103 |
+
human_template += self._read_mem()
|
104 |
+
human_template += "\n\nHere are some history states listed below:\n"
|
105 |
+
|
106 |
+
fewshot_example_prompt = PromptTemplate(
|
107 |
+
input_variables=["question", "answer"],
|
108 |
+
template="Question: \n{question}\n{answer}"
|
109 |
+
)
|
110 |
+
human_message_prompt = FewShotPromptTemplate(
|
111 |
+
examples=self.fewshot_example,
|
112 |
+
example_prompt=fewshot_example_prompt,
|
113 |
+
suffix=human_template,
|
114 |
+
input_variables=[
|
115 |
+
'game_description', 'goal_description',
|
116 |
+
'action_description', 'reply_format_description'],
|
117 |
+
partial_variables={'format_instructions': self.parser.get_format_instructions()}
|
118 |
+
)
|
119 |
+
human_message_prompt = HumanMessagePromptTemplate(prompt=human_message_prompt)
|
120 |
+
|
121 |
+
short_memory_template = HumanMessagePromptTemplate.from_template("{history} Please select an action based on the current game state. You must select the appropriate action from the given action descriptions and cannot refrain from taking action or perform any prohibited actions. Here's the action description below: \n {action_description}\n")
|
122 |
+
|
123 |
+
chat_prompt = ChatPromptTemplate.from_messages(
|
124 |
+
[system_message_prompt, human_message_prompt, short_memory_template])
|
125 |
+
|
126 |
+
if logfile:
|
127 |
+
# logger.remove()
|
128 |
+
if self.first_call:
|
129 |
+
logger.add(logfile, colorize=True, enqueue=True, filter=lambda x: '[Reflexion Memory]' not in x['message'])
|
130 |
+
self.first_call = False
|
131 |
+
handler = FileCallbackHandler(logfile)
|
132 |
+
total_tokens, total_cost = 0, 0
|
133 |
+
max_think_times = 1
|
134 |
+
# TODO: ADD REACT Support
|
135 |
+
# print(str(self.env_history))
|
136 |
+
if self.use_short_mem:
|
137 |
+
my_history = str(self.env_history)
|
138 |
+
else:
|
139 |
+
my_history = ""
|
140 |
+
for i_think in range(max_think_times):
|
141 |
+
chain = LLMChain(llm=chat, prompt=chat_prompt, callbacks=[handler], verbose=False)
|
142 |
+
with get_openai_callback() as cb:
|
143 |
+
response = chain.run(
|
144 |
+
game_description=game_description,
|
145 |
+
goal_description=goal_description,
|
146 |
+
action_description=action_description,
|
147 |
+
history=self.env_history.get_histories(11),
|
148 |
+
format_instructions=self.parser.get_format_instructions(),
|
149 |
+
reply_format_description=reply_format_description,
|
150 |
+
max_token = 3000
|
151 |
+
)
|
152 |
+
|
153 |
+
total_tokens += cb.total_tokens
|
154 |
+
total_cost += cb.total_cost
|
155 |
+
action = self.parser.parse(response).action
|
156 |
+
|
157 |
+
text_prompt = chat_prompt.format_messages(
|
158 |
+
game_description=game_description,
|
159 |
+
goal_description=goal_description,
|
160 |
+
action_description=action_description,
|
161 |
+
history=self.env_history.get_histories(11),
|
162 |
+
format_instructions=self.parser.get_format_instructions(),
|
163 |
+
reply_format_description=reply_format_description,
|
164 |
+
)
|
165 |
+
texts = ""
|
166 |
+
for text in text_prompt:
|
167 |
+
texts += text.content + "\n"
|
168 |
+
|
169 |
+
self._add_history_after_action(action)
|
170 |
+
logger.info(f'The GPT response is: {response}.')
|
171 |
+
logger.info(f'The optimal action is: {action}.')
|
172 |
+
if self.pre_memory:
|
173 |
+
logger.info(f'The suggestion is: {self.pre_memory[-1]}.')
|
174 |
+
if self.post_memory:
|
175 |
+
logger.info(f'The summary is: {self.post_memory[-1]}.')
|
176 |
+
if env_info.get('history'):
|
177 |
+
logger.info(f'History: {history_to_str(env_info["history"])}')
|
178 |
+
|
179 |
+
return action, texts, response, logger, total_tokens, total_cost
|
deciders/misc.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def history_to_str(history):
|
2 |
+
history_str = ""
|
3 |
+
for d in history:
|
4 |
+
history_str += f"state: {d['state']}, action: {d['action']}, reward: {d['reward']}\n"
|
5 |
+
return history_str
|
6 |
+
|
7 |
+
def get_majority_vote(actions):
|
8 |
+
return max(set(actions), key=actions.count)
|
9 |
+
|
10 |
+
def test_get_majority_vote():
|
11 |
+
assert get_majority_vote([1, 1, 1, 2, 2]) == 1
|
12 |
+
assert get_majority_vote([1, 1, 2, 2, 2]) == 2
|
13 |
+
assert get_majority_vote([1, 1, 2, 2, 3]) == 1
|
14 |
+
assert get_majority_vote([1, 2, 3, 4, 5]) == 1
|
15 |
+
assert get_majority_vote([1, 2, 3, 4, 5, 1, 1, 1, 1, 1]) == 1
|
16 |
+
assert get_majority_vote([1, 2, 3, 4, 5, 1, 1, 1, 1, 2]) == 1
|
17 |
+
assert get_majority_vote([1, 2, 3, 4, 5, 1, 1, 1, 2, 2]) == 1
|
18 |
+
assert get_majority_vote([1, 2, 3, 4, 5, 1, 1, 2, 2, 2]) == 2
|
19 |
+
|
20 |
+
if __name__ == "__main__":
|
21 |
+
test_get_majority_vote()
|
deciders/pal.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import openai
|
2 |
+
from .misc import history_to_str
|
3 |
+
from langchain.chat_models import AzureChatOpenAI
|
4 |
+
from langchain.prompts.chat import (
|
5 |
+
PromptTemplate,
|
6 |
+
ChatPromptTemplate,
|
7 |
+
SystemMessagePromptTemplate,
|
8 |
+
HumanMessagePromptTemplate,
|
9 |
+
)
|
10 |
+
from langchain.prompts.few_shot import FewShotPromptTemplate
|
11 |
+
from langchain import LLMChain
|
12 |
+
from loguru import logger
|
13 |
+
from langchain.callbacks import FileCallbackHandler
|
14 |
+
from langchain.callbacks import get_openai_callback
|
15 |
+
from .act import NaiveAct
|
16 |
+
from .utils import run_chain
|
17 |
+
|
18 |
+
def get_last_n_lines(text, n):
|
19 |
+
lines = text.splitlines()
|
20 |
+
return '\n'.join(lines[-n:])
|
21 |
+
|
22 |
+
class PAL(NaiveAct):
|
23 |
+
def __init__(self, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None, logger=None):
|
24 |
+
super().__init__(action_space, args, prompts, distiller, temperature, max_tokens, logger)
|
25 |
+
|
26 |
+
def act(
|
27 |
+
self,
|
28 |
+
state_description,
|
29 |
+
action_description,
|
30 |
+
env_info,
|
31 |
+
game_description,
|
32 |
+
goal_description,
|
33 |
+
logfile=None,
|
34 |
+
):
|
35 |
+
self._add_history_before_action(game_description, goal_description, state_description)
|
36 |
+
chat = AzureChatOpenAI(
|
37 |
+
openai_api_type=openai.api_type,
|
38 |
+
openai_api_version=openai.api_version,
|
39 |
+
openai_api_base=openai.api_base,
|
40 |
+
openai_api_key=openai.api_key,
|
41 |
+
deployment_name=self.args.gpt_version,
|
42 |
+
temperature=self.temperature,
|
43 |
+
max_tokens=self.max_tokens
|
44 |
+
)
|
45 |
+
|
46 |
+
suffix_flag = False
|
47 |
+
reply_format_description = \
|
48 |
+
"Your response should choose an optimal action from a valid action list and terminate with the following format: "
|
49 |
+
|
50 |
+
# System Message
|
51 |
+
human_template = "Now, you are completing a challenging task. You must carefully understand the Program-aided Language method you will use and apply it to the following task.\n"
|
52 |
+
|
53 |
+
# task-irrelevant SystemMessage
|
54 |
+
if self.irr_few_shot_examples:
|
55 |
+
human_template += 'In the following example, I shall present a set of question and answer with the Program-aided Language method. Please adhere to the format and reasoning of the provided response when addressing the subsequent task.\n'
|
56 |
+
for i, examples in enumerate(self.irr_few_shot_examples):
|
57 |
+
human_template += f"\nExample {i+1}:\n"
|
58 |
+
human_template += "Question: \n" + examples['question'] + "\nAnswer: \n" + examples['answer']
|
59 |
+
|
60 |
+
# task-irrelevant few shot if have
|
61 |
+
if self.irr_few_shot_examples:
|
62 |
+
human_template += "\nMoving forward, I will describe the task, the goal, and the actions you may execute. Please pay close attention to comprehend the information presented below.\n"
|
63 |
+
|
64 |
+
if self.fewshot_example:
|
65 |
+
human_template += "I will describe the task, the goal, and the actions you may execute. Please pay close attention to comprehend the information presented below."
|
66 |
+
# print(fewshot_example_prompt.format(**fewshot_examples[0]))
|
67 |
+
human_template += '\nTask Description: {game_description} \n'
|
68 |
+
human_template += 'Goal Description: {goal_description}\n'
|
69 |
+
human_template += 'Actions Description: {action_description}\n'
|
70 |
+
|
71 |
+
if self.fewshot_example:
|
72 |
+
human_template += "Here, I will provide you with some guidance to help you better understand the rules of the task. Next are some examples: "
|
73 |
+
for i, examples in enumerate(self.fewshot_example):
|
74 |
+
human_template += f"\nExample {i+1}:\n"
|
75 |
+
human_template += "Question: \n" + examples['question'] + "\nAnswer: \n" + examples['answer']
|
76 |
+
|
77 |
+
if self.prompt_level in [2, 3, 4]:
|
78 |
+
if self.memory:
|
79 |
+
human_template += '\nSubsequently, I will offer pertinent guidance or information about the task. Please utilize this instruction to accomplish the given task effectively.\n'
|
80 |
+
suffix_flag = True
|
81 |
+
if self.prompt_level == 2:
|
82 |
+
human_template += 'I have collected a few trajectories from a random policy, and the summaries are listed below.'
|
83 |
+
elif self.prompt_level == 3:
|
84 |
+
human_template += 'I have collected a few trajectories before, and the summaries are listed below.'
|
85 |
+
elif self.prompt_level == 4:
|
86 |
+
human_template += 'I have collected a few trajectories from an expert policy, and the summaries are listed below.'
|
87 |
+
human_template += self._read_mem() + "\n"
|
88 |
+
|
89 |
+
if self.use_short_mem:
|
90 |
+
if len(self.env_history) > 1:
|
91 |
+
if not suffix_flag:
|
92 |
+
human_template += '\nSubsequently, I will offer pertinent guidance or information about the task. Please utilize this instruction to accomplish the given task effectively.'
|
93 |
+
human_template += f"\nBelow are the latest {min(self.args.short_mem_num,len(self.env_history)//2)} historical data entries:\n"
|
94 |
+
human_template += f"{self.env_history.get_histories(self.mem_num)}"
|
95 |
+
human_template += '\nNext is the observation that the agent gets:\nCurrent {state_description}\n'
|
96 |
+
human_template += 'Please select an action based on the current game state and the information you get. You must select the appropriate action from the given action descriptions and cannot refrain from taking action or performing any prohibited actions. Here is the action description below:\n{action_description}\n'
|
97 |
+
human_template += 'Please generate Python program as answers to given questions, similar to the provided examples.\n'
|
98 |
+
human_template += 'And You should calculate the final result based on the program ,not just give a code script alone!\n'
|
99 |
+
|
100 |
+
human_message_prompt = PromptTemplate(
|
101 |
+
template=human_template,
|
102 |
+
input_variables=[
|
103 |
+
'state_description', 'goal_description', 'game_description',
|
104 |
+
'action_description'],
|
105 |
+
)
|
106 |
+
|
107 |
+
human_message_prompt = HumanMessagePromptTemplate(prompt=human_message_prompt)
|
108 |
+
|
109 |
+
chat_prompt = ChatPromptTemplate.from_messages([human_message_prompt])
|
110 |
+
|
111 |
+
if not self.logger:
|
112 |
+
logger.remove()
|
113 |
+
self.logger = logger.add(logfile, colorize=True, enqueue=True)
|
114 |
+
handler = FileCallbackHandler(logfile)
|
115 |
+
|
116 |
+
chain = LLMChain(llm=chat, prompt=chat_prompt, callbacks=[handler], verbose=False)
|
117 |
+
|
118 |
+
with get_openai_callback() as cb:
|
119 |
+
response = run_chain(
|
120 |
+
chain,
|
121 |
+
game_description=game_description,
|
122 |
+
state_description=state_description,
|
123 |
+
goal_description=goal_description,
|
124 |
+
action_description=action_description,
|
125 |
+
)
|
126 |
+
total_tokens = cb.total_tokens
|
127 |
+
total_cost = cb.total_cost
|
128 |
+
_response = get_last_n_lines(response, 2)
|
129 |
+
|
130 |
+
|
131 |
+
action = self.parser.parse(_response).action
|
132 |
+
|
133 |
+
text_prompt = chat_prompt.format_messages(
|
134 |
+
game_description=game_description,
|
135 |
+
state_description=state_description,
|
136 |
+
goal_description=goal_description,
|
137 |
+
action_description=action_description,
|
138 |
+
)
|
139 |
+
texts = ""
|
140 |
+
for text in text_prompt:
|
141 |
+
texts += text.content + "\n"
|
142 |
+
|
143 |
+
self._add_history_after_action(action)
|
144 |
+
self.logger.info(f'The GPT response is: {response}.')
|
145 |
+
self.logger.info(f'The optimal action is: {action}.')
|
146 |
+
if env_info.get('history'):
|
147 |
+
self.logger.info(f'History: {history_to_str(env_info["history"])}')
|
148 |
+
|
149 |
+
return action, texts, response, total_tokens, total_cost
|
deciders/parser.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel, Field, validator
|
2 |
+
|
3 |
+
# Define your desired data structure.
|
4 |
+
class TwoAction(BaseModel):
|
5 |
+
action: int = Field(description="the choosed action to perform")
|
6 |
+
|
7 |
+
# You can add custom validation logic easily with Pydantic.
|
8 |
+
@validator('action')
|
9 |
+
def action_is_valid(cls, field):
|
10 |
+
if field not in [1, 2]:
|
11 |
+
raise ValueError("Action is not valid ([1, 2])!")
|
12 |
+
return field
|
13 |
+
|
14 |
+
class ThreeAction(BaseModel):
|
15 |
+
action: int = Field(description="the choosed action to perform")
|
16 |
+
|
17 |
+
# You can add custom validation logic easily with Pydantic.
|
18 |
+
@validator('action')
|
19 |
+
def action_is_valid(cls, field):
|
20 |
+
if field not in [1, 2, 3]:
|
21 |
+
raise ValueError("Action is not valid ([1, 2, 3])!")
|
22 |
+
return field
|
23 |
+
|
24 |
+
class FourAction(BaseModel):
|
25 |
+
action: int = Field(description="the choosed action to perform")
|
26 |
+
|
27 |
+
# You can add custom validation logic easily with Pydantic.
|
28 |
+
@validator('action')
|
29 |
+
def action_is_valid(cls, field):
|
30 |
+
if field not in [1, 2, 3, 4]:
|
31 |
+
raise ValueError("Action is not valid ([1, 2, 3, 4])!")
|
32 |
+
return field
|
33 |
+
|
34 |
+
class SixAction(BaseModel):
|
35 |
+
action: int = Field(description="the choosed action to perform")
|
36 |
+
|
37 |
+
# You can add custom validation logic easily with Pydantic.
|
38 |
+
@validator('action')
|
39 |
+
def action_is_valid(cls, field):
|
40 |
+
if field not in [1, 2, 3, 4, 5, 6]:
|
41 |
+
raise ValueError("Action is not valid ([1, 2, 3, 4, 5, 6])!")
|
42 |
+
return field
|
43 |
+
|
44 |
+
class ContinuousAction(BaseModel):
|
45 |
+
action: float = Field(description="the choosed action to perform")
|
46 |
+
# You can add custom validation logic easily with Pydantic.
|
47 |
+
@validator('action')
|
48 |
+
def action_is_valid(cls, field):
|
49 |
+
if not (field >= -1 and field <= 1):
|
50 |
+
raise ValueError("Action is not valid ([-1,1])!")
|
51 |
+
return field
|
52 |
+
|
53 |
+
PARSERS = {1:ContinuousAction, 2: TwoAction, 3: ThreeAction, 4: FourAction, 6: SixAction}
|
deciders/reflexion.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import openai
|
2 |
+
from .misc import history_to_str
|
3 |
+
from langchain.chat_models import AzureChatOpenAI
|
4 |
+
from langchain.prompts.chat import (
|
5 |
+
PromptTemplate,
|
6 |
+
ChatPromptTemplate,
|
7 |
+
SystemMessagePromptTemplate,
|
8 |
+
HumanMessagePromptTemplate,
|
9 |
+
)
|
10 |
+
from langchain.prompts.few_shot import FewShotPromptTemplate
|
11 |
+
from langchain import LLMChain
|
12 |
+
from loguru import logger
|
13 |
+
from langchain.callbacks import FileCallbackHandler
|
14 |
+
from langchain.callbacks import get_openai_callback
|
15 |
+
from .act import NaiveAct
|
16 |
+
from memory.env_history import EnvironmentHistory
|
17 |
+
import tiktoken
|
18 |
+
from .utils import run_chain
|
19 |
+
|
20 |
+
|
21 |
+
class Reflexion(NaiveAct):
|
22 |
+
def __init__(self, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None, logger=None):
|
23 |
+
super().__init__(action_space, args, prompts, distiller, temperature, max_tokens, logger)
|
24 |
+
|
25 |
+
def num_tokens_from_string(self,string: str) -> int:
|
26 |
+
"""Returns the number of tokens in a text string."""
|
27 |
+
num_tokens = len(self.encoding.encode(string))
|
28 |
+
return num_tokens
|
29 |
+
|
30 |
+
def update_mem(self,):
|
31 |
+
traj = self.game_description
|
32 |
+
traj += self.goal_description
|
33 |
+
one_history_token = self.num_tokens_from_string(self.env_history.get_one_history())
|
34 |
+
history_num = 4000 // one_history_token
|
35 |
+
traj += self.env_history.get_histories_with_last(history_num)
|
36 |
+
self._update_mem(traj)
|
37 |
+
|
38 |
+
def _update_mem(self, traj):
|
39 |
+
my_reflection = self.distiller.generate(traj, self.memory)
|
40 |
+
self.memory.append(my_reflection)
|
41 |
+
self.env_history.reset()
|
42 |
+
|
43 |
+
def act(
|
44 |
+
self,
|
45 |
+
state_description,
|
46 |
+
action_description,
|
47 |
+
env_info,
|
48 |
+
game_description,
|
49 |
+
goal_description,
|
50 |
+
logfile=None,
|
51 |
+
):
|
52 |
+
self.action_description = action_description
|
53 |
+
self.game_description = game_description
|
54 |
+
self.goal_description = goal_description
|
55 |
+
self.env_history.add("observation", state_description)
|
56 |
+
chat = AzureChatOpenAI(
|
57 |
+
openai_api_type=openai.api_type,
|
58 |
+
openai_api_version=openai.api_version,
|
59 |
+
openai_api_base=openai.api_base,
|
60 |
+
openai_api_key=openai.api_key,
|
61 |
+
deployment_name=self.args.gpt_version,
|
62 |
+
temperature=self.temperature,
|
63 |
+
max_tokens=self.max_tokens,
|
64 |
+
)
|
65 |
+
suffix_flag = False
|
66 |
+
reply_format_description = \
|
67 |
+
"Your response should choose an optimal action from a valid action list and terminate with the following format: "
|
68 |
+
|
69 |
+
# System Message
|
70 |
+
human_template = "Now, you are completing a challenging task. You must carefully understand the Reflexion method you will use and apply it to the following task.\n"
|
71 |
+
|
72 |
+
# task-irrelevant SystemMessage
|
73 |
+
if self.irr_few_shot_examples:
|
74 |
+
human_template += 'In the following example, I shall present a set of question and answer about the Reflexion method. Please adhere to the format and reasoning of the provided response when addressing the subsequent task.\n'
|
75 |
+
for i, examples in enumerate(self.irr_few_shot_examples):
|
76 |
+
human_template += f"\nExample {i+1}:\n"
|
77 |
+
human_template += "Question: \n" + examples['question'] + "\nAnswer: \n" + examples['answer']
|
78 |
+
|
79 |
+
# task-irrelevant few shot if have
|
80 |
+
if self.irr_few_shot_examples:
|
81 |
+
human_template += "\nMoving forward, I will describe the task, the goal, and the actions you may execute. Please pay close attention to comprehend the information presented below.\n"
|
82 |
+
|
83 |
+
if self.fewshot_example:
|
84 |
+
human_template += "I will describe the task, the goal, and the actions you may execute. Please pay close attention to comprehend the information presented below."
|
85 |
+
# print(fewshot_example_prompt.format(**fewshot_examples[0]))
|
86 |
+
human_template += '\nTask Description: {game_description} \n'
|
87 |
+
human_template += 'Goal Description: {goal_description}\n'
|
88 |
+
human_template += 'Actions Description: {action_description}\n'
|
89 |
+
|
90 |
+
if self.fewshot_example:
|
91 |
+
human_template += "Here, I will provide you with some guidance to help you better understand the rules of the task. Next are some examples: "
|
92 |
+
for i, examples in enumerate(self.fewshot_example):
|
93 |
+
human_template += f"\nExample {i+1}:\n"
|
94 |
+
human_template += "Question: \n" + examples['question'] + "\nAnswer: \n" + examples['answer']
|
95 |
+
|
96 |
+
if self.prompt_level in [2, 3, 4]:
|
97 |
+
if self.memory:
|
98 |
+
human_template += '\nSubsequently, I will offer pertinent guidance or information about the task. Please utilize this instruction to accomplish the given task effectively.\n'
|
99 |
+
suffix_flag = True
|
100 |
+
if self.prompt_level == 2:
|
101 |
+
human_template += 'I have collected a few trajectories from a random policy, and the summaries are listed below.'
|
102 |
+
elif self.prompt_level == 3:
|
103 |
+
human_template += 'I have collected a few trajectories before, and the summaries are listed below.'
|
104 |
+
elif self.prompt_level == 4:
|
105 |
+
human_template += 'I have collected a few trajectories from an expert policy, and the summaries are listed below.'
|
106 |
+
human_template += self._read_mem() + "\n"
|
107 |
+
|
108 |
+
if self.use_short_mem:
|
109 |
+
if len(self.env_history) > 1:
|
110 |
+
if not suffix_flag:
|
111 |
+
human_template += '\nSubsequently, I will offer pertinent guidance or information about the task. Please utilize this instruction to accomplish the given task effectively.'
|
112 |
+
human_template += f"\nBelow are the latest {self.mem_num} historical data entries:\n"
|
113 |
+
human_template += f"{self.env_history.get_histories(self.mem_num)}"
|
114 |
+
human_template += '\nNext is the observation that the agent gets:\nCurrent {state_description}\n'
|
115 |
+
human_template += 'Please select an action based on the current game state and the information you get. You must select the appropriate action from the given action descriptions and cannot refrain from taking action or performing any prohibited actions. Here is the action description below:\n{action_description}\n'
|
116 |
+
human_template += 'Also, please keep in mind not to answer with any redundant and irrelevant content.\n'
|
117 |
+
human_template += "Finally, you also need to normalize your output according to the reply format description.\n"
|
118 |
+
human_template += 'Reply format description: {reply_format_description}{format_instructions}\n'
|
119 |
+
|
120 |
+
human_message_prompt = PromptTemplate(
|
121 |
+
template=human_template,
|
122 |
+
input_variables=[
|
123 |
+
'state_description', 'goal_description', 'game_description',
|
124 |
+
'action_description', 'reply_format_description'],
|
125 |
+
partial_variables={'format_instructions': self.parser.get_format_instructions()}
|
126 |
+
)
|
127 |
+
|
128 |
+
human_message_prompt = HumanMessagePromptTemplate(prompt=human_message_prompt)
|
129 |
+
|
130 |
+
chat_prompt = ChatPromptTemplate.from_messages([human_message_prompt])
|
131 |
+
if not self.logger:
|
132 |
+
# logger.remove()
|
133 |
+
if self.first_call:
|
134 |
+
self.logger = logger.add(logfile, colorize=True, enqueue=True, filter=lambda x: '[Reflexion Memory]' not in x['message'])
|
135 |
+
self.first_call = False
|
136 |
+
handler = FileCallbackHandler(logfile)
|
137 |
+
total_tokens, total_cost = 0, 0
|
138 |
+
max_think_times = 1
|
139 |
+
# TODO: ADD REACT Support
|
140 |
+
# print(str(self.env_history))
|
141 |
+
|
142 |
+
for i_think in range(max_think_times):
|
143 |
+
chain = LLMChain(llm=chat, prompt=chat_prompt, callbacks=[handler], verbose=False)
|
144 |
+
with get_openai_callback() as cb:
|
145 |
+
response = run_chain(
|
146 |
+
chain,
|
147 |
+
state_description=self.env_history.get_last_history(),
|
148 |
+
game_description=game_description,
|
149 |
+
goal_description=goal_description,
|
150 |
+
action_description=action_description,
|
151 |
+
format_instructions=self.parser.get_format_instructions(),
|
152 |
+
reply_format_description=reply_format_description,
|
153 |
+
max_token = 3000
|
154 |
+
)
|
155 |
+
|
156 |
+
total_tokens += cb.total_tokens
|
157 |
+
total_cost += cb.total_cost
|
158 |
+
action = self.parser.parse(response).action
|
159 |
+
text_prompt = chat_prompt.format_messages(
|
160 |
+
state_description=self.env_history.get_last_history(),
|
161 |
+
game_description=game_description,
|
162 |
+
goal_description=goal_description,
|
163 |
+
action_description=action_description,
|
164 |
+
format_instructions=self.parser.get_format_instructions(),
|
165 |
+
reply_format_description=reply_format_description,
|
166 |
+
)
|
167 |
+
texts = ""
|
168 |
+
for text in text_prompt:
|
169 |
+
texts += text.content + "\n"
|
170 |
+
|
171 |
+
self._add_history_after_action(action)
|
172 |
+
self.logger.info(f'The GPT response is: {response}.')
|
173 |
+
self.logger.info(f'The optimal action is: {action}.')
|
174 |
+
if self.memory:
|
175 |
+
self.logger.info(f'The memory is: {self.memory[-1]}.')
|
176 |
+
if env_info.get('history'):
|
177 |
+
self.logger.info(f'History: {history_to_str(env_info["history"])}')
|
178 |
+
|
179 |
+
return action, texts, response, total_tokens, total_cost
|
deciders/self_consistency.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import openai
|
2 |
+
from .misc import history_to_str
|
3 |
+
from langchain.chat_models import AzureChatOpenAI
|
4 |
+
from langchain.prompts.chat import (
|
5 |
+
PromptTemplate,
|
6 |
+
ChatPromptTemplate,
|
7 |
+
SystemMessagePromptTemplate,
|
8 |
+
HumanMessagePromptTemplate,
|
9 |
+
)
|
10 |
+
from langchain.prompts.few_shot import FewShotPromptTemplate
|
11 |
+
from langchain import LLMChain
|
12 |
+
from loguru import logger
|
13 |
+
from langchain.callbacks import FileCallbackHandler
|
14 |
+
from langchain.callbacks import get_openai_callback
|
15 |
+
from .act import NaiveAct
|
16 |
+
from .utils import run_chain
|
17 |
+
|
18 |
+
|
19 |
+
class SelfConsistency(NaiveAct):
|
20 |
+
def __init__(self, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None, logger=None):
|
21 |
+
temperature = 0.7
|
22 |
+
super().__init__(action_space, args, prompts, distiller, temperature, max_tokens, logger)
|
23 |
+
self.temperature = temperature
|
24 |
+
|
25 |
+
def act(
|
26 |
+
self,
|
27 |
+
state_description,
|
28 |
+
action_description,
|
29 |
+
env_info,
|
30 |
+
game_description,
|
31 |
+
goal_description,
|
32 |
+
logfile=None,
|
33 |
+
):
|
34 |
+
# print(self.temperature)
|
35 |
+
self.action_description = action_description
|
36 |
+
self._add_history_before_action(game_description, goal_description, state_description)
|
37 |
+
chat = AzureChatOpenAI(
|
38 |
+
openai_api_type=openai.api_type,
|
39 |
+
openai_api_version=openai.api_version,
|
40 |
+
openai_api_base=openai.api_base,
|
41 |
+
openai_api_key=openai.api_key,
|
42 |
+
deployment_name=self.args.gpt_version,
|
43 |
+
temperature=self.temperature,
|
44 |
+
max_tokens=self.max_tokens
|
45 |
+
)
|
46 |
+
|
47 |
+
suffix_flag = False
|
48 |
+
reply_format_description = \
|
49 |
+
"Your response should choose an optimal action from a valid action list and terminate with the following format: "
|
50 |
+
|
51 |
+
# System Message
|
52 |
+
human_template = "Now, you are completing a challenging task. You must carefully understand the Self-Consistency method you will use and apply it to the following task.\n"
|
53 |
+
|
54 |
+
# task-irrelevant SystemMessage
|
55 |
+
if self.irr_few_shot_examples:
|
56 |
+
human_template += 'In the following example, I shall present a set of question and answer with the Self-Consistency method. Please adhere to the format and reasoning of the provided response when addressing the subsequent task.\n'
|
57 |
+
for i, examples in enumerate(self.irr_few_shot_examples):
|
58 |
+
human_template += f"\nExample {i+1}:\n"
|
59 |
+
human_template += "Question: \n" + examples['question'] + "\nAnswer: \n" + examples['answer']
|
60 |
+
|
61 |
+
# task-irrelevant few shot if have
|
62 |
+
if self.irr_few_shot_examples:
|
63 |
+
human_template += "\nMoving forward, I will describe the task, the goal, and the actions you may execute. Please pay close attention to comprehend the information presented below.\n"
|
64 |
+
|
65 |
+
if self.fewshot_example:
|
66 |
+
human_template += "I will describe the task, the goal, and the actions you may execute. Please pay close attention to comprehend the information presented below."
|
67 |
+
# print(fewshot_example_prompt.format(**fewshot_examples[0]))
|
68 |
+
human_template += '\nTask Description: {game_description} \n'
|
69 |
+
human_template += 'Goal Description: {goal_description}\n'
|
70 |
+
human_template += 'Actions Description: {action_description}\n'
|
71 |
+
|
72 |
+
if self.fewshot_example:
|
73 |
+
human_template += "Here, I will provide you with some guidance to help you better understand the rules of the task. Next are some examples: "
|
74 |
+
for i, examples in enumerate(self.fewshot_example):
|
75 |
+
human_template += f"\nExample {i+1}:\n"
|
76 |
+
human_template += "Question: \n" + examples['question'] + "\nAnswer: \n" + examples['answer']
|
77 |
+
|
78 |
+
if self.prompt_level in [2, 3, 4]:
|
79 |
+
if self.memory:
|
80 |
+
human_template += '\nSubsequently, I will offer pertinent guidance or information about the task. Please utilize this instruction to accomplish the given task effectively.\n'
|
81 |
+
suffix_flag = True
|
82 |
+
if self.prompt_level == 2:
|
83 |
+
human_template += 'I have collected a few trajectories from a random policy, and the summaries are listed below.'
|
84 |
+
elif self.prompt_level == 3:
|
85 |
+
human_template += 'I have collected a few trajectories before, and the summaries are listed below.'
|
86 |
+
elif self.prompt_level == 4:
|
87 |
+
human_template += 'I have collected a few trajectories from an expert policy, and the summaries are listed below.'
|
88 |
+
human_template += self._read_mem() + "\n"
|
89 |
+
|
90 |
+
if self.use_short_mem:
|
91 |
+
if len(self.env_history) > 1:
|
92 |
+
if not suffix_flag:
|
93 |
+
human_template += '\nSubsequently, I will offer pertinent guidance or information about the task. Please utilize this instruction to accomplish the given task effectively.'
|
94 |
+
human_template += f"\nBelow are the latest {self.args.short_mem_num} historical data entries:\n"
|
95 |
+
human_template += f"{self.env_history.get_histories(self.mem_num)}"
|
96 |
+
human_template += '\nNext is the observation that the agent gets:\nCurrent {state_description}\n'
|
97 |
+
human_template += 'Please select an action based on the current game state and the information you get. You must select the appropriate action from the given action descriptions and cannot refrain from taking action or performing any prohibited actions. Here is the action description below:\n{action_description}\n'
|
98 |
+
human_template += 'Please note that you need to carefully lay out your thought process on the question, not just give an answer. You need to write the corresponding logic of your thinking following the example above. Also, please keep in mind not to answer with any redundant and irrelevant content.\n'
|
99 |
+
human_template += "Finally, you also need to normalize your output according to the reply format description.\n"
|
100 |
+
human_template += 'Reply format description: {reply_format_description}{format_instructions}\n'
|
101 |
+
|
102 |
+
human_message_prompt = PromptTemplate(
|
103 |
+
template=human_template,
|
104 |
+
input_variables=[
|
105 |
+
'state_description', 'goal_description', 'game_description',
|
106 |
+
'action_description', 'reply_format_description'],
|
107 |
+
partial_variables={'format_instructions': self.parser.get_format_instructions()}
|
108 |
+
)
|
109 |
+
|
110 |
+
human_message_prompt = HumanMessagePromptTemplate(prompt=human_message_prompt)
|
111 |
+
|
112 |
+
chat_prompt = ChatPromptTemplate.from_messages([human_message_prompt])
|
113 |
+
|
114 |
+
if not self.logger:
|
115 |
+
logger.remove()
|
116 |
+
self.logger = logger.add(logfile, colorize=True, enqueue=True)
|
117 |
+
handler = FileCallbackHandler(logfile)
|
118 |
+
|
119 |
+
chain = LLMChain(llm=chat, prompt=chat_prompt, callbacks=[handler], verbose=False)
|
120 |
+
|
121 |
+
text_prompt = chat_prompt.format_messages(
|
122 |
+
game_description=game_description,
|
123 |
+
state_description=state_description,
|
124 |
+
goal_description=goal_description,
|
125 |
+
action_description=action_description,
|
126 |
+
reply_format_description=reply_format_description
|
127 |
+
)
|
128 |
+
texts = ""
|
129 |
+
for text in text_prompt:
|
130 |
+
texts += text.content + "\n"
|
131 |
+
|
132 |
+
actions = []
|
133 |
+
response_dict = {}
|
134 |
+
error_flag = True
|
135 |
+
for i in range(5):
|
136 |
+
try:
|
137 |
+
with get_openai_callback() as cb:
|
138 |
+
response = run_chain(
|
139 |
+
chain,
|
140 |
+
game_description=game_description,
|
141 |
+
state_description=state_description,
|
142 |
+
goal_description=goal_description,
|
143 |
+
action_description=action_description,
|
144 |
+
reply_format_description=reply_format_description
|
145 |
+
)
|
146 |
+
total_tokens = cb.total_tokens
|
147 |
+
total_cost = cb.total_cost
|
148 |
+
action = self.parser.parse(response).action
|
149 |
+
actions.append(action)
|
150 |
+
response_dict[action] = response
|
151 |
+
|
152 |
+
self.logger.info(f'The GPT response is: {response}.')
|
153 |
+
self.logger.info(f'The optimal action is: {action}.\n')
|
154 |
+
except:
|
155 |
+
continue
|
156 |
+
|
157 |
+
action = max(set(actions), key=actions.count)
|
158 |
+
# print(actions)
|
159 |
+
# print(action)
|
160 |
+
if actions:
|
161 |
+
self._add_history_after_action(action)
|
162 |
+
self.logger.info(f'The action list is: {actions}.')
|
163 |
+
self.logger.info(f'The GPT response is: {response_dict[action]}.')
|
164 |
+
self.logger.info(f'The optimal action is: {action}.')
|
165 |
+
if env_info.get('history'):
|
166 |
+
self.logger.info(f'History: {history_to_str(env_info["history"])}')
|
167 |
+
else:
|
168 |
+
raise Exception("No valid Actions!")
|
169 |
+
|
170 |
+
return action, texts, response, total_tokens, total_cost
|
deciders/selfask.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import openai
|
2 |
+
from .misc import history_to_str
|
3 |
+
from langchain.chat_models import AzureChatOpenAI
|
4 |
+
from langchain.prompts.chat import (
|
5 |
+
PromptTemplate,
|
6 |
+
ChatPromptTemplate,
|
7 |
+
SystemMessagePromptTemplate,
|
8 |
+
HumanMessagePromptTemplate,
|
9 |
+
)
|
10 |
+
from langchain.prompts.few_shot import FewShotPromptTemplate
|
11 |
+
from langchain import LLMChain
|
12 |
+
from loguru import logger
|
13 |
+
from langchain.callbacks import FileCallbackHandler
|
14 |
+
from langchain.callbacks import get_openai_callback
|
15 |
+
from .act import NaiveAct
|
16 |
+
from .utils import run_chain
|
17 |
+
|
18 |
+
|
19 |
+
class SelfAskAct(NaiveAct):
|
20 |
+
def __init__(self, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None, logger=None):
|
21 |
+
super().__init__(action_space, args, prompts, distiller, temperature, max_tokens,logger)
|
22 |
+
|
23 |
+
def act(
|
24 |
+
self,
|
25 |
+
state_description,
|
26 |
+
action_description,
|
27 |
+
env_info,
|
28 |
+
game_description,
|
29 |
+
goal_description,
|
30 |
+
logfile=None,
|
31 |
+
):
|
32 |
+
self.action_description = action_description
|
33 |
+
self._add_history_before_action(game_description, goal_description, state_description)
|
34 |
+
chat = AzureChatOpenAI(
|
35 |
+
openai_api_type=openai.api_type,
|
36 |
+
openai_api_version=openai.api_version,
|
37 |
+
openai_api_base=openai.api_base,
|
38 |
+
openai_api_key=openai.api_key,
|
39 |
+
deployment_name=self.args.gpt_version,
|
40 |
+
temperature=self.temperature,
|
41 |
+
max_tokens=self.max_tokens
|
42 |
+
)
|
43 |
+
|
44 |
+
suffix_flag = False
|
45 |
+
reply_format_description = \
|
46 |
+
"Your response should choose an optimal action from a valid action list and terminate with the following format: "
|
47 |
+
|
48 |
+
# System Message
|
49 |
+
human_template = "Now, you are completing a challenging task. You must carefully understand the self-ask method you will use and apply it to the following task.\n"
|
50 |
+
|
51 |
+
# task-irrelevant SystemMessage
|
52 |
+
if self.irr_few_shot_examples:
|
53 |
+
human_template += 'In the following example, I shall present a set of question and answer with the self-ask method. Please adhere to the format and reasoning of the provided response when addressing the subsequent task.\n'
|
54 |
+
for i, examples in enumerate(self.irr_few_shot_examples):
|
55 |
+
human_template += f"\nExample {i+1}:\n"
|
56 |
+
human_template += "Question: \n" + examples['question'] + "\nAnswer: \n" + examples['answer']
|
57 |
+
|
58 |
+
# task-irrelevant few shot if have
|
59 |
+
if self.irr_few_shot_examples:
|
60 |
+
human_template += "\nMoving forward, I will describe the task, the goal, and the actions you may execute. Please pay close attention to comprehend the information presented below.\n"
|
61 |
+
|
62 |
+
if self.fewshot_example:
|
63 |
+
human_template += "I will describe the task, the goal, and the actions you may execute. Please pay close attention to comprehend the information presented below."
|
64 |
+
# print(fewshot_example_prompt.format(**fewshot_examples[0]))
|
65 |
+
human_template += '\nTask Description: {game_description} \n'
|
66 |
+
human_template += 'Goal Description: {goal_description}\n'
|
67 |
+
human_template += 'Actions Description: {action_description}\n'
|
68 |
+
|
69 |
+
if self.fewshot_example:
|
70 |
+
human_template += "Here, I will provide you with some guidance to help you better understand the rules of the task. Next are some examples: "
|
71 |
+
for i, examples in enumerate(self.fewshot_example):
|
72 |
+
human_template += f"\nExample {i+1}:\n"
|
73 |
+
human_template += "Question: \n" + examples['question'] + "\nAnswer: \n" + examples['answer']
|
74 |
+
|
75 |
+
if self.prompt_level in [2, 3, 4]:
|
76 |
+
if self.memory:
|
77 |
+
human_template += '\nSubsequently, I will offer pertinent guidance or information about the task. Please utilize this instruction to accomplish the given task effectively.\n'
|
78 |
+
suffix_flag = True
|
79 |
+
if self.prompt_level == 2:
|
80 |
+
human_template += 'I have collected a few trajectories from a random policy, and the summaries are listed below.'
|
81 |
+
elif self.prompt_level == 3:
|
82 |
+
human_template += 'I have collected a few trajectories before, and the summaries are listed below.'
|
83 |
+
elif self.prompt_level == 4:
|
84 |
+
human_template += 'I have collected a few trajectories from an expert policy, and the summaries are listed below.'
|
85 |
+
human_template += self._read_mem() + "\n"
|
86 |
+
|
87 |
+
if self.use_short_mem:
|
88 |
+
if len(self.env_history) > 1:
|
89 |
+
if not suffix_flag:
|
90 |
+
human_template += '\nSubsequently, I will offer pertinent guidance or information about the task. Please utilize this instruction to accomplish the given task effectively.'
|
91 |
+
human_template += f"\nBelow are the latest {self.args.short_mem_num} historical data entries:\n"
|
92 |
+
human_template += f"{self.env_history.get_histories(self.mem_num)}"
|
93 |
+
human_template += '\nNext is the observation that the agent gets:\nCurrent {state_description}\n'
|
94 |
+
human_template += 'Please select an action based on the current game state and the information you get. You must select the appropriate action from the given action descriptions and cannot refrain from taking action or performing any prohibited actions. Here is the action description below:\n{action_description}\n'
|
95 |
+
human_template += 'You must utilize a multi-turn dialogue approach, just as the format illustrated in the example above(like "Follow up" and "Intermediate answer"). And you need to write down the thought process during the self-ask process. Also, please keep in mind not to answer with any redundant and irrelevant content.\n'
|
96 |
+
human_template += "Finally, you also need to normalize your output according to the reply format description.\n"
|
97 |
+
human_template += 'Reply format description: {reply_format_description}{format_instructions}\n'
|
98 |
+
|
99 |
+
human_message_prompt = PromptTemplate(
|
100 |
+
template=human_template,
|
101 |
+
input_variables=[
|
102 |
+
'state_description', 'goal_description', 'game_description',
|
103 |
+
'action_description', 'reply_format_description'],
|
104 |
+
partial_variables={'format_instructions': self.parser.get_format_instructions()}
|
105 |
+
)
|
106 |
+
|
107 |
+
human_message_prompt = HumanMessagePromptTemplate(prompt=human_message_prompt)
|
108 |
+
|
109 |
+
chat_prompt = ChatPromptTemplate.from_messages([human_message_prompt])
|
110 |
+
|
111 |
+
if not self.logger:
|
112 |
+
logger.remove()
|
113 |
+
self.logger = logger.add(logfile, colorize=True, enqueue=True)
|
114 |
+
handler = FileCallbackHandler(logfile)
|
115 |
+
|
116 |
+
chain = LLMChain(
|
117 |
+
llm=chat, prompt=chat_prompt, callbacks=[handler], verbose=False)
|
118 |
+
|
119 |
+
with get_openai_callback() as cb:
|
120 |
+
response = run_chain(
|
121 |
+
chain,
|
122 |
+
game_description=game_description,
|
123 |
+
state_description=state_description,
|
124 |
+
goal_description=goal_description,
|
125 |
+
action_description=action_description,
|
126 |
+
reply_format_description=reply_format_description
|
127 |
+
)
|
128 |
+
total_tokens = cb.total_tokens
|
129 |
+
total_cost = cb.total_cost
|
130 |
+
action = self.parser.parse(response).action
|
131 |
+
|
132 |
+
text_prompt = chat_prompt.format_messages(
|
133 |
+
game_description=game_description,
|
134 |
+
state_description=state_description,
|
135 |
+
goal_description=goal_description,
|
136 |
+
action_description=action_description,
|
137 |
+
reply_format_description=reply_format_description
|
138 |
+
)
|
139 |
+
texts = ""
|
140 |
+
for text in text_prompt:
|
141 |
+
texts += text.content + "\n"
|
142 |
+
|
143 |
+
self._add_history_after_action(action)
|
144 |
+
|
145 |
+
self.logger.info(f'The GPT response is: {response}.')
|
146 |
+
self.logger.info(f'The optimal action is: {action}.')
|
147 |
+
if env_info.get('history'):
|
148 |
+
self.logger.info(f'History: {history_to_str(env_info["history"])}')
|
149 |
+
|
150 |
+
return action, texts, response, total_tokens, total_cost
|
deciders/spp.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import openai
|
2 |
+
from .misc import history_to_str
|
3 |
+
from langchain.chat_models import AzureChatOpenAI
|
4 |
+
from langchain.prompts.chat import (
|
5 |
+
PromptTemplate,
|
6 |
+
ChatPromptTemplate,
|
7 |
+
SystemMessagePromptTemplate,
|
8 |
+
HumanMessagePromptTemplate,
|
9 |
+
)
|
10 |
+
from langchain.prompts.few_shot import FewShotPromptTemplate
|
11 |
+
from langchain import LLMChain
|
12 |
+
from loguru import logger
|
13 |
+
from langchain.callbacks import FileCallbackHandler
|
14 |
+
from langchain.callbacks import get_openai_callback
|
15 |
+
from .act import NaiveAct
|
16 |
+
from .utils import run_chain
|
17 |
+
|
18 |
+
class SPP(NaiveAct):
|
19 |
+
def __init__(self, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None, logger=None):
|
20 |
+
super().__init__(action_space, args, prompts, distiller, temperature, max_tokens, logger)
|
21 |
+
|
22 |
+
def act(
|
23 |
+
self,
|
24 |
+
state_description,
|
25 |
+
action_description,
|
26 |
+
env_info,
|
27 |
+
game_description,
|
28 |
+
goal_description,
|
29 |
+
logfile=None,
|
30 |
+
):
|
31 |
+
self.action_description = action_description
|
32 |
+
self._add_history_before_action(game_description, goal_description, state_description)
|
33 |
+
chat = AzureChatOpenAI(
|
34 |
+
openai_api_type=openai.api_type,
|
35 |
+
openai_api_version=openai.api_version,
|
36 |
+
openai_api_base=openai.api_base,
|
37 |
+
openai_api_key=openai.api_key,
|
38 |
+
deployment_name=self.args.gpt_version,
|
39 |
+
temperature=self.temperature,
|
40 |
+
max_tokens=self.max_tokens
|
41 |
+
)
|
42 |
+
|
43 |
+
self.fewshot_example = self.irr_few_shot_examples if not self.fewshot_example else self.fewshot_example
|
44 |
+
self.irr_few_shot_examples = self.irr_few_shot_examples if not self.fewshot_example else self.fewshot_example
|
45 |
+
suffix_flag = False
|
46 |
+
reply_format_description = \
|
47 |
+
"Your response should choose an optimal action from a valid action list and terminate with the following format: "
|
48 |
+
|
49 |
+
# System Message
|
50 |
+
human_template = "When faced with a task, begin by identifying the participants who will contribute to solving the task. Then, initiate a multi-round collaboration process until a final solution is reached. The participants will give critical comments and detailed suggestions whenever necessary.\n"
|
51 |
+
human_template += "Now, you are completing a challenging task. You must carefully understand the Solo-Performance-Prompting method you will use and apply it to the following task.\n"
|
52 |
+
|
53 |
+
# task-irrelevant SystemMessage
|
54 |
+
if self.irr_few_shot_examples:
|
55 |
+
human_template += 'In the following example, I shall present a set of question and answer with the Solo-Performance-Prompting method. Please adhere to the format and reasoning of the provided response when addressing the subsequent task.\n'
|
56 |
+
for i, examples in enumerate(self.irr_few_shot_examples):
|
57 |
+
human_template += f"\nExample {i+1}:\n"
|
58 |
+
human_template += "Question: \n" + examples['question'] + "\nAnswer: \n" + examples['answer']
|
59 |
+
|
60 |
+
# task-irrelevant few shot if have
|
61 |
+
if self.irr_few_shot_examples:
|
62 |
+
human_template += "\nMoving forward, I will describe the task, the goal, and the actions you may execute. Please pay close attention to comprehend the information presented below.\n"
|
63 |
+
|
64 |
+
human_template += '\nTask Description: {game_description} \n'
|
65 |
+
human_template += 'Goal Description: {goal_description}\n'
|
66 |
+
human_template += 'Actions Description: {action_description}\n'
|
67 |
+
|
68 |
+
if self.prompt_level in [2, 3, 4]:
|
69 |
+
if self.memory:
|
70 |
+
human_template += '\nSubsequently, I will offer pertinent guidance or information about the task. Please utilize this instruction to accomplish the given task effectively.\n'
|
71 |
+
suffix_flag = True
|
72 |
+
if self.prompt_level == 2:
|
73 |
+
human_template += 'I have collected a few trajectories from a random policy, and the summaries are listed below.'
|
74 |
+
elif self.prompt_level == 3:
|
75 |
+
human_template += 'I have collected a few trajectories before, and the summaries are listed below.'
|
76 |
+
elif self.prompt_level == 4:
|
77 |
+
human_template += 'I have collected a few trajectories from an expert policy, and the summaries are listed below.'
|
78 |
+
human_template += self._read_mem() + "\n"
|
79 |
+
|
80 |
+
if self.use_short_mem:
|
81 |
+
if len(self.env_history) > 1:
|
82 |
+
if not suffix_flag:
|
83 |
+
human_template += '\nSubsequently, I will offer pertinent guidance or information about the task. Please utilize this instruction to accomplish the given task effectively.'
|
84 |
+
human_template += f"\nBelow are the latest {self.args.short_mem_num} historical data entries:\n"
|
85 |
+
human_template += f"{self.env_history.get_histories(self.mem_num)}"
|
86 |
+
human_template += '\nNext is the observation that the agent gets:\nCurrent {state_description}\n'
|
87 |
+
human_template += 'Please select an action based on the current game state and the information you get. You must select the appropriate action from the given action descriptions and cannot refrain from taking action or performing any prohibited actions. Here is the action description below:\n{action_description}\n'
|
88 |
+
human_template += 'Please note that you need to carefully lay out the participants who will contribute to solving the task and initiate a multi-round collaboration process until a final solution is reached. Now, identify the participants and collaboratively solve the following task step by step.Also, please keep in mind not to answer with any redundant and irrelevant content.\n'
|
89 |
+
human_template += "Finally, you also need to normalize your output according to the reply format description.\n"
|
90 |
+
human_template += 'Reply format description: {reply_format_description}{format_instructions}\n'
|
91 |
+
|
92 |
+
human_message_prompt = PromptTemplate(
|
93 |
+
template=human_template,
|
94 |
+
input_variables=[
|
95 |
+
'state_description', 'goal_description', 'game_description',
|
96 |
+
'action_description', 'reply_format_description'],
|
97 |
+
partial_variables={'format_instructions': self.parser.get_format_instructions()}
|
98 |
+
)
|
99 |
+
|
100 |
+
human_message_prompt = HumanMessagePromptTemplate(prompt=human_message_prompt)
|
101 |
+
|
102 |
+
chat_prompt = ChatPromptTemplate.from_messages([human_message_prompt])
|
103 |
+
|
104 |
+
if not self.logger:
|
105 |
+
logger.remove()
|
106 |
+
self.logger = logger.add(logfile, colorize=True, enqueue=True)
|
107 |
+
handler = FileCallbackHandler(logfile)
|
108 |
+
|
109 |
+
chain = LLMChain(llm=chat, prompt=chat_prompt, callbacks=[handler], verbose=False)
|
110 |
+
|
111 |
+
with get_openai_callback() as cb:
|
112 |
+
response = run_chain(
|
113 |
+
chain,
|
114 |
+
game_description=game_description,
|
115 |
+
state_description=state_description,
|
116 |
+
goal_description=goal_description,
|
117 |
+
action_description=action_description,
|
118 |
+
reply_format_description=reply_format_description
|
119 |
+
)
|
120 |
+
total_tokens = cb.total_tokens
|
121 |
+
total_cost = cb.total_cost
|
122 |
+
action = self.parser.parse(response).action
|
123 |
+
|
124 |
+
text_prompt = chat_prompt.format_messages(
|
125 |
+
game_description=game_description,
|
126 |
+
state_description=state_description,
|
127 |
+
goal_description=goal_description,
|
128 |
+
action_description=action_description,
|
129 |
+
reply_format_description=reply_format_description
|
130 |
+
)
|
131 |
+
texts = ""
|
132 |
+
for text in text_prompt:
|
133 |
+
texts += text.content + "\n"
|
134 |
+
|
135 |
+
self._add_history_after_action(action)
|
136 |
+
|
137 |
+
self.logger.info(f'The GPT response is: {response}.')
|
138 |
+
self.logger.info(f'The optimal action is: {action}.')
|
139 |
+
if env_info.get('history'):
|
140 |
+
self.logger.info(f'History: {history_to_str(env_info["history"])}')
|
141 |
+
|
142 |
+
return action, texts, response, total_tokens, total_cost
|
deciders/utils.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import openai
|
4 |
+
from openai import OpenAI
|
5 |
+
from tenacity import (
|
6 |
+
retry,
|
7 |
+
stop_after_attempt, # type: ignore
|
8 |
+
wait_random_exponential, # type: ignore
|
9 |
+
)
|
10 |
+
|
11 |
+
from typing import Optional, List
|
12 |
+
if sys.version_info >= (3, 8):
|
13 |
+
from typing import Literal
|
14 |
+
else:
|
15 |
+
from typing_extensions import Literal
|
16 |
+
|
17 |
+
|
18 |
+
Model = Literal["gpt-4", "gpt-35-turbo", "text-davinci-003"]
|
19 |
+
|
20 |
+
from .gpt import gpt
|
21 |
+
gpt().__init__()
|
22 |
+
|
23 |
+
import timeout_decorator
|
24 |
+
@timeout_decorator.timeout(30)
|
25 |
+
def run_chain(chain, *args, **kwargs):
|
26 |
+
return chain.run(*args, **kwargs)
|
27 |
+
|
28 |
+
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
29 |
+
def get_completion(prompt: str, engine: str = "gpt-35-turbo", temperature: float = 0.0, max_tokens: int = 256, stop_strs: Optional[List[str]] = None) -> str:
|
30 |
+
|
31 |
+
client = OpenAI(api_key=openai.api_key)
|
32 |
+
response = client.chat.completions.create(
|
33 |
+
model=engine,
|
34 |
+
prompt=prompt,
|
35 |
+
temperature=temperature,
|
36 |
+
max_tokens=max_tokens,
|
37 |
+
top_p=1,
|
38 |
+
frequency_penalty=0.0,
|
39 |
+
presence_penalty=0.0,
|
40 |
+
stop=stop_strs,
|
41 |
+
# request_timeout = 1
|
42 |
+
)
|
43 |
+
return response.choices[0].text
|
44 |
+
|
45 |
+
# @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
46 |
+
def get_chat(prompt: str, model: str = "gpt-35-turbo", engine: str = "gpt-35-turbo", temperature: float = 0.0, max_tokens: int = 256, stop_strs: Optional[List[str]] = None, is_batched: bool = False) -> str:
|
47 |
+
assert model != "text-davinci-003"
|
48 |
+
messages = [
|
49 |
+
{
|
50 |
+
"role": "user",
|
51 |
+
"content": prompt
|
52 |
+
}
|
53 |
+
]
|
54 |
+
# import pdb;pdb.set_trace()
|
55 |
+
client = OpenAI(api_key=openai.api_key)
|
56 |
+
|
57 |
+
response = client.chat.completions.create(
|
58 |
+
model=model,
|
59 |
+
messages=messages,
|
60 |
+
max_tokens=max_tokens,
|
61 |
+
stop=stop_strs,
|
62 |
+
temperature=temperature,
|
63 |
+
# request_timeout = 1
|
64 |
+
)
|
65 |
+
return response.choices[0].message.content
|
distillers/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .raw_prompt_generator import RawPromptGenerator
|
2 |
+
from .self_reflection import RefletionGenerator
|
3 |
+
from .traj_prompt_summarizer import TrajPromptSummarizer
|
4 |
+
from .guider import Guidance_Generator
|
5 |
+
|
6 |
+
REGISTRY = {}
|
7 |
+
REGISTRY['raw_distiller'] = RawPromptGenerator
|
8 |
+
REGISTRY['reflect_distiller'] = RefletionGenerator
|
9 |
+
REGISTRY['traj_distiller'] = TrajPromptSummarizer
|
10 |
+
REGISTRY['guide_generator'] = Guidance_Generator
|
distillers/guidance_summary_few_shot_examples.txt
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
You are in the middle of a room. Looking quickly around you, you see a cabinet 10, a cabinet 9, a cabinet 8, a cabinet 7, a cabinet 6, a cabinet 5, a cabinet 4, a cabinet 3, a cabinet 2, a cabinet 1, a coffeemachine 1, a countertop 2, a countertop 1, a diningtable 2, a diningtable 1, a drawer 3, a drawer 2, a drawer 1, a fridge 1, a garbagecan 1, a microwave 1, a sinkbasin 1, a stoveburner 4, a stoveburner 3, a stoveburner 2, a stoveburner 1, and a toaster 1.
|
2 |
+
Your task is to: heat some mug and put it in coffeemachine.
|
3 |
+
Action: think: To solve the task, I need to find and take a mug, then heat it with stoveburner, then put it in coffeemachine.
|
4 |
+
Ok.
|
5 |
+
Action: look
|
6 |
+
You are in the middle of a room. Looking quickly around you, you see nothing.
|
7 |
+
Action: think: First I need to find a mug. A mug is more likely to appear in countertop (1-2), diningtable (1-2), drawer (1-3), fridge (1), sinkbasin (1), stoveburner (1-4), toaster (1). I can check one by one, starting with countertop 1.
|
8 |
+
Ok.
|
9 |
+
Action: look
|
10 |
+
You are in the middle of a room. Looking quickly around you, you see nothing.
|
11 |
+
Action: go to countertop 1
|
12 |
+
On the countertop 1, you see a butterknife 1, a cup 1, a fork 2, and a mug 1.
|
13 |
+
Action: think: Now I find a mug (1). Next, I need to take it.
|
14 |
+
Ok.
|
15 |
+
Action: take mug 1 from countertop 1
|
16 |
+
You pick up the mug 1 from the countertop 1.
|
17 |
+
Action: think: Now I take a mug (1). Next, I need to heat it with stoveburner (1-4).
|
18 |
+
Ok.
|
19 |
+
Action: go to stoveburner 1
|
20 |
+
On the stoveburner 1, you see a pan 2.
|
21 |
+
Action: examine stoveburner 1
|
22 |
+
On the stoveburner 1, you see a pan 2.
|
23 |
+
Action: examine stoveburner 1
|
24 |
+
On the stoveburner 1, you see a pan 2.
|
25 |
+
Action: examine stoveburner 1
|
26 |
+
On the stoveburner 1, you see a pan 2.
|
27 |
+
Action: examine stoveburner 1
|
28 |
+
On the stoveburner 1, you see a pan 2.
|
29 |
+
STATUS: FAIL
|
30 |
+
Summary and Key Findings: Summary of Trajectory and Outcomes:
|
31 |
+
The task was to heat a mug and put it in the coffeemachine. The initial plan was to find and take a mug, heat it with a stoveburner, and then put it in the coffeemachine. The user successfully found and took a mug from countertop 1. However, upon reaching stoveburner 1, the user got stuck in a loop of examining the stoveburner without making any progress in heating the mug. The task ultimately ended in failure.
|
32 |
+
Key Findings:
|
33 |
+
1. The user was able to identify potential locations for finding a mug and successfully located one on countertop 1.
|
34 |
+
2. The user got stuck in a loop of examining stoveburner 1 without taking any action to heat the mug.
|
35 |
+
3. The task was not completed successfully due to the user's inability to progress beyond examining the stoveburner.
|
36 |
+
|
37 |
+
You are in the middle of a room. Looking quickly around you, you see a bed 1, a desk 2, a desk 1, a drawer 6, a drawer 5, a drawer 4, a drawer 3, a drawer 2, a drawer 1, a garbagecan 1, a laundryhamper 1, a safe 1, a shelf 6, a shelf 5, a shelf 4, a shelf 3, a shelf 2, and a shelf 1.
|
38 |
+
Your task is to: look at bowl under the desklamp.
|
39 |
+
Action: think: To solve the task, I need to find and take a bowl, then find and use a desklamp.
|
40 |
+
OK.
|
41 |
+
Action: think: First I need to find a bowl. A bowl is more likely to appear in drawer (1-6), desk (1-2), shelf (1-6), garbagecan (1), laundryhamper (1). I can check one by one, starting with drawer 1.
|
42 |
+
OK.
|
43 |
+
Action: go to drawer 1
|
44 |
+
The drawer 1 is closed.
|
45 |
+
Action: open drawer 1
|
46 |
+
You open the drawer 1. The drawer 1 is open. In it, you see a cellphone 2.
|
47 |
+
Action: go to drawer 2
|
48 |
+
On the drawer 2, you see a keychain 1.
|
49 |
+
Action: go to drawer 3
|
50 |
+
Nothing happens.
|
51 |
+
Action: go to drawer 4
|
52 |
+
The drawer 4 is closed.
|
53 |
+
Action: open drawer 4
|
54 |
+
You open the drawer 4. The drawer 4 is open. In it, you see nothing.
|
55 |
+
Action: go to drawer 5
|
56 |
+
On the drawer 5, you see nothing.
|
57 |
+
Action: go to drawer 6
|
58 |
+
The drawer 6 is closed.
|
59 |
+
Action: open drawer 6
|
60 |
+
You open the drawer 6. The drawer 6 is open. In it, you see a book 2, and a pen 2.
|
61 |
+
Action: go to desk 1
|
62 |
+
On the desk 1, you see a alarmclock 1, a desklamp 1, a mug 1, and a pencil 1.
|
63 |
+
Action: think: Now I find a desklamp (1). Next, I need to use it.
|
64 |
+
OK.
|
65 |
+
Action: use desklamp 1
|
66 |
+
You turn on the desklamp 1.
|
67 |
+
Action:
|
68 |
+
Nothing happens.
|
69 |
+
Action: go to desk 2
|
70 |
+
On the desk 2, you see a bowl 2, a bowl 1, a cd 2, a cd 1, and a creditcard 3.
|
71 |
+
Action: think: Now I find a bowl (2). Next, I need to take it.
|
72 |
+
OK.
|
73 |
+
Action: take bowl 2 from desk 2
|
74 |
+
You pick up the bowl 2 from the desk 2.
|
75 |
+
Action:
|
76 |
+
Nothing happens.
|
77 |
+
Action:
|
78 |
+
Nothing happens.
|
79 |
+
STATUS: FAIL
|
80 |
+
Summary: Summary of Trajectory and Outcomes:
|
81 |
+
The task was to look at a bowl under the desklamp. The initial plan was to find and take a bowl, then find and use a desklamp. The user successfully found a desklamp on desk 1 and turned it on. They also found a bowl on desk 2 and picked it up. However, the user did not complete the task of looking at the bowl under the desklamp and instead got stuck in a state of inaction. The task ultimately ended in failure.
|
82 |
+
Key Findings:
|
83 |
+
1. The user was able to identify potential locations for finding a bowl and a desklamp and successfully located them on desk 2 and desk 1, respectively.
|
84 |
+
2. The user turned on the desklamp but did not proceed to look at the bowl under it.
|
85 |
+
3. The task was not completed successfully due to the user's inability to progress beyond taking the bowl and turning on the desklamp.
|
distillers/guider.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from deciders.utils import get_completion, get_chat
|
2 |
+
|
3 |
+
from typing import List, Dict, Any
|
4 |
+
from loguru import logger
|
5 |
+
import random
|
6 |
+
import json
|
7 |
+
class Guidance_Generator():
|
8 |
+
def __init__(self,logfile="",args=None):
|
9 |
+
self.args = args
|
10 |
+
with open("./distillers/guidance_summary_few_shot_examples.txt", 'r') as f:
|
11 |
+
self.SUMMARY_FEW_SHOT_EXAMPLES = f.read()
|
12 |
+
# with open("./distillers/exploration_few_shot_examples.txt", 'r') as f:
|
13 |
+
# self.SUGGEST_FEW_SHOT_EXAMPLES = f.read()
|
14 |
+
self.insight = ""
|
15 |
+
self.suggestion = ""
|
16 |
+
if logfile:
|
17 |
+
# logger.remove()
|
18 |
+
logger.add(logfile, colorize=True, enqueue=True, filter=lambda x: '[Reflexion Memory]' in x['message'])
|
19 |
+
|
20 |
+
def generate_from_file(self, file_path,max_step_num=200):
|
21 |
+
mem = []
|
22 |
+
with open(file_path, 'r') as infile:
|
23 |
+
data = json.load(infile)
|
24 |
+
for traj in data:
|
25 |
+
traj_text = traj[0]['game_description']
|
26 |
+
traj_text += traj[0]['goal_description']
|
27 |
+
for transition in traj[-max_step_num:]:
|
28 |
+
traj_text += transition['observation']
|
29 |
+
traj_text += f"Action: {transition['action']}"
|
30 |
+
summary = self.generate_summary(traj_text, mem)
|
31 |
+
mem.append(summary)
|
32 |
+
return mem
|
33 |
+
|
34 |
+
def _generate_summary_query(self, traj, post_memory):
|
35 |
+
"""
|
36 |
+
Generates an exploration guidance query for GPT-3.5 based on given trajectory and memory.
|
37 |
+
|
38 |
+
Parameters:
|
39 |
+
- traj: Trajectory of the new experience.
|
40 |
+
- post_memory: List of memory items to summarize.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
- query: Formulated query string for GPT-3.5.
|
44 |
+
"""
|
45 |
+
segments = []
|
46 |
+
|
47 |
+
# Summarization memory
|
48 |
+
# if post_memory:
|
49 |
+
# segments.append('Your summarization memory is as below:')
|
50 |
+
# segments.extend([f'Episode #{i}: {m}' for i, m in enumerate(post_memory)])
|
51 |
+
|
52 |
+
# Trajectory
|
53 |
+
segments.append(f"Your new collected trajectory is as below:\n {traj}")
|
54 |
+
segments.append(f"The suggestion to guide the trajectory is:\n{self.suggestion}")
|
55 |
+
# Questions
|
56 |
+
questions = """
|
57 |
+
Please answer the following questions directly, without additional explanation:
|
58 |
+
1. Based on the collected trajectory, infer the specific values of game-relevant knowledge proposed in the suggestion with json format.
|
59 |
+
2. Summarize the policy behavior and its performance.
|
60 |
+
Provide concise responses.
|
61 |
+
"""
|
62 |
+
segments.append(questions)
|
63 |
+
|
64 |
+
# Construct the final query
|
65 |
+
query = '\n'.join(segments)
|
66 |
+
return query
|
67 |
+
|
68 |
+
# def _generate_summary_query(self, traj, post_memory):
|
69 |
+
# """Allows the Agent to generate exploration guidance."""
|
70 |
+
# query = ""
|
71 |
+
# if len(post_memory) > 0:
|
72 |
+
# query += '\Your summarization memory is as below:\n'
|
73 |
+
# for i, m in enumerate(post_memory):
|
74 |
+
# query += f'Episode #{i}: {m}\n'
|
75 |
+
# query += f"""
|
76 |
+
# {traj}
|
77 |
+
# Above is the trajectory of the new experience.
|
78 |
+
# """
|
79 |
+
# query += '\n Anwser the following questions.\n 1. What is the performance of this policy and does it improve the performance compared to before? 2. Summarize the main reason that makes the policy improve or reduce the performance; 3. What new information of the task can be inferred compared to the memory?'
|
80 |
+
# return query
|
81 |
+
|
82 |
+
def generate_summary(self, traj, post_memory):
|
83 |
+
query = self._generate_summary_query(traj, post_memory)
|
84 |
+
summary = get_chat(query,model=self.args.gpt_version, engine=self.args.gpt_version)
|
85 |
+
logger.info(f'[Reflexion Memory]The summary prompt is: {query}.')
|
86 |
+
logger.info(f'[Reflexion Memory]The summary response is: {summary}.')
|
87 |
+
return summary
|
88 |
+
|
89 |
+
def generate_insight(self, post_memory):
|
90 |
+
query: str = f"""As an AI assistant, you are helping a six-year-old player who has never played this game before. The experiences you have are as follows:"""
|
91 |
+
if len(post_memory) > 0:
|
92 |
+
for i, m in enumerate(post_memory):
|
93 |
+
query += f'Episode #{i}: {m}\n'
|
94 |
+
query += '\n Identify and summarize the key information that can be exploited to improve performance of the player.'
|
95 |
+
insight = get_chat(query,model=self.args.gpt_version, engine=self.args.gpt_version)
|
96 |
+
logger.info(f'[Reflexion Memory]The insight prompt is: {query}.')
|
97 |
+
logger.info(f'[Reflexion Memory]The insight response is: {insight}.')
|
98 |
+
return insight
|
99 |
+
|
100 |
+
def generate_suggestion(self, game_description, goal_description, action_description, pre_memory, post_memory, insight, max_num_trials):
|
101 |
+
query: str = f"""You are an AI assitant that help a human player win the following game.
|
102 |
+
The game is \n"{game_description}" \n, the action space is described as {action_description},\n the player's goal is \n "{goal_description}".\n
|
103 |
+
The player can play for {max_num_trials} episodes. The main aim for you is to help the player win the game in the last episode. """
|
104 |
+
if len(post_memory) > 0:
|
105 |
+
query += f"""You have obtained experience as below """
|
106 |
+
for i, m in enumerate(post_memory):
|
107 |
+
query += f'Episode #{i}: {m}\n'
|
108 |
+
# if max_num_trials - len(post_memory) == 1:
|
109 |
+
# query = (f"\n The main goal is to aid the human player in winning the game in the next episode. "
|
110 |
+
# f"This is his {len(post_memory) + 1} try out of {max(max_num_trials, 1)} episodes. "
|
111 |
+
# "Your suggestions should be simple, executable with heuristic policy, and suitable for an LLM agent. "
|
112 |
+
# "Reply in an item list format. Specifically, focus on:"
|
113 |
+
# "\n1. How to achieve optimal performance (exploitation) using the obtained knowledge?"
|
114 |
+
# "\nNote: Stress the importance of prioritizing performance without exploration.")
|
115 |
+
# suggestion = get_chat(query) + "\n Remember, in this attempt, aim solely for high performance without exploration."
|
116 |
+
# else:
|
117 |
+
# if max_num_trials-len(post_memory) == 1:
|
118 |
+
# query += f"\n The main aim for you is to help the human player win the game in the last episode. The next episode is the last episode. You can give suggestions before each episode. Then what is your suggestion for his next episode? Note that this is the last try and he should not explore which may decrease his performance. The suggestions should be simple to follow, executable with heuristic policy, easy to use for an llm agent,and reply in item list format. The answer should instruct him to exploit all the knowlegde to gain the highest performance (exploitation) in the next episode. "
|
119 |
+
# else:
|
120 |
+
query += f"\n The main aim for you is to help the human player win the game in the last episode. He has only {max(max_num_trials-len(post_memory), 1)} episodes left to try.You can give suggestions before each episode. Then what is your suggestion for his next episode? Please provide simple, concise answers suitable for a six-year-old child, focusing on the following in item list format: 1. What game-relevant knowledge is critical to determine the optimal policy. Notice that the knowledge should be obtainable by interacting with the environment and helpful for the decisions.\n 2. How should the player conduct exploration in the next episode to acquire this information?\n3. How can the player exploit the information obtained to achieve higher performance in subsequent episodes?\n 4. How should exploration and exploitation be balanced to improve performance in the next episode?\n"
|
121 |
+
# query += (f"\n The primary goal is to assist the human player in winning the game in the final episode. "
|
122 |
+
# f"This is his {len(post_memory) + 1} try out of {max(max_num_trials, 1)} episodes. "
|
123 |
+
# "Provide suggestions for the next episode that balance both exploration and exploitation. "
|
124 |
+
# "The suggestions should be in item list format, easy to follow, aligned with heuristic policy, and usable for an LLM agent. Address:"
|
125 |
+
# "\n1. Which information the player should gather via exploration and the best ways to explore?"
|
126 |
+
# "\n2. Strategies to refine the policy for enhanced performance (exploitation)?"
|
127 |
+
# "\n3. How should exploration and exploitation be weighted in the next episode?")
|
128 |
+
|
129 |
+
# TODO: consider the inconsistency between past suggestion and past memory.
|
130 |
+
suggestion = get_chat(query,model=self.args.gpt_version, engine=self.args.gpt_version)
|
131 |
+
self.suggestion = suggestion
|
132 |
+
logger.info(f'[Reflexion Memory]The suggestion prompt is: {query}.')
|
133 |
+
logger.info(f'[Reflexion Memory]The suggestion response is: {suggestion}.')
|
134 |
+
return suggestion
|
135 |
+
|
136 |
+
def generate(self, traj, memory, max_len_mem=5):
|
137 |
+
if len(memory)> max_len_mem:
|
138 |
+
reflection_query = self._generate_summary_query(traj, memory[-max_len_mem:])
|
139 |
+
else:
|
140 |
+
reflection_query = self._generate_summary_query(traj, memory)
|
141 |
+
reflection = get_completion(reflection_query,engine=self.args.gpt_version)
|
142 |
+
logger.info(f'[Reflexion Memory]The reflexion prompt is: {reflection_query}.')
|
143 |
+
logger.info(f'[Reflexion Memory]The reflexion response is: {reflection}.')
|
144 |
+
return reflection
|
distillers/raw_prompt_generator.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import json
|
3 |
+
class RawPromptGenerator():
|
4 |
+
def __init__(self,args=None):
|
5 |
+
self.args = args
|
6 |
+
pass
|
7 |
+
|
8 |
+
def generate_from_file(self, file_path, choice_num=1):
|
9 |
+
with open(file_path, 'r') as infile:
|
10 |
+
data = json.load(infile)
|
11 |
+
result = []
|
12 |
+
for my_data in data[0]:
|
13 |
+
result.append({'question': my_data['question'], 'answer': my_data['answer']})
|
14 |
+
selected_index = random.sample(range(len(result)), choice_num)
|
15 |
+
selected_result = [result[index] for index in selected_index]
|
16 |
+
return selected_result
|
distillers/reflexion_few_shot_examples.txt
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
You are in the middle of a room. Looking quickly around you, you see a cabinet 10, a cabinet 9, a cabinet 8, a cabinet 7, a cabinet 6, a cabinet 5, a cabinet 4, a cabinet 3, a cabinet 2, a cabinet 1, a coffeemachine 1, a countertop 2, a countertop 1, a diningtable 2, a diningtable 1, a drawer 3, a drawer 2, a drawer 1, a fridge 1, a garbagecan 1, a microwave 1, a sinkbasin 1, a stoveburner 4, a stoveburner 3, a stoveburner 2, a stoveburner 1, and a toaster 1.
|
2 |
+
Your task is to: heat some mug and put it in coffeemachine.
|
3 |
+
> think: To solve the task, I need to find and take a mug, then heat it with stoveburner, then put it in coffeemachine.
|
4 |
+
Ok.
|
5 |
+
> look
|
6 |
+
You are in the middle of a room. Looking quickly around you, you see nothing.
|
7 |
+
> think: First I need to find a mug. A mug is more likely to appear in countertop (1-2), diningtable (1-2), drawer (1-3), fridge (1), sinkbasin (1), stoveburner (1-4), toaster (1). I can check one by one, starting with countertop 1.
|
8 |
+
Ok.
|
9 |
+
> look
|
10 |
+
You are in the middle of a room. Looking quickly around you, you see nothing.
|
11 |
+
> go to countertop 1
|
12 |
+
On the countertop 1, you see a butterknife 1, a cup 1, a fork 2, and a mug 1.
|
13 |
+
> think: Now I find a mug (1). Next, I need to take it.
|
14 |
+
Ok.
|
15 |
+
> take mug 1 from countertop 1
|
16 |
+
You pick up the mug 1 from the countertop 1.
|
17 |
+
> think: Now I take a mug (1). Next, I need to heat it with stoveburner (1-4).
|
18 |
+
Ok.
|
19 |
+
> go to stoveburner 1
|
20 |
+
On the stoveburner 1, you see a pan 2.
|
21 |
+
> examine stoveburner 1
|
22 |
+
On the stoveburner 1, you see a pan 2.
|
23 |
+
> examine stoveburner 1
|
24 |
+
On the stoveburner 1, you see a pan 2.
|
25 |
+
> examine stoveburner 1
|
26 |
+
On the stoveburner 1, you see a pan 2.
|
27 |
+
> examine stoveburner 1
|
28 |
+
On the stoveburner 1, you see a pan 2.
|
29 |
+
STATUS: FAIL
|
30 |
+
New plan: I was stuck in a loop in which I continually examined stoveburner 1 instead of heating mug 1 with stoveburner 1. I should have taken mug 1 from countertop 1, then heated it with stoveburner 1, then put it in coffeemachine 1. It did not help to execute two identical actions in a row. I will try to execute a different action if I am stuck in a loop again.
|
31 |
+
|
32 |
+
You are in the middle of a room. Looking quickly around you, you see a bed 1, a desk 2, a desk 1, a drawer 6, a drawer 5, a drawer 4, a drawer 3, a drawer 2, a drawer 1, a garbagecan 1, a laundryhamper 1, a safe 1, a shelf 6, a shelf 5, a shelf 4, a shelf 3, a shelf 2, and a shelf 1.
|
33 |
+
Your task is to: look at bowl under the desklamp.
|
34 |
+
> think: To solve the task, I need to find and take a bowl, then find and use a desklamp.
|
35 |
+
OK.
|
36 |
+
> think: First I need to find a bowl. A bowl is more likely to appear in drawer (1-6), desk (1-2), shelf (1-6), garbagecan (1), laundryhamper (1). I can check one by one, starting with drawer 1.
|
37 |
+
OK.
|
38 |
+
> go to drawer 1
|
39 |
+
The drawer 1 is closed.
|
40 |
+
> open drawer 1
|
41 |
+
You open the drawer 1. The drawer 1 is open. In it, you see a cellphone 2.
|
42 |
+
> go to drawer 2
|
43 |
+
On the drawer 2, you see a keychain 1.
|
44 |
+
> go to drawer 3
|
45 |
+
Nothing happens.
|
46 |
+
> go to drawer 4
|
47 |
+
The drawer 4 is closed.
|
48 |
+
> open drawer 4
|
49 |
+
You open the drawer 4. The drawer 4 is open. In it, you see nothing.
|
50 |
+
> go to drawer 5
|
51 |
+
On the drawer 5, you see nothing.
|
52 |
+
> go to drawer 6
|
53 |
+
The drawer 6 is closed.
|
54 |
+
> open drawer 6
|
55 |
+
You open the drawer 6. The drawer 6 is open. In it, you see a book 2, and a pen 2.
|
56 |
+
> go to desk 1
|
57 |
+
On the desk 1, you see a alarmclock 1, a desklamp 1, a mug 1, and a pencil 1.
|
58 |
+
> think: Now I find a desklamp (1). Next, I need to use it.
|
59 |
+
OK.
|
60 |
+
> use desklamp 1
|
61 |
+
You turn on the desklamp 1.
|
62 |
+
>
|
63 |
+
Nothing happens.
|
64 |
+
> go to desk 2
|
65 |
+
On the desk 2, you see a bowl 2, a bowl 1, a cd 2, a cd 1, and a creditcard 3.
|
66 |
+
> think: Now I find a bowl (2). Next, I need to take it.
|
67 |
+
OK.
|
68 |
+
> take bowl 2 from desk 2
|
69 |
+
You pick up the bowl 2 from the desk 2.
|
70 |
+
>
|
71 |
+
Nothing happens.
|
72 |
+
>
|
73 |
+
Nothing happens.
|
74 |
+
STATUS: FAIL
|
75 |
+
New plan: In this environment, my plan was to find a bowl then find and use a desklamp. However, the task says to look at bowl under the desklamp. I should have looked for the desklamp then looked for the bowl. I noticed that the desklamp was found on desk 1. In the next trial, I will go to desk 1, find the lamp, then look for the bowl under the desklamp.
|
distillers/self_reflection.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from deciders.utils import get_completion
|
2 |
+
|
3 |
+
from typing import List, Dict, Any
|
4 |
+
from loguru import logger
|
5 |
+
import random
|
6 |
+
import json
|
7 |
+
class RefletionGenerator():
|
8 |
+
def __init__(self,logfile="",args=None):
|
9 |
+
self.args = args
|
10 |
+
with open("./distillers/reflexion_few_shot_examples.txt", 'r') as f:
|
11 |
+
self.FEW_SHOT_EXAMPLES = f.read()
|
12 |
+
if logfile:
|
13 |
+
# logger.remove()
|
14 |
+
logger.add(logfile, colorize=True, enqueue=True, filter=lambda x: '[Reflexion Memory]' in x['message'])
|
15 |
+
|
16 |
+
def generate_from_file(self, file_path,max_step_num=200):
|
17 |
+
mem = []
|
18 |
+
with open(file_path, 'r') as infile:
|
19 |
+
data = json.load(infile)
|
20 |
+
for traj in data:
|
21 |
+
traj_text = traj[0]['game_description']
|
22 |
+
traj_text += traj[0]['goal_description']
|
23 |
+
for transition in traj[-max_step_num:]:
|
24 |
+
traj_text += transition['observation']
|
25 |
+
traj_text += f"Action: {transition['action']}"
|
26 |
+
reflection = self.generate(traj_text, mem, max_len_mem=5)
|
27 |
+
mem.append(reflection)
|
28 |
+
return mem
|
29 |
+
|
30 |
+
def _generate_reflection_query(self, traj, memory):
|
31 |
+
"""Allows the Agent to reflect upon a past experience."""
|
32 |
+
query: str = f"""You will be given the history of a past experience in which you were placed in an environment and given a task to complete. You were unsuccessful in completing the task. Do not summarize your environment, but rather think about the strategy and path you took to attempt to complete the task. Think step by step what mistakes you made leading the failure. Then devise a concise, new plan of action that accounts for your mistake with reference to specific actions that you should have taken. For example, if you tried A and B but forgot C, then you should reason that the forgetting C is the key mistake. After that you devise a plan to achieve C with environment-specific actions. You remind yourself the plan your will take in the next trail and Give your plan after "Plan". Here are two examples:
|
33 |
+
|
34 |
+
{self.FEW_SHOT_EXAMPLES}
|
35 |
+
|
36 |
+
{traj}"""
|
37 |
+
if len(memory) > 0:
|
38 |
+
query += '\n\nPlans from past attempts:\n'
|
39 |
+
for i, m in enumerate(memory):
|
40 |
+
query += f'Trial #{i}: {m}\n'
|
41 |
+
|
42 |
+
query += '\n\nNew plan:'
|
43 |
+
return query
|
44 |
+
|
45 |
+
def generate(self, traj, memory, max_len_mem=5):
|
46 |
+
if len(memory)> max_len_mem:
|
47 |
+
reflection_query = self._generate_reflection_query(traj, memory[-max_len_mem:])
|
48 |
+
else:
|
49 |
+
reflection_query = self._generate_reflection_query(traj, memory)
|
50 |
+
reflection = get_completion(reflection_query, engine=self.args.gpt_version)
|
51 |
+
logger.info(f'[Reflexion Memory]The reflexion prompt is: {reflection_query}.')
|
52 |
+
logger.info(f'[Reflexion Memory]The reflexion response is: {reflection}.')
|
53 |
+
return reflection
|
distillers/traj_prompt_summarizer.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from deciders.utils import get_completion
|
3 |
+
import json
|
4 |
+
class TrajPromptSummarizer():
|
5 |
+
def __init__(self,args=None):
|
6 |
+
self.args = args
|
7 |
+
with open("./distillers/traj_summary_few_shot_examples.txt", 'r') as f:
|
8 |
+
self.FEW_SHOT_EXAMPLES = f.read()
|
9 |
+
|
10 |
+
def generate_from_file(self, file_path,max_step_num=200):
|
11 |
+
mem = []
|
12 |
+
with open(file_path, 'r') as infile:
|
13 |
+
data = json.load(infile)
|
14 |
+
for traj in data:
|
15 |
+
traj_text = traj[0]['game_description']
|
16 |
+
traj_text += traj[0]['goal_description']
|
17 |
+
for transition in traj[-max_step_num:]:
|
18 |
+
traj_text += transition['observation']
|
19 |
+
traj_text += f"> {transition['action']}"
|
20 |
+
traj_text += f"Your performance is: {transition['cum_reward']}"
|
21 |
+
reflection = self.generate(traj_text, mem, max_len_mem=5)
|
22 |
+
mem.append(reflection)
|
23 |
+
return mem
|
24 |
+
|
25 |
+
def _generate_summary_query(self, traj, memory):
|
26 |
+
"""Allows the Agent to reflect upon a past experience."""
|
27 |
+
query: str = f"""You will be given the history of a past experience in which you were placed in an environment and given a task to complete. Summarize your trajectory and reasoning the relation between your policy and the obtained result. Here are two examples:
|
28 |
+
|
29 |
+
{self.FEW_SHOT_EXAMPLES}
|
30 |
+
|
31 |
+
{traj}"""
|
32 |
+
if len(memory) > 0:
|
33 |
+
query += '\n\nPlans from past attempts:\n'
|
34 |
+
for i, m in enumerate(memory):
|
35 |
+
query += f'Trial #{i}: {m}\n'
|
36 |
+
|
37 |
+
query += '\n\nSummary:'
|
38 |
+
return query
|
39 |
+
|
40 |
+
def generate(self, traj, memory, max_len_mem=5):
|
41 |
+
if len(memory)> max_len_mem:
|
42 |
+
reflection_query = self._generate_summary_query(traj, memory[-max_len_mem:])
|
43 |
+
else:
|
44 |
+
reflection_query = self._generate_summary_query(traj, memory)
|
45 |
+
reflection = get_completion(reflection_query, engine=self.args.gpt_version)
|
46 |
+
return reflection
|
distillers/traj_summary_few_shot_examples.txt
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
You are in the middle of a room. Looking quickly around you, you see a cabinet 10, a cabinet 9, a cabinet 8, a cabinet 7, a cabinet 6, a cabinet 5, a cabinet 4, a cabinet 3, a cabinet 2, a cabinet 1, a coffeemachine 1, a countertop 2, a countertop 1, a diningtable 2, a diningtable 1, a drawer 3, a drawer 2, a drawer 1, a fridge 1, a garbagecan 1, a microwave 1, a sinkbasin 1, a stoveburner 4, a stoveburner 3, a stoveburner 2, a stoveburner 1, and a toaster 1.
|
2 |
+
Your task is to: heat some mug and put it in coffeemachine.
|
3 |
+
> think: To solve the task, I need to find and take a mug, then heat it with stoveburner, then put it in coffeemachine.
|
4 |
+
Ok.
|
5 |
+
> look
|
6 |
+
You are in the middle of a room. Looking quickly around you, you see nothing.
|
7 |
+
> think: First I need to find a mug. A mug is more likely to appear in countertop (1-2), diningtable (1-2), drawer (1-3), fridge (1), sinkbasin (1), stoveburner (1-4), toaster (1). I can check one by one, starting with countertop 1.
|
8 |
+
Ok.
|
9 |
+
> look
|
10 |
+
You are in the middle of a room. Looking quickly around you, you see nothing.
|
11 |
+
> go to countertop 1
|
12 |
+
On the countertop 1, you see a butterknife 1, a cup 1, a fork 2, and a mug 1.
|
13 |
+
> think: Now I find a mug (1). Next, I need to take it.
|
14 |
+
Ok.
|
15 |
+
> take mug 1 from countertop 1
|
16 |
+
You pick up the mug 1 from the countertop 1.
|
17 |
+
> think: Now I take a mug (1). Next, I need to heat it with stoveburner (1-4).
|
18 |
+
Ok.
|
19 |
+
> go to stoveburner 1
|
20 |
+
On the stoveburner 1, you see a pan 2.
|
21 |
+
> examine stoveburner 1
|
22 |
+
On the stoveburner 1, you see a pan 2.
|
23 |
+
> examine stoveburner 1
|
24 |
+
On the stoveburner 1, you see a pan 2.
|
25 |
+
> examine stoveburner 1
|
26 |
+
On the stoveburner 1, you see a pan 2.
|
27 |
+
> examine stoveburner 1
|
28 |
+
On the stoveburner 1, you see a pan 2.
|
29 |
+
STATUS: FAIL
|
30 |
+
Summary: I continually examined stoveburner 1 instead of heating mug 1 with stoveburner 1. This makes me stuck in a loop and fail.
|
31 |
+
|
32 |
+
You are in the middle of a room. Looking quickly around you, you see a bed 1, a desk 2, a desk 1, a drawer 6, a drawer 5, a drawer 4, a drawer 3, a drawer 2, a drawer 1, a garbagecan 1, a laundryhamper 1, a safe 1, a shelf 6, a shelf 5, a shelf 4, a shelf 3, a shelf 2, and a shelf 1.
|
33 |
+
Your task is to: look at bowl under the desklamp.
|
34 |
+
> think: To solve the task, I need to find and take a bowl, then find and use a desklamp.
|
35 |
+
OK.
|
36 |
+
> think: First I need to find a bowl. A bowl is more likely to appear in drawer (1-6), desk (1-2), shelf (1-6), garbagecan (1), laundryhamper (1). I can check one by one, starting with drawer 1.
|
37 |
+
OK.
|
38 |
+
> go to drawer 1
|
39 |
+
The drawer 1 is closed.
|
40 |
+
> open drawer 1
|
41 |
+
You open the drawer 1. The drawer 1 is open. In it, you see a cellphone 2.
|
42 |
+
> go to drawer 2
|
43 |
+
On the drawer 2, you see a keychain 1.
|
44 |
+
> go to drawer 3
|
45 |
+
Nothing happens.
|
46 |
+
> go to drawer 4
|
47 |
+
The drawer 4 is closed.
|
48 |
+
> open drawer 4
|
49 |
+
You open the drawer 4. The drawer 4 is open. In it, you see nothing.
|
50 |
+
> go to drawer 5
|
51 |
+
On the drawer 5, you see nothing.
|
52 |
+
> go to drawer 6
|
53 |
+
The drawer 6 is closed.
|
54 |
+
> open drawer 6
|
55 |
+
You open the drawer 6. The drawer 6 is open. In it, you see a book 2, and a pen 2.
|
56 |
+
> go to desk 1
|
57 |
+
On the desk 1, you see a alarmclock 1, a desklamp 1, a mug 1, and a pencil 1.
|
58 |
+
> think: Now I find a desklamp (1). Next, I need to use it.
|
59 |
+
OK.
|
60 |
+
> use desklamp 1
|
61 |
+
You turn on the desklamp 1.
|
62 |
+
>
|
63 |
+
Nothing happens.
|
64 |
+
> go to desk 2
|
65 |
+
On the desk 2, you see a bowl 2, a bowl 1, a cd 2, a cd 1, and a creditcard 3.
|
66 |
+
> think: Now I find a bowl (2). Next, I need to take it.
|
67 |
+
OK.
|
68 |
+
> take bowl 2 from desk 2
|
69 |
+
You pick up the bowl 2 from the desk 2.
|
70 |
+
>
|
71 |
+
Nothing happens.
|
72 |
+
>
|
73 |
+
Nothing happens.
|
74 |
+
STATUS: FAIL
|
75 |
+
Summary: I try to find a bowl then find and use a desklamp. This is inconsistent to the task which require looking at
|
76 |
+
bowl under the desklamp. Thus I fail.
|
draw_overall_performance.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
|
4 |
+
# Load the CSV data
|
5 |
+
data = pd.read_csv("performance_data.csv")
|
6 |
+
|
7 |
+
# Group games by type
|
8 |
+
game_types = {
|
9 |
+
"Classic Control": ["Acrobot-v1", "CartPole-v0", "MountainCar-v0"],
|
10 |
+
"Box 2D": ["LunarLander-v2"],
|
11 |
+
"Toy Text": ["Taxi-v3", "CliffWalking-v0", "Blackjack-v1"]
|
12 |
+
}
|
13 |
+
|
14 |
+
for game_type, games in game_types.items():
|
15 |
+
fig, axs = plt.subplots(1, len(games), figsize=(12 * len(games), 6))
|
16 |
+
fig.suptitle(f"Performance Plot: {game_type}", fontsize=28, fontname="Times New Roman")
|
17 |
+
|
18 |
+
if len(games) == 1:
|
19 |
+
axs = [axs]
|
20 |
+
|
21 |
+
handles, labels = [], []
|
22 |
+
|
23 |
+
for idx, game in enumerate(games):
|
24 |
+
# Filter data to get information for the current game (in the loop)
|
25 |
+
game_data = data[data["game"] == game]
|
26 |
+
|
27 |
+
axs[idx].set_title(game, fontsize=20, fontname="Times New Roman")
|
28 |
+
axs[idx].set_xlabel("Levels", fontsize=16, fontname="Times New Roman")
|
29 |
+
if idx == 0:
|
30 |
+
axs[idx].set_ylabel("Scores", fontsize=16, fontname="Times New Roman")
|
31 |
+
|
32 |
+
for index, row in game_data.iterrows():
|
33 |
+
decider_name = row["decider_name"]
|
34 |
+
levels = ["l1", "l2", "l3", "l4", "l5"]
|
35 |
+
scores = row[levels].values.tolist()
|
36 |
+
lines = axs[idx].plot(levels, scores, "-o", label=decider_name)
|
37 |
+
# Grab the handle and label for creating a global legend
|
38 |
+
handles.append(lines[0])
|
39 |
+
labels.append(decider_name)
|
40 |
+
|
41 |
+
# Eliminate duplicate labels and handles
|
42 |
+
unique_labels = []
|
43 |
+
unique_handles = []
|
44 |
+
for handle, label in zip(handles, labels):
|
45 |
+
if label not in unique_labels:
|
46 |
+
unique_labels.append(label)
|
47 |
+
unique_handles.append(handle)
|
48 |
+
|
49 |
+
# Add a legend at the bottom middle of the figure
|
50 |
+
fig.legend(
|
51 |
+
unique_handles,
|
52 |
+
unique_labels,
|
53 |
+
loc="lower center",
|
54 |
+
ncol=4, prop={'size': 18}
|
55 |
+
)
|
56 |
+
|
57 |
+
# Adjust layout to accommodate the legend and prevent cropping
|
58 |
+
|
59 |
+
plt.savefig("./vis/" + game_type + ".png", dpi=300)
|
environment.yml
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: llm-gym
|
2 |
+
channels:
|
3 |
+
- conda-forge
|
4 |
+
- defaults
|
5 |
+
dependencies:
|
6 |
+
- _libgcc_mutex=0.1=main
|
7 |
+
- _openmp_mutex=5.1=1_gnu
|
8 |
+
- aiosignal=1.2.0=pyhd3eb1b0_0
|
9 |
+
- asttokens=2.0.5=pyhd3eb1b0_0
|
10 |
+
- async-timeout=4.0.2=py38h06a4308_0
|
11 |
+
- attrs=22.1.0=py38h06a4308_0
|
12 |
+
- backcall=0.2.0=pyhd3eb1b0_0
|
13 |
+
- blas=1.0=mkl
|
14 |
+
- brotlipy=0.7.0=py38h27cfd23_1003
|
15 |
+
- ca-certificates=2023.08.22=h06a4308_0
|
16 |
+
- cached-property=1.5.2=py_0
|
17 |
+
- certifi=2023.7.22=py38h06a4308_0
|
18 |
+
- cffi=1.15.1=py38h5eee18b_3
|
19 |
+
- chardet=4.0.0=py38h06a4308_1003
|
20 |
+
- comm=0.1.2=py38h06a4308_0
|
21 |
+
- cryptography=39.0.1=py38h9ce1e76_2
|
22 |
+
- cudatoolkit=11.3.1=h2bc3f7f_2
|
23 |
+
- debugpy=1.5.1=py38h295c915_0
|
24 |
+
- executing=0.8.3=pyhd3eb1b0_0
|
25 |
+
- frozenlist=1.3.3=py38h5eee18b_0
|
26 |
+
- hdf5=1.10.6=h3ffc7dd_1
|
27 |
+
- idna=3.4=py38h06a4308_0
|
28 |
+
- importlib_metadata=6.0.0=hd3eb1b0_0
|
29 |
+
- intel-openmp=2023.1.0=hdb19cb5_46305
|
30 |
+
- ipykernel=6.19.2=py38hb070fc8_0
|
31 |
+
- ipython=8.12.0=py38h06a4308_0
|
32 |
+
- jedi=0.18.1=py38h06a4308_1
|
33 |
+
- jupyter_client=8.1.0=py38h06a4308_0
|
34 |
+
- jupyter_core=5.3.0=py38h06a4308_0
|
35 |
+
- ld_impl_linux-64=2.38=h1181459_1
|
36 |
+
- libffi=3.4.4=h6a678d5_0
|
37 |
+
- libgcc-ng=11.2.0=h1234567_1
|
38 |
+
- libgfortran-ng=11.2.0=h00389a5_1
|
39 |
+
- libgfortran5=11.2.0=h1234567_1
|
40 |
+
- libgomp=11.2.0=h1234567_1
|
41 |
+
- libllvm14=14.0.6=hdb19cb5_3
|
42 |
+
- libprotobuf=3.20.3=he621ea3_0
|
43 |
+
- libsodium=1.0.18=h7b6447c_0
|
44 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
45 |
+
- loguru=0.7.1=py38h578d9bd_0
|
46 |
+
- matplotlib-inline=0.1.6=py38h06a4308_0
|
47 |
+
- mkl=2023.1.0=h6d00ec8_46342
|
48 |
+
- mkl-service=2.4.0=py38h5eee18b_1
|
49 |
+
- mkl_fft=1.3.6=py38h417a72b_1
|
50 |
+
- mkl_random=1.2.2=py38h417a72b_1
|
51 |
+
- ncurses=6.4=h6a678d5_0
|
52 |
+
- nest-asyncio=1.5.6=py38h06a4308_0
|
53 |
+
- numpy-base=1.24.3=py38h060ed82_1
|
54 |
+
- openssl=3.0.10=h7f8727e_2
|
55 |
+
- packaging=23.0=py38h06a4308_0
|
56 |
+
- parso=0.8.3=pyhd3eb1b0_0
|
57 |
+
- pcre=8.45=h295c915_0
|
58 |
+
- pexpect=4.8.0=pyhd3eb1b0_3
|
59 |
+
- pickleshare=0.7.5=pyhd3eb1b0_1003
|
60 |
+
- pip=23.2.1=py38h06a4308_0
|
61 |
+
- platformdirs=2.5.2=py38h06a4308_0
|
62 |
+
- prompt-toolkit=3.0.36=py38h06a4308_0
|
63 |
+
- psutil=5.9.0=py38h5eee18b_0
|
64 |
+
- ptyprocess=0.7.0=pyhd3eb1b0_2
|
65 |
+
- pure_eval=0.2.2=pyhd3eb1b0_0
|
66 |
+
- pycparser=2.21=pyhd3eb1b0_0
|
67 |
+
- pygments=2.15.1=py38h06a4308_1
|
68 |
+
- pyopenssl=23.0.0=py38h06a4308_0
|
69 |
+
- pysocks=1.7.1=py38h06a4308_0
|
70 |
+
- python=3.8.16=h955ad1f_4
|
71 |
+
- python-dateutil=2.8.2=pyhd3eb1b0_0
|
72 |
+
- python_abi=3.8=2_cp38
|
73 |
+
- pyyaml=6.0=py38h0a891b7_4
|
74 |
+
- pyzmq=25.1.0=py38h6a678d5_0
|
75 |
+
- readline=8.2=h5eee18b_0
|
76 |
+
- setuptools=67.8.0=py38h06a4308_0
|
77 |
+
- six=1.16.0=pyhd3eb1b0_1
|
78 |
+
- sqlite=3.41.2=h5eee18b_0
|
79 |
+
- stack_data=0.2.0=pyhd3eb1b0_0
|
80 |
+
- tbb=2021.8.0=hdb19cb5_0
|
81 |
+
- tk=8.6.12=h1ccaba5_0
|
82 |
+
- tornado=6.2=py38h5eee18b_0
|
83 |
+
- traitlets=5.7.1=py38h06a4308_0
|
84 |
+
- typing_extensions=4.6.3=py38h06a4308_0
|
85 |
+
- wcwidth=0.2.5=pyhd3eb1b0_0
|
86 |
+
- wheel=0.38.4=py38h06a4308_0
|
87 |
+
- xz=5.4.2=h5eee18b_0
|
88 |
+
- yaml=0.2.5=h7b6447c_0
|
89 |
+
- zeromq=4.3.4=h2531618_0
|
90 |
+
- zlib=1.2.13=h5eee18b_0
|
91 |
+
- pip:
|
92 |
+
- absl-py==1.4.0
|
93 |
+
- aiohttp==3.8.4
|
94 |
+
- ale-py==0.8.1
|
95 |
+
- annotated-types==0.5.0
|
96 |
+
- appdirs==1.4.4
|
97 |
+
- beautifulsoup4==4.12.2
|
98 |
+
- box2d-py==2.3.5
|
99 |
+
- cachetools==5.3.1
|
100 |
+
- cchardet==2.1.7
|
101 |
+
- charset-normalizer==3.1.0
|
102 |
+
- click==8.1.3
|
103 |
+
- cloudpickle==2.2.1
|
104 |
+
- contourpy==1.1.0
|
105 |
+
- cycler==0.11.0
|
106 |
+
- cython==3.0.1
|
107 |
+
- dataclasses-json==0.5.14
|
108 |
+
- decorator==4.4.2
|
109 |
+
- docker-pycreds==0.4.0
|
110 |
+
- fasteners==0.18
|
111 |
+
- filelock==3.12.2
|
112 |
+
- fonttools==4.40.0
|
113 |
+
- fsspec==2023.6.0
|
114 |
+
- gitdb==4.0.10
|
115 |
+
- gitpython==3.1.31
|
116 |
+
- glfw==2.6.2
|
117 |
+
- google-auth==2.21.0
|
118 |
+
- google-auth-oauthlib==1.0.0
|
119 |
+
- greenlet==2.0.2
|
120 |
+
- grpcio==1.56.0
|
121 |
+
- gym==0.26.2
|
122 |
+
- gym-notices==0.0.8
|
123 |
+
- h5py==3.9.0
|
124 |
+
- huggingface-hub==0.15.1
|
125 |
+
- imageio==2.31.2
|
126 |
+
- imageio-ffmpeg==0.4.8
|
127 |
+
- importlib-metadata==6.6.0
|
128 |
+
- importlib-resources==5.12.0
|
129 |
+
- iniconfig==2.0.0
|
130 |
+
- kiwisolver==1.4.4
|
131 |
+
- langchain==0.0.284
|
132 |
+
- langsmith==0.0.33
|
133 |
+
- llvmlite==0.40.1
|
134 |
+
- lz4==4.3.2
|
135 |
+
- markdown==3.4.3
|
136 |
+
- markupsafe==2.1.1
|
137 |
+
- marshmallow==3.20.1
|
138 |
+
- matplotlib==3.7.1
|
139 |
+
- moviepy==1.0.3
|
140 |
+
- mujoco==2.2.0
|
141 |
+
- mujoco-py==2.1.2.14
|
142 |
+
- multidict==6.0.4
|
143 |
+
- numba==0.57.1
|
144 |
+
- numexpr==2.8.5
|
145 |
+
- numpy==1.24.4
|
146 |
+
- oauthlib==3.2.2
|
147 |
+
- openai==0.27.8
|
148 |
+
- opencv-python==4.8.0.76
|
149 |
+
- pathtools==0.1.2
|
150 |
+
- pillow==9.5.0
|
151 |
+
- pluggy==1.2.0
|
152 |
+
- proglog==0.1.10
|
153 |
+
- protobuf==3.19.6
|
154 |
+
- py==1.11.0
|
155 |
+
- pyasn1==0.5.0
|
156 |
+
- pyasn1-modules==0.3.0
|
157 |
+
- pydantic==2.3.0
|
158 |
+
- pydantic-core==2.6.3
|
159 |
+
- pygame==2.1.0
|
160 |
+
- pyopengl==3.1.7
|
161 |
+
- pyparsing==3.0.9
|
162 |
+
- pytest==7.0.1
|
163 |
+
- regex==2023.6.3
|
164 |
+
- requests==2.31.0
|
165 |
+
- requests-oauthlib==1.3.1
|
166 |
+
- rsa==4.9
|
167 |
+
- safetensors==0.3.1
|
168 |
+
- sentry-sdk==1.26.0
|
169 |
+
- setproctitle==1.3.2
|
170 |
+
- smmap==5.0.0
|
171 |
+
- soupsieve==2.4.1
|
172 |
+
- sqlalchemy==2.0.20
|
173 |
+
- swig==4.1.1
|
174 |
+
- tenacity==8.2.3
|
175 |
+
- tensorboard==2.14.0
|
176 |
+
- tensorboard-data-server==0.7.1
|
177 |
+
- tianshou==0.4.10
|
178 |
+
- tokenizers==0.13.3
|
179 |
+
# - torch==1.12.0+cu113
|
180 |
+
# - torchaudio==0.12.0+cu113
|
181 |
+
# - torchvision==0.13.0+cu113
|
182 |
+
- tqdm==4.65.0
|
183 |
+
- transformers==4.30.2
|
184 |
+
- typing==3.7.4.3
|
185 |
+
- typing-extensions==4.7.1
|
186 |
+
- typing-inspect==0.9.0
|
187 |
+
- urllib3
|
188 |
+
- v==1
|
189 |
+
- wandb==0.15.4
|
190 |
+
- werkzeug==2.3.6
|
191 |
+
- yarl==1.9.2
|
192 |
+
- zipp==3.15.0
|
193 |
+
- aquarel==0.0.5
|
envs/__init__.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base_env import BaseEnv, SettableStateEnv
|
2 |
+
from .classic_control import cartpole_translator, cartpole_policies
|
3 |
+
from .classic_control import acrobot_translator, acrobot_policies
|
4 |
+
from .classic_control import mountaincar_translator, mountaincar_policies
|
5 |
+
from .classic_control import mountaincarContinuous_translator,mountaincarContinuous_policies
|
6 |
+
|
7 |
+
from .box2d import LunarLander_translator, LunarLander_policies
|
8 |
+
|
9 |
+
from .toy_text import blackjack_translator, blackjack_policies
|
10 |
+
from .toy_text import taxi_translator, taxi_policies
|
11 |
+
from .toy_text import cliffwalking_translator, cliffwalking_policies
|
12 |
+
from .toy_text import frozenlake_translator, frozenlake_policies
|
13 |
+
|
14 |
+
REGISTRY = {}
|
15 |
+
REGISTRY["sampling_wrapper"] = SettableStateEnv
|
16 |
+
REGISTRY["base_env"] = BaseEnv
|
17 |
+
REGISTRY["cart_init_translator"] = cartpole_translator.GameDescriber
|
18 |
+
REGISTRY["cart_basic_translator"] = cartpole_translator.BasicStateSequenceTranslator
|
19 |
+
REGISTRY["acrobot_init_translator"] = acrobot_translator.GameDescriber
|
20 |
+
REGISTRY["acrobot_basic_translator"] = acrobot_translator.BasicStateSequenceTranslator
|
21 |
+
REGISTRY["mountaincar_init_translator"] = mountaincar_translator.GameDescriber
|
22 |
+
REGISTRY["mountaincar_basic_translator"] = mountaincar_translator.BasicStateSequenceTranslator
|
23 |
+
|
24 |
+
REGISTRY["cart_policies"] = [cartpole_policies.dedicated_1_policy, cartpole_policies.dedicated_2_policy, cartpole_policies.pseudo_random_policy, cartpole_policies.real_random_policy]
|
25 |
+
REGISTRY["acrobot_policies"] = [acrobot_policies.dedicated_1_policy, acrobot_policies.dedicated_2_policy, acrobot_policies.dedicated_3_policy, acrobot_policies.pseudo_random_policy, acrobot_policies.real_random_policy]
|
26 |
+
REGISTRY["mountaincar_policies"] = [mountaincar_policies.dedicated_1_policy, mountaincar_policies.dedicated_2_policy, mountaincar_policies.dedicated_3_policy, mountaincar_policies.pseudo_random_policy, mountaincar_policies.real_random_policy]
|
27 |
+
|
28 |
+
REGISTRY["lunarLander_init_translator"] = LunarLander_translator.GameDescriber
|
29 |
+
REGISTRY["lunarLander_basic_translator"] = LunarLander_translator.BasicStateSequenceTranslator
|
30 |
+
REGISTRY["lunarLander_policies"] = [LunarLander_policies.dedicated_1_policy, LunarLander_policies.dedicated_2_policy, LunarLander_policies.dedicated_3_policy,LunarLander_policies.dedicated_4_policy, LunarLander_policies.pseudo_random_policy, LunarLander_policies.real_random_policy]
|
31 |
+
|
32 |
+
REGISTRY["blackjack_init_translator"] = blackjack_translator.GameDescriber
|
33 |
+
REGISTRY["blackjack_basic_translator"] = blackjack_translator.BasicStateSequenceTranslator
|
34 |
+
REGISTRY["blackjack_policies"] = [blackjack_policies.dedicated_1_policy, blackjack_policies.dedicated_2_policy, blackjack_policies.pseudo_random_policy, blackjack_policies.real_random_policy]
|
35 |
+
|
36 |
+
REGISTRY["taxi_init_translator"] = taxi_translator.GameDescriber
|
37 |
+
REGISTRY["taxi_basic_translator"] = taxi_translator.BasicStateSequenceTranslator
|
38 |
+
REGISTRY["taxi_policies"] = [taxi_policies.dedicated_1_policy, taxi_policies.dedicated_2_policy, taxi_policies.dedicated_3_policy, taxi_policies.dedicated_4_policy, taxi_policies.dedicated_5_policy, taxi_policies.dedicated_6_policy, taxi_policies.pseudo_random_policy, taxi_policies.real_random_policy]
|
39 |
+
|
40 |
+
REGISTRY["cliffwalking_init_translator"] = cliffwalking_translator.GameDescriber
|
41 |
+
REGISTRY["cliffwalking_basic_translator"] = cliffwalking_translator.BasicStateSequenceTranslator
|
42 |
+
REGISTRY["cliffwalking_policies"] = [cliffwalking_policies.dedicated_1_policy, cliffwalking_policies.dedicated_2_policy, cliffwalking_policies.dedicated_3_policy, cliffwalking_policies.dedicated_4_policy, cliffwalking_policies.pseudo_random_policy, cliffwalking_policies.real_random_policy]
|
43 |
+
|
44 |
+
REGISTRY["frozenlake_init_translator"] = frozenlake_translator.GameDescriber
|
45 |
+
REGISTRY["frozenlake_basic_translator"] = frozenlake_translator.BasicStateSequenceTranslator
|
46 |
+
REGISTRY["frozenlake_policies"] = [frozenlake_policies.dedicated_1_policy, frozenlake_policies.dedicated_2_policy, frozenlake_policies.dedicated_3_policy, frozenlake_policies.dedicated_4_policy, frozenlake_policies.pseudo_random_policy, frozenlake_policies.real_random_policy]
|
47 |
+
|
48 |
+
|
49 |
+
REGISTRY["mountaincarContinuous_init_translator"] = mountaincarContinuous_translator.GameDescriber
|
50 |
+
REGISTRY["mountaincarContinuous_basic_translator"] = mountaincarContinuous_translator.BasicStateSequenceTranslator
|
51 |
+
REGISTRY["mountaincarContinuous_policies"] = [mountaincarContinuous_policies.pseudo_random_policy, mountaincarContinuous_policies.real_random_policy]
|
envs/base_env.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file contains functions for interacting with the CartPole environment
|
2 |
+
|
3 |
+
import gym
|
4 |
+
|
5 |
+
class SettableStateEnv(gym.Wrapper):
|
6 |
+
def __init__(self, env):
|
7 |
+
super().__init__(env)
|
8 |
+
self.env = env
|
9 |
+
|
10 |
+
def set_state(self, state):
|
11 |
+
self.env.state = state
|
12 |
+
self.env.steps_beyond_terminated = None
|
13 |
+
|
14 |
+
class BaseEnv(gym.Wrapper):
|
15 |
+
def __init__(self, env, translator):
|
16 |
+
super().__init__(env)
|
17 |
+
self.translator = translator
|
18 |
+
self.env_name = super().spec.id
|
19 |
+
self.transition_data = {}
|
20 |
+
self.game_description = self.get_game_description()
|
21 |
+
self.goal_description = self.get_goal_description()
|
22 |
+
self.action_description = self.get_action_description()
|
23 |
+
self.action_desc_dict = self.get_action_desc_dict()
|
24 |
+
self.reward_desc_dict = self.get_reward_desc_dict()
|
25 |
+
|
26 |
+
def reset(self, **kwargs):
|
27 |
+
state, _ = super().reset(**kwargs)
|
28 |
+
self.transition_data['state'] = state
|
29 |
+
self.translator.obtain(self.transition_data)
|
30 |
+
summary, future_summary = self.translator.translate()
|
31 |
+
info = {
|
32 |
+
'future_summary': future_summary
|
33 |
+
}
|
34 |
+
self.state = state
|
35 |
+
return summary, info
|
36 |
+
|
37 |
+
def step(self, action):
|
38 |
+
potential_next_state = self.get_potential_next_state(action)
|
39 |
+
state, reward, terminated, _, info = super().step(action)
|
40 |
+
self.transition_data['action'] = action
|
41 |
+
self.transition_data['next_state'] = state
|
42 |
+
self.transition_data['reward'] = reward
|
43 |
+
self.transition_data['terminated'] = terminated
|
44 |
+
self.translator.update(self.transition_data)
|
45 |
+
self.transition_data = {}
|
46 |
+
self.transition_data['state'] = state
|
47 |
+
self.translator.obtain(self.transition_data)
|
48 |
+
summary, future_summary = self.translator.translate()
|
49 |
+
info = {
|
50 |
+
'future_summary': future_summary,
|
51 |
+
'potential_state': potential_next_state
|
52 |
+
}
|
53 |
+
return summary, reward, terminated, _, info
|
54 |
+
|
55 |
+
|
56 |
+
def step_llm(self, action):
|
57 |
+
potential_next_state = self.get_potential_next_state(action)
|
58 |
+
if "Continuous" in self.env_name:
|
59 |
+
state, reward, terminated, _, info = super().step(action)
|
60 |
+
else:
|
61 |
+
state, reward, terminated, _, info = super().step(action-1)
|
62 |
+
self.transition_data['action'] = action
|
63 |
+
self.transition_data['next_state'] = state
|
64 |
+
self.transition_data['reward'] = reward
|
65 |
+
self.transition_data['terminated'] = terminated
|
66 |
+
self.translator.update(self.transition_data)
|
67 |
+
self.transition_data = {}
|
68 |
+
self.transition_data['state'] = state
|
69 |
+
self.translator.obtain(self.transition_data)
|
70 |
+
self.state = state
|
71 |
+
summary, future_summary = self.translator.translate()
|
72 |
+
info = {
|
73 |
+
'future_summary': future_summary,
|
74 |
+
'potential_state': potential_next_state,
|
75 |
+
}
|
76 |
+
return summary, reward, terminated, _, info
|
77 |
+
|
78 |
+
def get_terminate_state(self, episode_len, max_episode_len):
|
79 |
+
return self.translator.translate_terminate_state(self.state, episode_len, max_episode_len)
|
80 |
+
|
81 |
+
def get_game_description(self,):
|
82 |
+
return self.translator.describe_game()
|
83 |
+
|
84 |
+
def get_goal_description(self,):
|
85 |
+
return self.translator.describe_goal()
|
86 |
+
|
87 |
+
def get_action_description(self,):
|
88 |
+
return self.translator.describe_action()
|
89 |
+
|
90 |
+
def get_action_desc_dict(self,):
|
91 |
+
return self.translator.get_action_desc_dict()
|
92 |
+
|
93 |
+
def get_reward_desc_dict(self,):
|
94 |
+
return self.translator.get_reward_desc_dict()
|
95 |
+
|
96 |
+
def get_potential_next_state(self, action):
|
97 |
+
return self.translator.translate_potential_next_state(self.state, action)
|
envs/box2d/LunarLander_policies.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
def dedicated_1_policy(state, pre_action=1):
|
3 |
+
def get_description():
|
4 |
+
return "Always select action 1 which do nothing"
|
5 |
+
dedicated_1_policy.description = get_description()
|
6 |
+
return 1
|
7 |
+
|
8 |
+
def dedicated_2_policy(state, pre_action=1):
|
9 |
+
def get_description():
|
10 |
+
return "Always select action 2 which fire left engine"
|
11 |
+
dedicated_2_policy.description = get_description()
|
12 |
+
return 2
|
13 |
+
|
14 |
+
def dedicated_3_policy(state, pre_action=1):
|
15 |
+
def get_description():
|
16 |
+
return "Always select action 3 which fire main engine"
|
17 |
+
dedicated_3_policy.description = get_description()
|
18 |
+
return 3
|
19 |
+
|
20 |
+
def dedicated_4_policy(state, pre_action=1):
|
21 |
+
def get_description():
|
22 |
+
return "Always select action 4 which fire right engine"
|
23 |
+
dedicated_4_policy.description = get_description()
|
24 |
+
return 4
|
25 |
+
|
26 |
+
def pseudo_random_policy(state, pre_action):
|
27 |
+
def get_description():
|
28 |
+
return "Select action 1, 2, 3, 4 alternatively which do nothing, fire left engine, fire main engine, and fire right engine alternatively"
|
29 |
+
pseudo_random_policy.description = get_description()
|
30 |
+
return pre_action%4+1
|
31 |
+
|
32 |
+
def real_random_policy(state,pre_action=0):
|
33 |
+
def get_description():
|
34 |
+
return "Select action with a random policy"
|
35 |
+
real_random_policy.description = get_description()
|
36 |
+
return np.random.choice([1, 2, 3, 4])
|
envs/box2d/LunarLander_translator.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# [Translator classes and functions for Lunar Lander environment]
|
2 |
+
|
3 |
+
class BasicLevelTranslator:
|
4 |
+
def __init__(self,):
|
5 |
+
pass
|
6 |
+
|
7 |
+
def translate(self, state):
|
8 |
+
x, y, x_dot, y_dot, angle, angular_velocity, left_leg_contact, right_leg_contact = state
|
9 |
+
left_contact_info = "in contact" if left_leg_contact else "not in contact"
|
10 |
+
right_contact_info = "in contact" if right_leg_contact else "not in contact"
|
11 |
+
return f"The lander is at position ({x:.2f}, {y:.2f}), the horizontal speed of movement is {x_dot:.2f}, " \
|
12 |
+
f"the vertical velocity speed of movement is {y_dot:.2f}. The angle is {angle:.2f} radians, and it's rotating at {angular_velocity:.2f} radians per second. The left leg is {left_contact_info} with ground. The right leg is {right_contact_info} with ground."
|
13 |
+
|
14 |
+
class GameDescriber:
|
15 |
+
def __init__(self, args):
|
16 |
+
self.is_only_local_obs = args.is_only_local_obs == 1
|
17 |
+
self.max_episode_len = args.max_episode_len
|
18 |
+
self.action_desc_dict = {
|
19 |
+
}
|
20 |
+
self.reward_desc_dict = {
|
21 |
+
}
|
22 |
+
|
23 |
+
def describe_goal(self):
|
24 |
+
return "The goal is to successfully land the lander on the landing pad which is at position (0, 0) while avoiding crash."
|
25 |
+
|
26 |
+
def translate_terminate_state(self, state, episode_len, max_episode_len):
|
27 |
+
return ""
|
28 |
+
|
29 |
+
def translate_potential_next_state(self, state, action):
|
30 |
+
return ""
|
31 |
+
|
32 |
+
def describe_game(self):
|
33 |
+
return "In the Lunar Lander game, you control a lander that is descending towards " \
|
34 |
+
"the landing pad. The goal is to successfully land the lander on the landing pad " \
|
35 |
+
"while avoiding crash. Please note that the lander is affected by gravity, and the lander starts at the " \
|
36 |
+
"top center of the viewport with a random initial force applied to its center of mass. " \
|
37 |
+
"Be careful to balance the engine to slow down your descent " \
|
38 |
+
"and land gently. If you land too quickly or crash into the landing pad, the game will " \
|
39 |
+
"end, and you will be punished."
|
40 |
+
|
41 |
+
def describe_action(self):
|
42 |
+
return "Your Next Move: \n Please choose an action. Type '1' to do noting, '2' to fire left engine and make lander move to right, '3' to fire main engine and make lander move to up, " \
|
43 |
+
"or '4' to fire right engine and make lander move to left. Ensure you only provide the action number from the valid action list, i.e., [1, 2, 3, 4]."
|
44 |
+
|
45 |
+
|
46 |
+
class BasicStateSequenceTranslator(BasicLevelTranslator):
|
47 |
+
def translate(self, infos, is_current=False):
|
48 |
+
descriptions = []
|
49 |
+
if is_current:
|
50 |
+
state_desc = BasicLevelTranslator().translate(infos[-1]['state'])
|
51 |
+
return state_desc
|
52 |
+
for i, info in enumerate(infos):
|
53 |
+
assert 'state' in info, "info should contain state information"
|
54 |
+
|
55 |
+
state_desc = BasicLevelTranslator().translate(info['state'])
|
56 |
+
if info['action'] == 1:
|
57 |
+
action_desc = f"Take Action: 'Do Noting'"
|
58 |
+
elif info['action'] == 2:
|
59 |
+
action_desc = f"Take Action: 'Fire left engine'"
|
60 |
+
elif info['action'] == 3:
|
61 |
+
action_desc = f"Take Action: 'Fire main engine'"
|
62 |
+
else:
|
63 |
+
action_desc = f"Take Action: 'Fire right engine'"
|
64 |
+
reward_desc = f"Result: Reward of {info['reward']}, "
|
65 |
+
next_state_desc = BasicLevelTranslator().translate(info['next_state'])
|
66 |
+
descriptions.append(f"{state_desc}.\n {action_desc} \n {reward_desc} \n Transit to {next_state_desc}")
|
67 |
+
return descriptions
|
envs/box2d/__init__.py
ADDED
File without changes
|
envs/box2d/few_shot_examples/lunarlander_l2.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
envs/box2d/few_shot_examples/lunarlander_l4.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
envs/classic_control/__init__.py
ADDED
File without changes
|
envs/classic_control/acrobot_policies.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
# https://colab.research.google.com/drive/1DdWsGi10232orUv-reY4wsTmT0VMoHaX?usp=sharing#scrollTo=4OfVmDKk7XvG
|
4 |
+
# LLMs bias on 0 so make the actions 1, 2 and 3 instead.
|
5 |
+
|
6 |
+
def dedicated_1_policy(state, pre_action=1):
|
7 |
+
def get_description():
|
8 |
+
return "Always select action 1"
|
9 |
+
dedicated_0_policy.description = get_description()
|
10 |
+
return 1
|
11 |
+
|
12 |
+
def dedicated_2_policy(state, pre_action=1):
|
13 |
+
def get_description():
|
14 |
+
return "Always select action 2"
|
15 |
+
dedicated_2_policy.description = get_description()
|
16 |
+
return 2
|
17 |
+
|
18 |
+
def dedicated_3_policy(state, pre_action=1):
|
19 |
+
def get_description():
|
20 |
+
return "Always select action 3"
|
21 |
+
dedicated_3_policy.description = get_description()
|
22 |
+
return 3
|
23 |
+
|
24 |
+
def pseudo_random_policy(state, pre_action):
|
25 |
+
def get_description():
|
26 |
+
return "Select action 1, 2, and 3 alternatively"
|
27 |
+
pseudo_random_policy.description = get_description()
|
28 |
+
return pre_action % 3 + 1
|
29 |
+
|
30 |
+
|
31 |
+
def real_random_policy(state, pre_action=1):
|
32 |
+
def get_description():
|
33 |
+
return "Select action with a random policy"
|
34 |
+
real_random_policy.description = get_description()
|
35 |
+
return np.random.choice([1, 2, 3])
|
36 |
+
|
envs/classic_control/acrobot_translator.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
class BasicLevelTranslator:
|
4 |
+
def __init__(self):
|
5 |
+
pass
|
6 |
+
|
7 |
+
def translate(self, state):
|
8 |
+
cos_theta1, sin_theta1, cos_theta2, sin_theta2, theta1_dot, theta2_dot = state
|
9 |
+
theta1_direction = "clockwise" if theta1_dot > 0 else "counterclockwise"
|
10 |
+
theta2_direction = "clockwise" if theta2_dot > 0 else "counterclockwise"
|
11 |
+
theta1 = math.atan(sin_theta1 / (cos_theta1+1e-6))
|
12 |
+
theta2 = math.atan(sin_theta2 / (cos_theta2+1e-6))
|
13 |
+
res = (f"Link1: angle theta1 {theta1:.2f} radians, rotating {abs(theta1_dot):.2f} radians per second {theta1_direction}. "
|
14 |
+
f"Link2: angle theta2 {theta2:.2f} radians relative to Link1, rotating {abs(theta2_dot):.2f} radians per second {theta2_direction}.")
|
15 |
+
return res
|
16 |
+
|
17 |
+
class GameDescriber:
|
18 |
+
def __init__(self, args):
|
19 |
+
self.is_only_local_obs = args.is_only_local_obs == 1
|
20 |
+
self.max_episode_len = args.max_episode_len
|
21 |
+
self.action_desc_dict = {
|
22 |
+
}
|
23 |
+
self.reward_desc_dict = {
|
24 |
+
}
|
25 |
+
|
26 |
+
def describe_goal(self):
|
27 |
+
return "The goal is to apply torque on the actuator to swing the free end of the linear chain above the target height, which is constructed as: -cos(theta1) - cos(theta2 + theta1) > 1.0."
|
28 |
+
|
29 |
+
def translate_terminate_state(self, state, episode_len, max_episode_len):
|
30 |
+
return ""
|
31 |
+
|
32 |
+
def translate_potential_next_state(self, state, action):
|
33 |
+
return ""
|
34 |
+
|
35 |
+
def describe_game(self):
|
36 |
+
return ('''In the Acrobot game, there are two links connected by two joints. The first link is connected to a base, and your goal is to swing the free end of the second link above the target height by applying torques on the actuated joint. The task ends if one of the following occurs: 1. The free end reaches the target height, which is constructed as: -cos(theta1) - cos(theta2 + theta1) > 1.0; or 2. Decision time is greater than 200.''')
|
37 |
+
|
38 |
+
# https://colab.research.google.com/drive/1DdWsGi10232orUv-reY4wsTmT0VMoHaX?usp=sharing#scrollTo=4OfVmDKk7XvG
|
39 |
+
# LLMs bias on 0 so make the actions 1, 2 and 3 instead.
|
40 |
+
def describe_action(self):
|
41 |
+
return ("Your Next Move: \\n Please choose an action. Type '1' to apply -1 torque, '2' to apply 0 torque, or '3' to apply 1 torque. "
|
42 |
+
"Ensure you provide the action number from the valid action list, i.e., [1, 2, 3].")
|
43 |
+
|
44 |
+
class BasicStateSequenceTranslator(BasicLevelTranslator):
|
45 |
+
def translate(self, infos, is_current=False):
|
46 |
+
descriptions = []
|
47 |
+
if is_current:
|
48 |
+
state_desc = BasicLevelTranslator().translate(infos[-1]['state'])
|
49 |
+
return state_desc
|
50 |
+
for i, info in enumerate(infos):
|
51 |
+
assert 'state' in info, "info should contain state information"
|
52 |
+
|
53 |
+
state_desc = BasicLevelTranslator().translate(info['state'])
|
54 |
+
action_desc = f"Take Action: Apply {info['action'] - 2} torque on the actuated joint."
|
55 |
+
reward_desc = f"Result: Reward of {info['reward']}."
|
56 |
+
next_state_desc = BasicLevelTranslator().translate(info['next_state'])
|
57 |
+
descriptions.append(f"{state_desc}.\n {action_desc} \n {reward_desc} \n Transit to {next_state_desc}")
|
58 |
+
return descriptions
|
envs/classic_control/cartpole_policies.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
def dedicated_1_policy(state, pre_action=1):
|
3 |
+
def get_description():
|
4 |
+
return "Always select action 1"
|
5 |
+
dedicated_1_policy.description = get_description()
|
6 |
+
return 1
|
7 |
+
|
8 |
+
def dedicated_2_policy(state, pre_action=1):
|
9 |
+
def get_description():
|
10 |
+
return "Always select action 2"
|
11 |
+
dedicated_2_policy.description = get_description()
|
12 |
+
return 2
|
13 |
+
|
14 |
+
def pseudo_random_policy(state, pre_action):
|
15 |
+
def get_description():
|
16 |
+
return "Select action 1 and 2 alternatively"
|
17 |
+
pseudo_random_policy.description = get_description()
|
18 |
+
return pre_action%2 +1
|
19 |
+
|
20 |
+
def real_random_policy(state,pre_action=1):
|
21 |
+
def get_description():
|
22 |
+
return "Select action with a random policy"
|
23 |
+
real_random_policy.description = get_description()
|
24 |
+
return np.random.choice([1, 2])
|
25 |
+
|
envs/classic_control/cartpole_translator.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
class BasicLevelTranslator:
|
3 |
+
def __init__(self,):
|
4 |
+
pass
|
5 |
+
|
6 |
+
def translate(self, state):
|
7 |
+
cart_position, cart_velocity, pole_angle, pole_angular_velocity = state
|
8 |
+
cart_direction = "right" if cart_velocity > 0 else "left"
|
9 |
+
pole_direction = "right" if pole_angular_velocity > 0 else "left"
|
10 |
+
res = (f"The cart is positioned at {cart_position:.3f}, with a velocity of {abs(cart_velocity):.2f} towards the {cart_direction}. "
|
11 |
+
f"The pole is tilted at {abs(pole_angle):.2f} radians, rotating at {abs(pole_angular_velocity):.2f} radians per second towards the {pole_direction}.")
|
12 |
+
return res
|
13 |
+
|
14 |
+
class GameDescriber:
|
15 |
+
def __init__(self, args):
|
16 |
+
self.is_only_local_obs = args.is_only_local_obs == 1
|
17 |
+
self.max_episode_len = args.max_episode_len
|
18 |
+
self.action_desc_dict = {
|
19 |
+
}
|
20 |
+
self.reward_desc_dict = {
|
21 |
+
}
|
22 |
+
|
23 |
+
def describe_goal(self):
|
24 |
+
return "The goal is to keep the pole balanced upright for as long as possible."
|
25 |
+
|
26 |
+
def translate_terminate_state(self, state, episode_len, max_episode_len):
|
27 |
+
return ""
|
28 |
+
|
29 |
+
def translate_potential_next_state(self, state, action):
|
30 |
+
return ""
|
31 |
+
|
32 |
+
def describe_game(self):
|
33 |
+
return "In the CartPole game, you control a cart that moves along a horizontal track. There is a pole " \
|
34 |
+
"standing upright on the cart. The goal of the game is to keep the pole balanced upright by moving the " \
|
35 |
+
"cart left or right. The game ends if the pole tilts too far from the vertical position or if the cart " \
|
36 |
+
"moves too far from the center of the track. The longer you can keep the pole balanced, the higher " \
|
37 |
+
"your score.Note that when the Cart Position is out of the (-2.4, 2.4) zone or the Pole Angle is out " \
|
38 |
+
"of the zone (-.2095, .2095), the round ends and the game is lost. "
|
39 |
+
|
40 |
+
def describe_action(self):
|
41 |
+
return "Your Next Move: \n Please choose an action. Type '1' to push the cart to the left or '2' to push the cart to the right. Ensure you only provide the action number from the valid action list, i.e., [1, 2]."
|
42 |
+
|
43 |
+
class BasicStateSequenceTranslator(BasicLevelTranslator):
|
44 |
+
def translate(self, infos, is_current=False):
|
45 |
+
descriptions = []
|
46 |
+
if is_current:
|
47 |
+
state_desc = BasicLevelTranslator().translate(infos[-1]['state'])
|
48 |
+
return state_desc
|
49 |
+
for i, info in enumerate(infos):
|
50 |
+
assert 'state' in info, "info should contain state information"
|
51 |
+
|
52 |
+
state_desc = BasicLevelTranslator().translate(info['state'])
|
53 |
+
action_desc = f"Take Action: Push {'right' if info['action'] == 2 else 'left'} ({info['action']})."
|
54 |
+
reward_desc = f"Result: Reward of {info['reward']}, "
|
55 |
+
next_state_desc = BasicLevelTranslator().translate(info['next_state'])
|
56 |
+
descriptions.append(f"{state_desc}.\n {action_desc} \n {reward_desc} \n Transit to {next_state_desc}")
|
57 |
+
return descriptions
|
envs/classic_control/few_shot_examples/acrobot_l2.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
envs/classic_control/few_shot_examples/acrobot_l4.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
envs/classic_control/few_shot_examples/cartpole_l2.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
envs/classic_control/few_shot_examples/cartpole_l4.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
envs/classic_control/few_shot_examples/mountaincarContinuous_l2.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
envs/classic_control/few_shot_examples/mountaincarContinuous_l4.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
envs/classic_control/few_shot_examples/mountaincar_l2.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|