qgallouedec HF staff commited on
Commit
61c9b91
1 Parent(s): 30f1e9b

pushing model

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ rainbow_atari.cleanrl_model filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - BreakoutNoFrameskip-v4
4
+ - deep-reinforcement-learning
5
+ - reinforcement-learning
6
+ - custom-implementation
7
+ library_name: cleanrl
8
+ model-index:
9
+ - name: RAINBOW
10
+ results:
11
+ - task:
12
+ type: reinforcement-learning
13
+ name: reinforcement-learning
14
+ dataset:
15
+ name: BreakoutNoFrameskip-v4
16
+ type: BreakoutNoFrameskip-v4
17
+ metrics:
18
+ - type: mean_reward
19
+ value: 4.50 +/- 4.50
20
+ name: mean_reward
21
+ verified: false
22
+ ---
23
+
24
+ # (CleanRL) **RAINBOW** Agent Playing **BreakoutNoFrameskip-v4**
25
+
26
+ This is a trained model of a RAINBOW agent playing BreakoutNoFrameskip-v4.
27
+ The model was trained by using [CleanRL](https://github.com/vwxyzjn/cleanrl) and the most up-to-date training code can be
28
+ found [here](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/rainbow_atari.py).
29
+
30
+ ## Get Started
31
+
32
+ To use this model, please install the `cleanrl` package with the following command:
33
+
34
+ ```
35
+ pip install "cleanrl[rainbow_atari]"
36
+ python -m cleanrl_utils.enjoy --exp-name rainbow_atari --env-id BreakoutNoFrameskip-v4
37
+ ```
38
+
39
+ Please refer to the [documentation](https://docs.cleanrl.dev/get-started/zoo/) for more detail.
40
+
41
+
42
+ ## Command to reproduce the training
43
+
44
+ ```bash
45
+ curl -OL https://huggingface.co/qgallouedec/BreakoutNoFrameskip-v4-rainbow_atari-seed1/raw/main/rainbow_atari.py
46
+ curl -OL https://huggingface.co/qgallouedec/BreakoutNoFrameskip-v4-rainbow_atari-seed1/raw/main/pyproject.toml
47
+ curl -OL https://huggingface.co/qgallouedec/BreakoutNoFrameskip-v4-rainbow_atari-seed1/raw/main/poetry.lock
48
+ poetry install --all-extras
49
+ python rainbow_atari.py --learning-starts 100 --total-timesteps 5000 --save-model --upload-model --hf-entity qgallouedec
50
+ ```
51
+
52
+ # Hyperparameters
53
+ ```python
54
+ {'batch_size': 32,
55
+ 'buffer_size': 1000000,
56
+ 'capture_video': False,
57
+ 'cuda': True,
58
+ 'env_id': 'BreakoutNoFrameskip-v4',
59
+ 'exp_name': 'rainbow_atari',
60
+ 'gamma': 0.99,
61
+ 'hf_entity': 'qgallouedec',
62
+ 'learning_rate': 0.00025,
63
+ 'learning_starts': 100,
64
+ 'n_atoms': 51,
65
+ 'num_envs': 1,
66
+ 'save_model': True,
67
+ 'seed': 1,
68
+ 'target_network_frequency': 10000,
69
+ 'torch_deterministic': True,
70
+ 'total_timesteps': 5000,
71
+ 'track': False,
72
+ 'train_frequency': 4,
73
+ 'upload_model': True,
74
+ 'v_max': 10,
75
+ 'v_min': -10,
76
+ 'wandb_entity': None,
77
+ 'wandb_project_name': 'cleanRL'}
78
+ ```
79
+
events.out.tfevents.1700431636.MacBook-Pro-de-Quentin.local.52615.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bfa07a2b4c1146a00b269973643f3d22e86edac5852a0cc786fae5ed7930d365
3
+ size 11272
poetry.lock ADDED
The diff for this file is too large to render. See raw diff
 
pyproject.toml ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "cleanrl"
3
+ version = "1.1.0"
4
+ description = "High-quality single file implementation of Deep Reinforcement Learning algorithms with research-friendly features"
5
+ authors = ["Costa Huang <[email protected]>"]
6
+ packages = [
7
+ { include = "cleanrl" },
8
+ { include = "cleanrl_utils" },
9
+ ]
10
+ keywords = ["reinforcement", "machine", "learning", "research"]
11
+ license="MIT"
12
+ readme = "README.md"
13
+
14
+ [tool.poetry.dependencies]
15
+ python = ">=3.7.1,<3.11"
16
+ tensorboard = "^2.10.0"
17
+ wandb = "^0.13.11"
18
+ gym = "0.23.1"
19
+ torch = ">=1.12.1"
20
+ stable-baselines3 = "1.2.0"
21
+ gymnasium = ">=0.28.1"
22
+ moviepy = "^1.0.3"
23
+ pygame = "2.1.0"
24
+ huggingface-hub = "^0.11.1"
25
+ rich = "<12.0"
26
+ tenacity = "^8.2.2"
27
+
28
+ ale-py = {version = "0.7.4", optional = true}
29
+ AutoROM = {extras = ["accept-rom-license"], version = "^0.4.2", optional = true}
30
+ opencv-python = {version = "^4.6.0.66", optional = true}
31
+ procgen = {version = "^0.10.7", optional = true}
32
+ pytest = {version = "^7.1.3", optional = true}
33
+ mujoco = {version = "<=2.3.3", optional = true}
34
+ imageio = {version = "^2.14.1", optional = true}
35
+ free-mujoco-py = {version = "^2.1.6", optional = true}
36
+ mkdocs-material = {version = "^8.4.3", optional = true}
37
+ markdown-include = {version = "^0.7.0", optional = true}
38
+ openrlbenchmark = {version = "^0.1.1b4", optional = true}
39
+ jax = {version = "^0.3.17", optional = true}
40
+ jaxlib = {version = "^0.3.15", optional = true}
41
+ flax = {version = "^0.6.0", optional = true}
42
+ optuna = {version = "^3.0.1", optional = true}
43
+ optuna-dashboard = {version = "^0.7.2", optional = true}
44
+ envpool = {version = "^0.6.4", optional = true}
45
+ PettingZoo = {version = "1.18.1", optional = true}
46
+ SuperSuit = {version = "3.4.0", optional = true}
47
+ multi-agent-ale-py = {version = "0.1.11", optional = true}
48
+ boto3 = {version = "^1.24.70", optional = true}
49
+ awscli = {version = "^1.25.71", optional = true}
50
+ shimmy = {version = ">=1.0.0", extras = ["dm-control"], optional = true}
51
+
52
+ [tool.poetry.group.dev.dependencies]
53
+ pre-commit = "^2.20.0"
54
+
55
+
56
+ [tool.poetry.group.isaacgym]
57
+ optional = true
58
+ [tool.poetry.group.isaacgym.dependencies]
59
+ isaacgymenvs = {git = "https://github.com/vwxyzjn/IsaacGymEnvs.git", rev = "poetry", python = ">=3.7.1,<3.10"}
60
+ isaacgym = {path = "cleanrl/ppo_continuous_action_isaacgym/isaacgym", develop = true}
61
+
62
+
63
+ [build-system]
64
+ requires = ["poetry-core"]
65
+ build-backend = "poetry.core.masonry.api"
66
+
67
+ [tool.poetry.extras]
68
+ atari = ["ale-py", "AutoROM", "opencv-python"]
69
+ procgen = ["procgen"]
70
+ plot = ["pandas", "seaborn"]
71
+ pytest = ["pytest"]
72
+ mujoco = ["mujoco", "imageio"]
73
+ mujoco_py = ["free-mujoco-py"]
74
+ jax = ["jax", "jaxlib", "flax"]
75
+ docs = ["mkdocs-material", "markdown-include", "openrlbenchmark"]
76
+ envpool = ["envpool"]
77
+ optuna = ["optuna", "optuna-dashboard"]
78
+ pettingzoo = ["PettingZoo", "SuperSuit", "multi-agent-ale-py"]
79
+ cloud = ["boto3", "awscli"]
80
+ dm_control = ["shimmy", "mujoco"]
81
+
82
+ # dependencies for algorithm variant (useful when you want to run a specific algorithm)
83
+ dqn = []
84
+ dqn_atari = ["ale-py", "AutoROM", "opencv-python"]
85
+ dqn_jax = ["jax", "jaxlib", "flax"]
86
+ dqn_atari_jax = [
87
+ "ale-py", "AutoROM", "opencv-python", # atari
88
+ "jax", "jaxlib", "flax" # jax
89
+ ]
90
+ c51 = []
91
+ c51_atari = ["ale-py", "AutoROM", "opencv-python"]
92
+ c51_jax = ["jax", "jaxlib", "flax"]
93
+ c51_atari_jax = [
94
+ "ale-py", "AutoROM", "opencv-python", # atari
95
+ "jax", "jaxlib", "flax" # jax
96
+ ]
97
+ ppo_atari_envpool_xla_jax_scan = [
98
+ "ale-py", "AutoROM", "opencv-python", # atari
99
+ "jax", "jaxlib", "flax", # jax
100
+ "envpool", # envpool
101
+ ]
102
+ qdagger_dqn_atari_impalacnn = [
103
+ "ale-py", "AutoROM", "opencv-python"
104
+ ]
105
+ qdagger_dqn_atari_jax_impalacnn = [
106
+ "ale-py", "AutoROM", "opencv-python", # atari
107
+ "jax", "jaxlib", "flax", # jax
108
+ ]
rainbow_atari.cleanrl_model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:463a757c64a65defb5532bc934cc088ef6d03eb2c7b4a683779abb9a7481f0b3
3
+ size 40440331
rainbow_atari.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/c51/#c51_ataripy
2
+ import argparse
3
+ import math
4
+ import os
5
+ import random
6
+ import time
7
+ from collections import deque
8
+ from distutils.util import strtobool
9
+ from types import SimpleNamespace
10
+
11
+ import gymnasium as gym
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ import torch.nn.init as init
17
+ import torch.optim as optim
18
+ from stable_baselines3.common.atari_wrappers import ClipRewardEnv, EpisodicLifeEnv, FireResetEnv, MaxAndSkipEnv, NoopResetEnv
19
+ from torch.utils.tensorboard import SummaryWriter
20
+
21
+
22
+ def parse_args():
23
+ # fmt: off
24
+ parser = argparse.ArgumentParser()
25
+ parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"),
26
+ help="the name of this experiment")
27
+ parser.add_argument("--seed", type=int, default=1,
28
+ help="seed of the experiment")
29
+ parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
30
+ help="if toggled, `torch.backends.cudnn.deterministic=False`")
31
+ parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
32
+ help="if toggled, cuda will be enabled by default")
33
+ parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
34
+ help="if toggled, this experiment will be tracked with Weights and Biases")
35
+ parser.add_argument("--wandb-project-name", type=str, default="cleanRL",
36
+ help="the wandb's project name")
37
+ parser.add_argument("--wandb-entity", type=str, default=None,
38
+ help="the entity (team) of wandb's project")
39
+ parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
40
+ help="whether to capture videos of the agent performances (check out `videos` folder)")
41
+ parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
42
+ help="whether to save model into the `runs/{run_name}` folder")
43
+ parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
44
+ help="whether to upload the saved model to huggingface")
45
+ parser.add_argument("--hf-entity", type=str, default="",
46
+ help="the user or org name of the model repository from the Hugging Face Hub")
47
+
48
+ # Algorithm specific arguments
49
+ parser.add_argument("--env-id", type=str, default="BreakoutNoFrameskip-v4",
50
+ help="the id of the environment")
51
+ parser.add_argument("--total-timesteps", type=int, default=10000000,
52
+ help="total timesteps of the experiments")
53
+ parser.add_argument("--learning-rate", type=float, default=2.5e-4,
54
+ help="the learning rate of the optimizer")
55
+ parser.add_argument("--num-envs", type=int, default=1,
56
+ help="the number of parallel game environments")
57
+ parser.add_argument("--n-atoms", type=int, default=51,
58
+ help="the number of atoms")
59
+ parser.add_argument("--v-min", type=float, default=-10,
60
+ help="the return lower bound")
61
+ parser.add_argument("--v-max", type=float, default=10,
62
+ help="the return upper bound")
63
+ parser.add_argument("--buffer-size", type=int, default=1000000,
64
+ help="the replay memory buffer size")
65
+ parser.add_argument("--gamma", type=float, default=0.99,
66
+ help="the discount factor gamma")
67
+ parser.add_argument("--target-network-frequency", type=int, default=10000,
68
+ help="the timesteps it takes to update the target network")
69
+ parser.add_argument("--batch-size", type=int, default=32,
70
+ help="the batch size of sample from the reply memory")
71
+ parser.add_argument("--learning-starts", type=int, default=80000,
72
+ help="timestep to start learning")
73
+ parser.add_argument("--train-frequency", type=int, default=4,
74
+ help="the frequency of training")
75
+ args = parser.parse_args()
76
+ # fmt: on
77
+ assert args.num_envs == 1, "vectorized envs are not supported at the moment"
78
+
79
+ return args
80
+
81
+
82
+ def make_env(env_id, seed, idx, capture_video, run_name):
83
+ def thunk():
84
+ if capture_video and idx == 0:
85
+ env = gym.make(env_id, render_mode="rgb_array")
86
+ env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
87
+ else:
88
+ env = gym.make(env_id)
89
+ env = gym.wrappers.RecordEpisodeStatistics(env)
90
+
91
+ env = NoopResetEnv(env, noop_max=30)
92
+ env = MaxAndSkipEnv(env, skip=4)
93
+ env = EpisodicLifeEnv(env)
94
+ if "FIRE" in env.unwrapped.get_action_meanings():
95
+ env = FireResetEnv(env)
96
+ env = ClipRewardEnv(env)
97
+ env = gym.wrappers.ResizeObservation(env, (84, 84))
98
+ env = gym.wrappers.GrayScaleObservation(env)
99
+ env = gym.wrappers.FrameStack(env, 4)
100
+
101
+ env.action_space.seed(seed)
102
+ return env
103
+
104
+ return thunk
105
+
106
+
107
+ class SumTree:
108
+ def __init__(self, capacity):
109
+ self.capacity = capacity # Capacity of the sum tree (number of leaves)
110
+ self.tree = [0] * (2 * capacity) # Binary tree representation
111
+ self.max_priority = 1.0 # Initial max priority for new experiences
112
+
113
+ def update(self, index, priority=None):
114
+ if priority is None:
115
+ priority = self.max_priority
116
+ tree_idx = index + self.capacity
117
+ change = priority - self.tree[tree_idx]
118
+ self.tree[tree_idx] = priority
119
+ self._propagate(tree_idx, change)
120
+ self.max_priority = max(self.max_priority, priority)
121
+
122
+ def _propagate(self, idx, change):
123
+ parent = idx // 2
124
+ while parent != 0:
125
+ self.tree[parent] += change
126
+ parent = parent // 2
127
+
128
+ def total(self):
129
+ return self.tree[1] # The root of the tree holds the total sum
130
+
131
+ def get(self, s):
132
+ idx = 1
133
+ while idx < self.capacity: # Keep moving down the tree to find the index
134
+ left = 2 * idx
135
+ right = left + 1
136
+ if self.tree[left] >= s:
137
+ idx = left
138
+ else:
139
+ s -= self.tree[left]
140
+ idx = right
141
+ return idx - self.capacity
142
+
143
+
144
+ class PrioritizedReplayBuffer:
145
+ def __init__(self, size, device, alpha=0.5, beta_0=0.4, n_step=3, gamma=0.99):
146
+ self.size = size
147
+ self.device = device
148
+ self.alpha = alpha
149
+ self.beta_0 = beta_0
150
+ self.update_beta(0.0)
151
+ self.n_step = n_step
152
+ self.gamma = gamma
153
+
154
+ self.next_index = 0
155
+ self.sum_tree = SumTree(size)
156
+ self.observations = np.zeros((self.size, 4, 84, 84), dtype=np.uint8)
157
+ self.next_observations = np.zeros((self.size, 4, 84, 84), dtype=np.uint8)
158
+ self.actions = np.zeros((self.size, 1), dtype=np.int64)
159
+ self.rewards = np.zeros((self.size, 1), dtype=np.float32)
160
+ self.dones = np.zeros((self.size, 1), dtype=bool)
161
+
162
+ self.n_step_buffer = deque(maxlen=n_step)
163
+
164
+ def add(self, obs, next_obs, actions, rewards, dones, infos):
165
+ self.n_step_buffer.append((obs[0], next_obs[0], actions[0], rewards[0], dones[0], infos))
166
+
167
+ if len(self.n_step_buffer) < self.n_step and not dones[0]:
168
+ return
169
+
170
+ # Compute n-step return and the first state and action
171
+ rewards = [self.n_step_buffer[i][3] for i in range(len(self.n_step_buffer))]
172
+ n_step_return = sum([r * (self.gamma**i) for i, r in enumerate(rewards)])
173
+ obs, _, action, _, _, _ = self.n_step_buffer[0]
174
+ _, next_obs, _, _, done, _ = self.n_step_buffer[-1]
175
+
176
+ # Store the n-step transition
177
+ self.observations[self.next_index] = obs
178
+ self.next_observations[self.next_index] = next_obs
179
+ self.actions[self.next_index] = action
180
+ self.rewards[self.next_index] = n_step_return
181
+ self.dones[self.next_index] = done
182
+
183
+ # Get the max priority in the tree and set the new transition with max priority
184
+ self.sum_tree.update(self.next_index)
185
+ self.next_index = (self.next_index + 1) % self.size
186
+
187
+ if dones[0]:
188
+ self.n_step_buffer.clear()
189
+
190
+ def sample(self, batch_size):
191
+ segment = self.sum_tree.total() / batch_size
192
+ idxs = []
193
+ priorities = []
194
+ for i in range(batch_size):
195
+ a = segment * i
196
+ b = segment * (i + 1)
197
+ s = random.uniform(a, b)
198
+ idx = self.sum_tree.get(s)
199
+ idxs.append(idx)
200
+ leaf_idx = idx + self.size # Adjusting index to point to the leaf node
201
+ priorities.append(self.sum_tree.tree[leaf_idx])
202
+
203
+ priorities = torch.tensor(priorities, dtype=torch.float32, device=self.device).unsqueeze(1)
204
+ sampling_probabilities = priorities / self.sum_tree.total()
205
+ weights = (self.size * sampling_probabilities) ** (-self.beta)
206
+ weights /= weights.max() # Normalize for stability
207
+
208
+ data = SimpleNamespace(
209
+ observations=torch.from_numpy(self.observations[idxs]).to(self.device),
210
+ next_observations=torch.from_numpy(self.next_observations[idxs]).to(self.device),
211
+ actions=torch.from_numpy(self.actions[idxs]).to(self.device),
212
+ rewards=torch.from_numpy(self.rewards[idxs]).to(self.device),
213
+ dones=torch.from_numpy(self.dones[idxs]).to(self.device),
214
+ )
215
+ return data, idxs, weights
216
+
217
+ def update_priorities(self, idxs, errors):
218
+ for idx, error in zip(idxs, errors):
219
+ priority = (abs(error) + 1e-5) ** self.alpha
220
+ self.sum_tree.update(idx, priority)
221
+
222
+ def update_beta(self, fraction):
223
+ self.beta = (1.0 - self.beta_0) * fraction + self.beta_0
224
+
225
+
226
+ class NoisyLinear(nn.Module):
227
+ def __init__(self, in_features, out_features, std_init=0.1):
228
+ super().__init__()
229
+ self.in_features = in_features
230
+ self.out_features = out_features
231
+ self.std_init = std_init
232
+
233
+ self.weight_mu = nn.Parameter(torch.Tensor(out_features, in_features))
234
+ self.weight_sigma = nn.Parameter(torch.Tensor(out_features, in_features))
235
+ self.register_buffer("weight_epsilon", torch.Tensor(out_features, in_features))
236
+
237
+ self.bias_mu = nn.Parameter(torch.Tensor(out_features))
238
+ self.bias_sigma = nn.Parameter(torch.Tensor(out_features))
239
+ self.register_buffer("bias_epsilon", torch.Tensor(out_features))
240
+
241
+ self.reset_parameters()
242
+ self.reset_noise()
243
+
244
+ def reset_parameters(self):
245
+ init.kaiming_uniform_(self.weight_mu, a=math.sqrt(5))
246
+ init.constant_(self.weight_sigma, self.std_init / math.sqrt(self.in_features))
247
+ init.constant_(self.bias_mu, 0)
248
+ init.constant_(self.bias_sigma, self.std_init / math.sqrt(self.out_features))
249
+
250
+ def reset_noise(self):
251
+ epsilon_in = self._scale_noise(self.in_features)
252
+ epsilon_out = self._scale_noise(self.out_features)
253
+ self.weight_epsilon.copy_(epsilon_out.outer(epsilon_in))
254
+ self.bias_epsilon.copy_(epsilon_out)
255
+
256
+ def _scale_noise(self, size):
257
+ x = torch.randn(size, device=self.weight_mu.device)
258
+ return x.sign().mul_(x.abs().sqrt_())
259
+
260
+ def forward(self, input):
261
+ weight = self.weight_mu + self.weight_sigma * self.weight_epsilon if self.training else self.weight_mu
262
+ bias = self.bias_mu + self.bias_sigma * self.bias_epsilon if self.training else self.bias_mu
263
+ return F.linear(input, weight, bias)
264
+
265
+
266
+ # ALGO LOGIC: initialize agent here:
267
+ class QNetwork(nn.Module):
268
+ def __init__(self, env, n_atoms=101, v_min=-100, v_max=100):
269
+ super().__init__()
270
+ self.env = env
271
+ self.n_atoms = n_atoms
272
+ self.register_buffer("atoms", torch.linspace(v_min, v_max, steps=n_atoms))
273
+ self.n = env.single_action_space.n
274
+
275
+ self.shared_layers = nn.Sequential(
276
+ nn.Conv2d(4, 32, 8, stride=4),
277
+ nn.ReLU(),
278
+ nn.Conv2d(32, 64, 4, stride=2),
279
+ nn.ReLU(),
280
+ nn.Conv2d(64, 64, 3, stride=1),
281
+ nn.ReLU(),
282
+ nn.Flatten(),
283
+ )
284
+ self.value_stream = nn.Sequential(NoisyLinear(3136, 512), nn.ReLU(), NoisyLinear(512, n_atoms))
285
+ self.advantage_stream = nn.Sequential(NoisyLinear(3136, 512), nn.ReLU(), NoisyLinear(512, self.n * n_atoms))
286
+
287
+ def reset_noise(self):
288
+ for module in self.modules():
289
+ if isinstance(module, NoisyLinear):
290
+ module.reset_noise()
291
+
292
+ def get_action(self, obs):
293
+ q_values_distributions = self.get_distribution(obs)
294
+ q_values = (torch.softmax(q_values_distributions, dim=2) * self.atoms).sum(2)
295
+ return torch.argmax(q_values, 1)
296
+
297
+ def get_distribution(self, obs):
298
+ x = self.shared_layers(obs / 255.0)
299
+ value = self.value_stream(x).view(-1, 1, self.n_atoms)
300
+ advantages = self.advantage_stream(x).view(-1, self.n, self.n_atoms)
301
+ return value + (advantages - advantages.mean(dim=1, keepdim=True))
302
+
303
+
304
+ if __name__ == "__main__":
305
+ import stable_baselines3 as sb3
306
+
307
+ if sb3.__version__ < "2.0":
308
+ raise ValueError(
309
+ """Ongoing migration: run the following command to install the new dependencies:
310
+
311
+ poetry run pip install "stable_baselines3==2.0.0a1" "gymnasium[atari,accept-rom-license]==0.28.1" "ale-py==0.8.1"
312
+ """
313
+ )
314
+ args = parse_args()
315
+ run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
316
+ if args.track:
317
+ import wandb
318
+
319
+ wandb.init(
320
+ project=args.wandb_project_name,
321
+ entity=args.wandb_entity,
322
+ sync_tensorboard=True,
323
+ config=vars(args),
324
+ name=run_name,
325
+ monitor_gym=True,
326
+ save_code=True,
327
+ )
328
+ writer = SummaryWriter(f"runs/{run_name}")
329
+ writer.add_text(
330
+ "hyperparameters",
331
+ "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
332
+ )
333
+
334
+ # TRY NOT TO MODIFY: seeding
335
+ random.seed(args.seed)
336
+ np.random.seed(args.seed)
337
+ torch.manual_seed(args.seed)
338
+ torch.backends.cudnn.deterministic = args.torch_deterministic
339
+
340
+ device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
341
+
342
+ # env setup
343
+ envs = gym.vector.SyncVectorEnv(
344
+ [make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)]
345
+ )
346
+ assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"
347
+
348
+ q_network = QNetwork(envs, n_atoms=args.n_atoms, v_min=args.v_min, v_max=args.v_max).to(device)
349
+ optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate, eps=0.01 / args.batch_size)
350
+ target_network = QNetwork(envs, n_atoms=args.n_atoms, v_min=args.v_min, v_max=args.v_max).to(device)
351
+ target_network.load_state_dict(q_network.state_dict())
352
+
353
+ rb = PrioritizedReplayBuffer(args.buffer_size, device)
354
+ start_time = time.time()
355
+
356
+ # TRY NOT TO MODIFY: start the game
357
+ obs, _ = envs.reset(seed=args.seed)
358
+ for global_step in range(args.total_timesteps):
359
+ # ALGO LOGIC: put action logic here
360
+ actions = q_network.get_action(torch.Tensor(obs).to(device))
361
+ actions = actions.cpu().numpy()
362
+
363
+ # TRY NOT TO MODIFY: execute the game and log data.
364
+ next_obs, rewards, terminations, truncations, infos = envs.step(actions)
365
+
366
+ # TRY NOT TO MODIFY: record rewards for plotting purposes
367
+ if "final_info" in infos:
368
+ for info in infos["final_info"]:
369
+ # Skip the envs that are not done
370
+ if "episode" not in info:
371
+ continue
372
+ print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
373
+ writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
374
+ writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
375
+ break
376
+
377
+ # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
378
+ real_next_obs = next_obs.copy()
379
+ for idx, trunc in enumerate(truncations):
380
+ if trunc:
381
+ real_next_obs[idx] = infos["final_observation"][idx]
382
+ rb.add(obs, real_next_obs, actions, rewards, terminations, infos)
383
+
384
+ # TRY NOT TO MODIFY: CRUCIAL step easy to overlook
385
+ obs = next_obs
386
+
387
+ # ALGO LOGIC: training.
388
+ if global_step > args.learning_starts:
389
+ if global_step % args.train_frequency == 0:
390
+ data, idxs, weights = rb.sample(args.batch_size)
391
+
392
+ # Combine observations for a single network call
393
+ combined_obs = torch.cat([data.observations, data.next_observations], dim=0)
394
+ combined_dist = q_network.get_distribution(combined_obs)
395
+ dist, next_dist = combined_dist.split(len(data.observations), dim=0)
396
+
397
+ with torch.no_grad():
398
+ next_q_values = (torch.softmax(next_dist, dim=2) * q_network.atoms).sum(2)
399
+ next_actions = torch.argmax(next_q_values, 1)
400
+ target_next_dist = target_network.get_distribution(data.next_observations)
401
+ next_pmfs = torch.softmax(target_next_dist[torch.arange(len(data.next_observations)), next_actions], dim=1)
402
+ next_atoms = data.rewards + args.gamma * target_network.atoms * (1 - data.dones.float())
403
+ # projection
404
+ delta_z = target_network.atoms[1] - target_network.atoms[0]
405
+ tz = next_atoms.clamp(args.v_min, args.v_max)
406
+
407
+ b = (tz - args.v_min) / delta_z
408
+ l = b.floor().clamp(0, args.n_atoms - 1)
409
+ u = b.ceil().clamp(0, args.n_atoms - 1)
410
+ # (l == u).float() handles the case where bj is exactly an integer
411
+ # example bj = 1, then the upper ceiling should be uj= 2, and lj= 1
412
+ d_m_l = (u + (l == u).float() - b) * next_pmfs
413
+ d_m_u = (b - l) * next_pmfs
414
+ target_pmfs = torch.zeros_like(next_pmfs)
415
+ for i in range(target_pmfs.size(0)):
416
+ target_pmfs[i].index_add_(0, l[i].long(), d_m_l[i])
417
+ target_pmfs[i].index_add_(0, u[i].long(), d_m_u[i])
418
+
419
+ old_pmfs = torch.softmax(dist[torch.arange(len(data.observations)), data.actions.flatten()], dim=1)
420
+
421
+ expected_old_q = (old_pmfs.detach() * q_network.atoms).sum(-1)
422
+ expected_target_q = (target_pmfs * target_network.atoms).sum(-1)
423
+ td_error = expected_target_q - expected_old_q
424
+ rb.update_priorities(idxs, td_error.abs().cpu().numpy())
425
+ rb.update_beta(global_step / args.total_timesteps)
426
+
427
+ loss = (weights * -(target_pmfs * old_pmfs.clamp(min=1e-5, max=1 - 1e-5).log())).sum(-1).mean()
428
+
429
+ if global_step % 100 == 0:
430
+ writer.add_scalar("losses/loss", loss.item(), global_step)
431
+ writer.add_scalar("losses/q_values", expected_old_q.mean().item(), global_step)
432
+ print("SPS:", int(global_step / (time.time() - start_time)))
433
+ writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
434
+
435
+ # optimize the model
436
+ optimizer.zero_grad()
437
+ loss.backward()
438
+ optimizer.step()
439
+ q_network.reset_noise()
440
+
441
+ # update target network
442
+ if global_step % args.target_network_frequency == 0:
443
+ target_network.load_state_dict(q_network.state_dict())
444
+
445
+ if args.save_model:
446
+ model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
447
+ model_data = {
448
+ "model_weights": q_network.state_dict(),
449
+ "args": vars(args),
450
+ }
451
+ torch.save(model_data, model_path)
452
+ print(f"model saved to {model_path}")
453
+ from cleanrl_utils.evals.rainbow_eval import evaluate
454
+
455
+ episodic_returns = evaluate(
456
+ model_path,
457
+ make_env,
458
+ args.env_id,
459
+ eval_episodes=10,
460
+ run_name=f"{run_name}-eval",
461
+ Model=QNetwork,
462
+ device=device,
463
+ )
464
+ for idx, episodic_return in enumerate(episodic_returns):
465
+ writer.add_scalar("eval/episodic_return", episodic_return, idx)
466
+
467
+ if args.upload_model:
468
+ from cleanrl_utils.huggingface import push_to_hub
469
+
470
+ repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}"
471
+ repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name
472
+ push_to_hub(args, episodic_returns, repo_id, "RAINBOW", f"runs/{run_name}", f"videos/{run_name}-eval")
473
+
474
+ envs.close()
475
+ writer.close()
replay.mp4 ADDED
Binary file (22 kB). View file
 
videos/BreakoutNoFrameskip-v4__rainbow_atari__1__1700431636-eval/rl-video-episode-0.mp4 ADDED
Binary file (22 kB). View file
 
videos/BreakoutNoFrameskip-v4__rainbow_atari__1__1700431636-eval/rl-video-episode-1.mp4 ADDED
Binary file (69 kB). View file
 
videos/BreakoutNoFrameskip-v4__rainbow_atari__1__1700431636-eval/rl-video-episode-8.mp4 ADDED
Binary file (22 kB). View file