qgallouedec's picture
qgallouedec HF staff
pushing model
61c9b91
# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/c51/#c51_ataripy
import argparse
import math
import os
import random
import time
from collections import deque
from distutils.util import strtobool
from types import SimpleNamespace
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optim
from stable_baselines3.common.atari_wrappers import ClipRewardEnv, EpisodicLifeEnv, FireResetEnv, MaxAndSkipEnv, NoopResetEnv
from torch.utils.tensorboard import SummaryWriter
def parse_args():
# fmt: off
parser = argparse.ArgumentParser()
parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"),
help="the name of this experiment")
parser.add_argument("--seed", type=int, default=1,
help="seed of the experiment")
parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="if toggled, `torch.backends.cudnn.deterministic=False`")
parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="if toggled, cuda will be enabled by default")
parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="if toggled, this experiment will be tracked with Weights and Biases")
parser.add_argument("--wandb-project-name", type=str, default="cleanRL",
help="the wandb's project name")
parser.add_argument("--wandb-entity", type=str, default=None,
help="the entity (team) of wandb's project")
parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to capture videos of the agent performances (check out `videos` folder)")
parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to save model into the `runs/{run_name}` folder")
parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to upload the saved model to huggingface")
parser.add_argument("--hf-entity", type=str, default="",
help="the user or org name of the model repository from the Hugging Face Hub")
# Algorithm specific arguments
parser.add_argument("--env-id", type=str, default="BreakoutNoFrameskip-v4",
help="the id of the environment")
parser.add_argument("--total-timesteps", type=int, default=10000000,
help="total timesteps of the experiments")
parser.add_argument("--learning-rate", type=float, default=2.5e-4,
help="the learning rate of the optimizer")
parser.add_argument("--num-envs", type=int, default=1,
help="the number of parallel game environments")
parser.add_argument("--n-atoms", type=int, default=51,
help="the number of atoms")
parser.add_argument("--v-min", type=float, default=-10,
help="the return lower bound")
parser.add_argument("--v-max", type=float, default=10,
help="the return upper bound")
parser.add_argument("--buffer-size", type=int, default=1000000,
help="the replay memory buffer size")
parser.add_argument("--gamma", type=float, default=0.99,
help="the discount factor gamma")
parser.add_argument("--target-network-frequency", type=int, default=10000,
help="the timesteps it takes to update the target network")
parser.add_argument("--batch-size", type=int, default=32,
help="the batch size of sample from the reply memory")
parser.add_argument("--learning-starts", type=int, default=80000,
help="timestep to start learning")
parser.add_argument("--train-frequency", type=int, default=4,
help="the frequency of training")
args = parser.parse_args()
# fmt: on
assert args.num_envs == 1, "vectorized envs are not supported at the moment"
return args
def make_env(env_id, seed, idx, capture_video, run_name):
def thunk():
if capture_video and idx == 0:
env = gym.make(env_id, render_mode="rgb_array")
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
else:
env = gym.make(env_id)
env = gym.wrappers.RecordEpisodeStatistics(env)
env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
env = EpisodicLifeEnv(env)
if "FIRE" in env.unwrapped.get_action_meanings():
env = FireResetEnv(env)
env = ClipRewardEnv(env)
env = gym.wrappers.ResizeObservation(env, (84, 84))
env = gym.wrappers.GrayScaleObservation(env)
env = gym.wrappers.FrameStack(env, 4)
env.action_space.seed(seed)
return env
return thunk
class SumTree:
def __init__(self, capacity):
self.capacity = capacity # Capacity of the sum tree (number of leaves)
self.tree = [0] * (2 * capacity) # Binary tree representation
self.max_priority = 1.0 # Initial max priority for new experiences
def update(self, index, priority=None):
if priority is None:
priority = self.max_priority
tree_idx = index + self.capacity
change = priority - self.tree[tree_idx]
self.tree[tree_idx] = priority
self._propagate(tree_idx, change)
self.max_priority = max(self.max_priority, priority)
def _propagate(self, idx, change):
parent = idx // 2
while parent != 0:
self.tree[parent] += change
parent = parent // 2
def total(self):
return self.tree[1] # The root of the tree holds the total sum
def get(self, s):
idx = 1
while idx < self.capacity: # Keep moving down the tree to find the index
left = 2 * idx
right = left + 1
if self.tree[left] >= s:
idx = left
else:
s -= self.tree[left]
idx = right
return idx - self.capacity
class PrioritizedReplayBuffer:
def __init__(self, size, device, alpha=0.5, beta_0=0.4, n_step=3, gamma=0.99):
self.size = size
self.device = device
self.alpha = alpha
self.beta_0 = beta_0
self.update_beta(0.0)
self.n_step = n_step
self.gamma = gamma
self.next_index = 0
self.sum_tree = SumTree(size)
self.observations = np.zeros((self.size, 4, 84, 84), dtype=np.uint8)
self.next_observations = np.zeros((self.size, 4, 84, 84), dtype=np.uint8)
self.actions = np.zeros((self.size, 1), dtype=np.int64)
self.rewards = np.zeros((self.size, 1), dtype=np.float32)
self.dones = np.zeros((self.size, 1), dtype=bool)
self.n_step_buffer = deque(maxlen=n_step)
def add(self, obs, next_obs, actions, rewards, dones, infos):
self.n_step_buffer.append((obs[0], next_obs[0], actions[0], rewards[0], dones[0], infos))
if len(self.n_step_buffer) < self.n_step and not dones[0]:
return
# Compute n-step return and the first state and action
rewards = [self.n_step_buffer[i][3] for i in range(len(self.n_step_buffer))]
n_step_return = sum([r * (self.gamma**i) for i, r in enumerate(rewards)])
obs, _, action, _, _, _ = self.n_step_buffer[0]
_, next_obs, _, _, done, _ = self.n_step_buffer[-1]
# Store the n-step transition
self.observations[self.next_index] = obs
self.next_observations[self.next_index] = next_obs
self.actions[self.next_index] = action
self.rewards[self.next_index] = n_step_return
self.dones[self.next_index] = done
# Get the max priority in the tree and set the new transition with max priority
self.sum_tree.update(self.next_index)
self.next_index = (self.next_index + 1) % self.size
if dones[0]:
self.n_step_buffer.clear()
def sample(self, batch_size):
segment = self.sum_tree.total() / batch_size
idxs = []
priorities = []
for i in range(batch_size):
a = segment * i
b = segment * (i + 1)
s = random.uniform(a, b)
idx = self.sum_tree.get(s)
idxs.append(idx)
leaf_idx = idx + self.size # Adjusting index to point to the leaf node
priorities.append(self.sum_tree.tree[leaf_idx])
priorities = torch.tensor(priorities, dtype=torch.float32, device=self.device).unsqueeze(1)
sampling_probabilities = priorities / self.sum_tree.total()
weights = (self.size * sampling_probabilities) ** (-self.beta)
weights /= weights.max() # Normalize for stability
data = SimpleNamespace(
observations=torch.from_numpy(self.observations[idxs]).to(self.device),
next_observations=torch.from_numpy(self.next_observations[idxs]).to(self.device),
actions=torch.from_numpy(self.actions[idxs]).to(self.device),
rewards=torch.from_numpy(self.rewards[idxs]).to(self.device),
dones=torch.from_numpy(self.dones[idxs]).to(self.device),
)
return data, idxs, weights
def update_priorities(self, idxs, errors):
for idx, error in zip(idxs, errors):
priority = (abs(error) + 1e-5) ** self.alpha
self.sum_tree.update(idx, priority)
def update_beta(self, fraction):
self.beta = (1.0 - self.beta_0) * fraction + self.beta_0
class NoisyLinear(nn.Module):
def __init__(self, in_features, out_features, std_init=0.1):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.std_init = std_init
self.weight_mu = nn.Parameter(torch.Tensor(out_features, in_features))
self.weight_sigma = nn.Parameter(torch.Tensor(out_features, in_features))
self.register_buffer("weight_epsilon", torch.Tensor(out_features, in_features))
self.bias_mu = nn.Parameter(torch.Tensor(out_features))
self.bias_sigma = nn.Parameter(torch.Tensor(out_features))
self.register_buffer("bias_epsilon", torch.Tensor(out_features))
self.reset_parameters()
self.reset_noise()
def reset_parameters(self):
init.kaiming_uniform_(self.weight_mu, a=math.sqrt(5))
init.constant_(self.weight_sigma, self.std_init / math.sqrt(self.in_features))
init.constant_(self.bias_mu, 0)
init.constant_(self.bias_sigma, self.std_init / math.sqrt(self.out_features))
def reset_noise(self):
epsilon_in = self._scale_noise(self.in_features)
epsilon_out = self._scale_noise(self.out_features)
self.weight_epsilon.copy_(epsilon_out.outer(epsilon_in))
self.bias_epsilon.copy_(epsilon_out)
def _scale_noise(self, size):
x = torch.randn(size, device=self.weight_mu.device)
return x.sign().mul_(x.abs().sqrt_())
def forward(self, input):
weight = self.weight_mu + self.weight_sigma * self.weight_epsilon if self.training else self.weight_mu
bias = self.bias_mu + self.bias_sigma * self.bias_epsilon if self.training else self.bias_mu
return F.linear(input, weight, bias)
# ALGO LOGIC: initialize agent here:
class QNetwork(nn.Module):
def __init__(self, env, n_atoms=101, v_min=-100, v_max=100):
super().__init__()
self.env = env
self.n_atoms = n_atoms
self.register_buffer("atoms", torch.linspace(v_min, v_max, steps=n_atoms))
self.n = env.single_action_space.n
self.shared_layers = nn.Sequential(
nn.Conv2d(4, 32, 8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, 4, stride=2),
nn.ReLU(),
nn.Conv2d(64, 64, 3, stride=1),
nn.ReLU(),
nn.Flatten(),
)
self.value_stream = nn.Sequential(NoisyLinear(3136, 512), nn.ReLU(), NoisyLinear(512, n_atoms))
self.advantage_stream = nn.Sequential(NoisyLinear(3136, 512), nn.ReLU(), NoisyLinear(512, self.n * n_atoms))
def reset_noise(self):
for module in self.modules():
if isinstance(module, NoisyLinear):
module.reset_noise()
def get_action(self, obs):
q_values_distributions = self.get_distribution(obs)
q_values = (torch.softmax(q_values_distributions, dim=2) * self.atoms).sum(2)
return torch.argmax(q_values, 1)
def get_distribution(self, obs):
x = self.shared_layers(obs / 255.0)
value = self.value_stream(x).view(-1, 1, self.n_atoms)
advantages = self.advantage_stream(x).view(-1, self.n, self.n_atoms)
return value + (advantages - advantages.mean(dim=1, keepdim=True))
if __name__ == "__main__":
import stable_baselines3 as sb3
if sb3.__version__ < "2.0":
raise ValueError(
"""Ongoing migration: run the following command to install the new dependencies:
poetry run pip install "stable_baselines3==2.0.0a1" "gymnasium[atari,accept-rom-license]==0.28.1" "ale-py==0.8.1"
"""
)
args = parse_args()
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
if args.track:
import wandb
wandb.init(
project=args.wandb_project_name,
entity=args.wandb_entity,
sync_tensorboard=True,
config=vars(args),
name=run_name,
monitor_gym=True,
save_code=True,
)
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
# TRY NOT TO MODIFY: seeding
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
# env setup
envs = gym.vector.SyncVectorEnv(
[make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)]
)
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"
q_network = QNetwork(envs, n_atoms=args.n_atoms, v_min=args.v_min, v_max=args.v_max).to(device)
optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate, eps=0.01 / args.batch_size)
target_network = QNetwork(envs, n_atoms=args.n_atoms, v_min=args.v_min, v_max=args.v_max).to(device)
target_network.load_state_dict(q_network.state_dict())
rb = PrioritizedReplayBuffer(args.buffer_size, device)
start_time = time.time()
# TRY NOT TO MODIFY: start the game
obs, _ = envs.reset(seed=args.seed)
for global_step in range(args.total_timesteps):
# ALGO LOGIC: put action logic here
actions = q_network.get_action(torch.Tensor(obs).to(device))
actions = actions.cpu().numpy()
# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, terminations, truncations, infos = envs.step(actions)
# TRY NOT TO MODIFY: record rewards for plotting purposes
if "final_info" in infos:
for info in infos["final_info"]:
# Skip the envs that are not done
if "episode" not in info:
continue
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
break
# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.copy()
for idx, trunc in enumerate(truncations):
if trunc:
real_next_obs[idx] = infos["final_observation"][idx]
rb.add(obs, real_next_obs, actions, rewards, terminations, infos)
# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
# ALGO LOGIC: training.
if global_step > args.learning_starts:
if global_step % args.train_frequency == 0:
data, idxs, weights = rb.sample(args.batch_size)
# Combine observations for a single network call
combined_obs = torch.cat([data.observations, data.next_observations], dim=0)
combined_dist = q_network.get_distribution(combined_obs)
dist, next_dist = combined_dist.split(len(data.observations), dim=0)
with torch.no_grad():
next_q_values = (torch.softmax(next_dist, dim=2) * q_network.atoms).sum(2)
next_actions = torch.argmax(next_q_values, 1)
target_next_dist = target_network.get_distribution(data.next_observations)
next_pmfs = torch.softmax(target_next_dist[torch.arange(len(data.next_observations)), next_actions], dim=1)
next_atoms = data.rewards + args.gamma * target_network.atoms * (1 - data.dones.float())
# projection
delta_z = target_network.atoms[1] - target_network.atoms[0]
tz = next_atoms.clamp(args.v_min, args.v_max)
b = (tz - args.v_min) / delta_z
l = b.floor().clamp(0, args.n_atoms - 1)
u = b.ceil().clamp(0, args.n_atoms - 1)
# (l == u).float() handles the case where bj is exactly an integer
# example bj = 1, then the upper ceiling should be uj= 2, and lj= 1
d_m_l = (u + (l == u).float() - b) * next_pmfs
d_m_u = (b - l) * next_pmfs
target_pmfs = torch.zeros_like(next_pmfs)
for i in range(target_pmfs.size(0)):
target_pmfs[i].index_add_(0, l[i].long(), d_m_l[i])
target_pmfs[i].index_add_(0, u[i].long(), d_m_u[i])
old_pmfs = torch.softmax(dist[torch.arange(len(data.observations)), data.actions.flatten()], dim=1)
expected_old_q = (old_pmfs.detach() * q_network.atoms).sum(-1)
expected_target_q = (target_pmfs * target_network.atoms).sum(-1)
td_error = expected_target_q - expected_old_q
rb.update_priorities(idxs, td_error.abs().cpu().numpy())
rb.update_beta(global_step / args.total_timesteps)
loss = (weights * -(target_pmfs * old_pmfs.clamp(min=1e-5, max=1 - 1e-5).log())).sum(-1).mean()
if global_step % 100 == 0:
writer.add_scalar("losses/loss", loss.item(), global_step)
writer.add_scalar("losses/q_values", expected_old_q.mean().item(), global_step)
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
# optimize the model
optimizer.zero_grad()
loss.backward()
optimizer.step()
q_network.reset_noise()
# update target network
if global_step % args.target_network_frequency == 0:
target_network.load_state_dict(q_network.state_dict())
if args.save_model:
model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
model_data = {
"model_weights": q_network.state_dict(),
"args": vars(args),
}
torch.save(model_data, model_path)
print(f"model saved to {model_path}")
from cleanrl_utils.evals.rainbow_eval import evaluate
episodic_returns = evaluate(
model_path,
make_env,
args.env_id,
eval_episodes=10,
run_name=f"{run_name}-eval",
Model=QNetwork,
device=device,
)
for idx, episodic_return in enumerate(episodic_returns):
writer.add_scalar("eval/episodic_return", episodic_return, idx)
if args.upload_model:
from cleanrl_utils.huggingface import push_to_hub
repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}"
repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name
push_to_hub(args, episodic_returns, repo_id, "RAINBOW", f"runs/{run_name}", f"videos/{run_name}-eval")
envs.close()
writer.close()