|
|
|
import argparse |
|
import os |
|
import random |
|
import time |
|
from distutils.util import strtobool |
|
from functools import partial |
|
from typing import Sequence |
|
|
|
os.environ[ |
|
"XLA_PYTHON_CLIENT_MEM_FRACTION" |
|
] = "0.7" |
|
|
|
import envpool |
|
import flax |
|
import flax.linen as nn |
|
import gym |
|
import jax |
|
import jax.numpy as jnp |
|
import numpy as np |
|
import optax |
|
from flax.linen.initializers import constant, orthogonal |
|
from flax.training.train_state import TrainState |
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
|
|
def parse_args(): |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), |
|
help="the name of this experiment") |
|
parser.add_argument("--seed", type=int, default=1, |
|
help="seed of the experiment") |
|
parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, |
|
help="if toggled, `torch.backends.cudnn.deterministic=False`") |
|
parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, |
|
help="if toggled, cuda will be enabled by default") |
|
parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, |
|
help="if toggled, this experiment will be tracked with Weights and Biases") |
|
parser.add_argument("--wandb-project-name", type=str, default="cleanRL", |
|
help="the wandb's project name") |
|
parser.add_argument("--wandb-entity", type=str, default=None, |
|
help="the entity (team) of wandb's project") |
|
parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, |
|
help="whether to capture videos of the agent performances (check out `videos` folder)") |
|
parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, |
|
help="whether to save model into the `runs/{run_name}` folder") |
|
parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, |
|
help="whether to upload the saved model to huggingface") |
|
parser.add_argument("--hf-entity", type=str, default="", |
|
help="the user or org name of the model repository from the Hugging Face Hub") |
|
|
|
|
|
parser.add_argument("--env-id", type=str, default="Pong-v5", |
|
help="the id of the environment") |
|
parser.add_argument("--total-timesteps", type=int, default=10000000, |
|
help="total timesteps of the experiments") |
|
parser.add_argument("--learning-rate", type=float, default=2.5e-4, |
|
help="the learning rate of the optimizer") |
|
parser.add_argument("--num-envs", type=int, default=8, |
|
help="the number of parallel game environments") |
|
parser.add_argument("--num-steps", type=int, default=128, |
|
help="the number of steps to run in each environment per policy rollout") |
|
parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, |
|
help="Toggle learning rate annealing for policy and value networks") |
|
parser.add_argument("--gamma", type=float, default=0.99, |
|
help="the discount factor gamma") |
|
parser.add_argument("--gae-lambda", type=float, default=0.95, |
|
help="the lambda for the general advantage estimation") |
|
parser.add_argument("--num-minibatches", type=int, default=4, |
|
help="the number of mini-batches") |
|
parser.add_argument("--update-epochs", type=int, default=4, |
|
help="the K epochs to update the policy") |
|
parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, |
|
help="Toggles advantages normalization") |
|
parser.add_argument("--clip-coef", type=float, default=0.1, |
|
help="the surrogate clipping coefficient") |
|
parser.add_argument("--ent-coef", type=float, default=0.01, |
|
help="coefficient of the entropy") |
|
parser.add_argument("--vf-coef", type=float, default=0.5, |
|
help="coefficient of the value function") |
|
parser.add_argument("--max-grad-norm", type=float, default=0.5, |
|
help="the maximum norm for the gradient clipping") |
|
parser.add_argument("--target-kl", type=float, default=None, |
|
help="the target KL divergence threshold") |
|
args = parser.parse_args() |
|
args.batch_size = int(args.num_envs * args.num_steps) |
|
args.minibatch_size = int(args.batch_size // args.num_minibatches) |
|
args.num_updates = args.total_timesteps // args.batch_size |
|
|
|
return args |
|
|
|
|
|
def make_env(env_id, seed, num_envs): |
|
def thunk(): |
|
envs = envpool.make( |
|
env_id, |
|
env_type="gym", |
|
num_envs=num_envs, |
|
episodic_life=True, |
|
reward_clip=True, |
|
seed=seed, |
|
) |
|
envs.num_envs = num_envs |
|
envs.single_action_space = envs.action_space |
|
envs.single_observation_space = envs.observation_space |
|
envs.is_vector_env = True |
|
return envs |
|
|
|
return thunk |
|
|
|
|
|
class Network(nn.Module): |
|
@nn.compact |
|
def __call__(self, x): |
|
x = jnp.transpose(x, (0, 2, 3, 1)) |
|
x = x / (255.0) |
|
x = nn.Conv( |
|
32, |
|
kernel_size=(8, 8), |
|
strides=(4, 4), |
|
padding="VALID", |
|
kernel_init=orthogonal(np.sqrt(2)), |
|
bias_init=constant(0.0), |
|
)(x) |
|
x = nn.relu(x) |
|
x = nn.Conv( |
|
64, |
|
kernel_size=(4, 4), |
|
strides=(2, 2), |
|
padding="VALID", |
|
kernel_init=orthogonal(np.sqrt(2)), |
|
bias_init=constant(0.0), |
|
)(x) |
|
x = nn.relu(x) |
|
x = nn.Conv( |
|
64, |
|
kernel_size=(3, 3), |
|
strides=(1, 1), |
|
padding="VALID", |
|
kernel_init=orthogonal(np.sqrt(2)), |
|
bias_init=constant(0.0), |
|
)(x) |
|
x = nn.relu(x) |
|
x = x.reshape((x.shape[0], -1)) |
|
x = nn.Dense(512, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(x) |
|
x = nn.relu(x) |
|
return x |
|
|
|
|
|
class Critic(nn.Module): |
|
@nn.compact |
|
def __call__(self, x): |
|
return nn.Dense(1, kernel_init=orthogonal(1), bias_init=constant(0.0))(x) |
|
|
|
|
|
class Actor(nn.Module): |
|
action_dim: Sequence[int] |
|
|
|
@nn.compact |
|
def __call__(self, x): |
|
return nn.Dense(self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0))(x) |
|
|
|
|
|
@flax.struct.dataclass |
|
class AgentParams: |
|
network_params: flax.core.FrozenDict |
|
actor_params: flax.core.FrozenDict |
|
critic_params: flax.core.FrozenDict |
|
|
|
|
|
@flax.struct.dataclass |
|
class Storage: |
|
obs: jnp.array |
|
actions: jnp.array |
|
logprobs: jnp.array |
|
dones: jnp.array |
|
values: jnp.array |
|
advantages: jnp.array |
|
returns: jnp.array |
|
rewards: jnp.array |
|
|
|
|
|
@flax.struct.dataclass |
|
class EpisodeStatistics: |
|
episode_returns: jnp.array |
|
episode_lengths: jnp.array |
|
returned_episode_returns: jnp.array |
|
returned_episode_lengths: jnp.array |
|
|
|
|
|
if __name__ == "__main__": |
|
args = parse_args() |
|
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" |
|
if args.track: |
|
import wandb |
|
|
|
wandb.init( |
|
project=args.wandb_project_name, |
|
entity=args.wandb_entity, |
|
sync_tensorboard=True, |
|
config=vars(args), |
|
name=run_name, |
|
monitor_gym=True, |
|
save_code=True, |
|
) |
|
writer = SummaryWriter(f"runs/{run_name}") |
|
writer.add_text( |
|
"hyperparameters", |
|
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), |
|
) |
|
|
|
|
|
random.seed(args.seed) |
|
np.random.seed(args.seed) |
|
key = jax.random.PRNGKey(args.seed) |
|
key, network_key, actor_key, critic_key = jax.random.split(key, 4) |
|
|
|
|
|
envs = make_env(args.env_id, args.seed, args.num_envs)() |
|
episode_stats = EpisodeStatistics( |
|
episode_returns=jnp.zeros(args.num_envs, dtype=jnp.float32), |
|
episode_lengths=jnp.zeros(args.num_envs, dtype=jnp.int32), |
|
returned_episode_returns=jnp.zeros(args.num_envs, dtype=jnp.float32), |
|
returned_episode_lengths=jnp.zeros(args.num_envs, dtype=jnp.int32), |
|
) |
|
handle, recv, send, step_env = envs.xla() |
|
|
|
def step_env_wrappeed(episode_stats, handle, action): |
|
handle, (next_obs, reward, next_done, info) = step_env(handle, action) |
|
new_episode_return = episode_stats.episode_returns + info["reward"] |
|
new_episode_length = episode_stats.episode_lengths + 1 |
|
episode_stats = episode_stats.replace( |
|
episode_returns=(new_episode_return) * (1 - info["terminated"]) * (1 - info["TimeLimit.truncated"]), |
|
episode_lengths=(new_episode_length) * (1 - info["terminated"]) * (1 - info["TimeLimit.truncated"]), |
|
|
|
returned_episode_returns=jnp.where( |
|
info["terminated"] + info["TimeLimit.truncated"], new_episode_return, episode_stats.returned_episode_returns |
|
), |
|
returned_episode_lengths=jnp.where( |
|
info["terminated"] + info["TimeLimit.truncated"], new_episode_length, episode_stats.returned_episode_lengths |
|
), |
|
) |
|
return episode_stats, handle, (next_obs, reward, next_done, info) |
|
|
|
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" |
|
|
|
def linear_schedule(count): |
|
|
|
|
|
frac = 1.0 - (count // (args.num_minibatches * args.update_epochs)) / args.num_updates |
|
return args.learning_rate * frac |
|
|
|
network = Network() |
|
actor = Actor(action_dim=envs.single_action_space.n) |
|
critic = Critic() |
|
network_params = network.init(network_key, np.array([envs.single_observation_space.sample()])) |
|
agent_state = TrainState.create( |
|
apply_fn=None, |
|
params=AgentParams( |
|
network_params, |
|
actor.init(actor_key, network.apply(network_params, np.array([envs.single_observation_space.sample()]))), |
|
critic.init(critic_key, network.apply(network_params, np.array([envs.single_observation_space.sample()]))), |
|
), |
|
tx=optax.chain( |
|
optax.clip_by_global_norm(args.max_grad_norm), |
|
optax.inject_hyperparams(optax.adam)( |
|
learning_rate=linear_schedule if args.anneal_lr else args.learning_rate, eps=1e-5 |
|
), |
|
), |
|
) |
|
network.apply = jax.jit(network.apply) |
|
actor.apply = jax.jit(actor.apply) |
|
critic.apply = jax.jit(critic.apply) |
|
|
|
@jax.jit |
|
def get_action_and_value( |
|
agent_state: TrainState, |
|
next_obs: np.ndarray, |
|
key: jax.random.PRNGKey, |
|
): |
|
"""sample action, calculate value, logprob, entropy, and update storage""" |
|
hidden = network.apply(agent_state.params.network_params, next_obs) |
|
logits = actor.apply(agent_state.params.actor_params, hidden) |
|
|
|
|
|
key, subkey = jax.random.split(key) |
|
u = jax.random.uniform(subkey, shape=logits.shape) |
|
action = jnp.argmax(logits - jnp.log(-jnp.log(u)), axis=1) |
|
logprob = jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action] |
|
value = critic.apply(agent_state.params.critic_params, hidden) |
|
return action, logprob, value.squeeze(1), key |
|
|
|
@jax.jit |
|
def get_action_and_value2( |
|
params: flax.core.FrozenDict, |
|
x: np.ndarray, |
|
action: np.ndarray, |
|
): |
|
"""calculate value, logprob of supplied `action`, and entropy""" |
|
hidden = network.apply(params.network_params, x) |
|
logits = actor.apply(params.actor_params, hidden) |
|
logprob = jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action] |
|
|
|
logits = logits - jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True) |
|
logits = logits.clip(min=jnp.finfo(logits.dtype).min) |
|
p_log_p = logits * jax.nn.softmax(logits) |
|
entropy = -p_log_p.sum(-1) |
|
value = critic.apply(params.critic_params, hidden).squeeze() |
|
return logprob, entropy, value |
|
|
|
def compute_gae_once(carry, inp, gamma, gae_lambda): |
|
advantages = carry |
|
nextdone, nextvalues, curvalues, reward = inp |
|
nextnonterminal = 1.0 - nextdone |
|
|
|
delta = reward + gamma * nextvalues * nextnonterminal - curvalues |
|
advantages = delta + gamma * gae_lambda * nextnonterminal * advantages |
|
return advantages, advantages |
|
|
|
compute_gae_once = partial(compute_gae_once, gamma=args.gamma, gae_lambda=args.gae_lambda) |
|
|
|
@jax.jit |
|
def compute_gae( |
|
agent_state: TrainState, |
|
next_obs: np.ndarray, |
|
next_done: np.ndarray, |
|
storage: Storage, |
|
): |
|
next_value = critic.apply( |
|
agent_state.params.critic_params, network.apply(agent_state.params.network_params, next_obs) |
|
).squeeze() |
|
|
|
advantages = jnp.zeros((args.num_envs,)) |
|
dones = jnp.concatenate([storage.dones, next_done[None, :]], axis=0) |
|
values = jnp.concatenate([storage.values, next_value[None, :]], axis=0) |
|
_, advantages = jax.lax.scan( |
|
compute_gae_once, advantages, (dones[1:], values[1:], values[:-1], storage.rewards), reverse=True |
|
) |
|
storage = storage.replace( |
|
advantages=advantages, |
|
returns=advantages + storage.values, |
|
) |
|
return storage |
|
|
|
def ppo_loss(params, x, a, logp, mb_advantages, mb_returns): |
|
newlogprob, entropy, newvalue = get_action_and_value2(params, x, a) |
|
logratio = newlogprob - logp |
|
ratio = jnp.exp(logratio) |
|
approx_kl = ((ratio - 1) - logratio).mean() |
|
|
|
if args.norm_adv: |
|
mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8) |
|
|
|
|
|
pg_loss1 = -mb_advantages * ratio |
|
pg_loss2 = -mb_advantages * jnp.clip(ratio, 1 - args.clip_coef, 1 + args.clip_coef) |
|
pg_loss = jnp.maximum(pg_loss1, pg_loss2).mean() |
|
|
|
|
|
v_loss = 0.5 * ((newvalue - mb_returns) ** 2).mean() |
|
|
|
entropy_loss = entropy.mean() |
|
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef |
|
return loss, (pg_loss, v_loss, entropy_loss, jax.lax.stop_gradient(approx_kl)) |
|
|
|
ppo_loss_grad_fn = jax.value_and_grad(ppo_loss, has_aux=True) |
|
|
|
@jax.jit |
|
def update_ppo( |
|
agent_state: TrainState, |
|
storage: Storage, |
|
key: jax.random.PRNGKey, |
|
): |
|
def update_epoch(carry, unused_inp): |
|
agent_state, key = carry |
|
key, subkey = jax.random.split(key) |
|
|
|
def flatten(x): |
|
return x.reshape((-1,) + x.shape[2:]) |
|
|
|
|
|
def convert_data(x: jnp.ndarray): |
|
x = jax.random.permutation(subkey, x) |
|
x = jnp.reshape(x, (args.num_minibatches, -1) + x.shape[1:]) |
|
return x |
|
|
|
flatten_storage = jax.tree_map(flatten, storage) |
|
shuffled_storage = jax.tree_map(convert_data, flatten_storage) |
|
|
|
def update_minibatch(agent_state, minibatch): |
|
(loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = ppo_loss_grad_fn( |
|
agent_state.params, |
|
minibatch.obs, |
|
minibatch.actions, |
|
minibatch.logprobs, |
|
minibatch.advantages, |
|
minibatch.returns, |
|
) |
|
agent_state = agent_state.apply_gradients(grads=grads) |
|
return agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl, grads) |
|
|
|
agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl, grads) = jax.lax.scan( |
|
update_minibatch, agent_state, shuffled_storage |
|
) |
|
return (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl, grads) |
|
|
|
(agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl, grads) = jax.lax.scan( |
|
update_epoch, (agent_state, key), (), length=args.update_epochs |
|
) |
|
return agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key |
|
|
|
|
|
global_step = 0 |
|
start_time = time.time() |
|
next_obs = envs.reset() |
|
next_done = jnp.zeros(args.num_envs, dtype=jax.numpy.bool_) |
|
|
|
|
|
def step_once(carry, step, env_step_fn): |
|
agent_state, episode_stats, obs, done, key, handle = carry |
|
action, logprob, value, key = get_action_and_value(agent_state, obs, key) |
|
|
|
episode_stats, handle, (next_obs, reward, next_done, _) = env_step_fn(episode_stats, handle, action) |
|
storage = Storage( |
|
obs=obs, |
|
actions=action, |
|
logprobs=logprob, |
|
dones=done, |
|
values=value, |
|
rewards=reward, |
|
returns=jnp.zeros_like(reward), |
|
advantages=jnp.zeros_like(reward), |
|
) |
|
return ((agent_state, episode_stats, next_obs, next_done, key, handle), storage) |
|
|
|
def rollout(agent_state, episode_stats, next_obs, next_done, key, handle, step_once_fn, max_steps): |
|
(agent_state, episode_stats, next_obs, next_done, key, handle), storage = jax.lax.scan( |
|
step_once_fn, (agent_state, episode_stats, next_obs, next_done, key, handle), (), max_steps |
|
) |
|
return agent_state, episode_stats, next_obs, next_done, storage, key, handle |
|
|
|
rollout = partial(rollout, step_once_fn=partial(step_once, env_step_fn=step_env_wrappeed), max_steps=args.num_steps) |
|
|
|
for update in range(1, args.num_updates + 1): |
|
update_time_start = time.time() |
|
agent_state, episode_stats, next_obs, next_done, storage, key, handle = rollout( |
|
agent_state, episode_stats, next_obs, next_done, key, handle |
|
) |
|
global_step += args.num_steps * args.num_envs |
|
storage = compute_gae(agent_state, next_obs, next_done, storage) |
|
agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key = update_ppo( |
|
agent_state, |
|
storage, |
|
key, |
|
) |
|
avg_episodic_return = np.mean(jax.device_get(episode_stats.returned_episode_returns)) |
|
print(f"global_step={global_step}, avg_episodic_return={avg_episodic_return}") |
|
|
|
|
|
writer.add_scalar("charts/avg_episodic_return", avg_episodic_return, global_step) |
|
writer.add_scalar( |
|
"charts/avg_episodic_length", np.mean(jax.device_get(episode_stats.returned_episode_lengths)), global_step |
|
) |
|
writer.add_scalar("charts/learning_rate", agent_state.opt_state[1].hyperparams["learning_rate"].item(), global_step) |
|
writer.add_scalar("losses/value_loss", v_loss[-1, -1].item(), global_step) |
|
writer.add_scalar("losses/policy_loss", pg_loss[-1, -1].item(), global_step) |
|
writer.add_scalar("losses/entropy", entropy_loss[-1, -1].item(), global_step) |
|
writer.add_scalar("losses/approx_kl", approx_kl[-1, -1].item(), global_step) |
|
writer.add_scalar("losses/loss", loss[-1, -1].item(), global_step) |
|
print("SPS:", int(global_step / (time.time() - start_time))) |
|
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) |
|
writer.add_scalar( |
|
"charts/SPS_update", int(args.num_envs * args.num_steps / (time.time() - update_time_start)), global_step |
|
) |
|
|
|
if args.save_model: |
|
model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model" |
|
with open(model_path, "wb") as f: |
|
f.write( |
|
flax.serialization.to_bytes( |
|
[ |
|
vars(args), |
|
[ |
|
agent_state.params.network_params, |
|
agent_state.params.actor_params, |
|
agent_state.params.critic_params, |
|
], |
|
] |
|
) |
|
) |
|
print(f"model saved to {model_path}") |
|
from cleanrl_utils.evals.ppo_envpool_jax_eval import evaluate |
|
|
|
episodic_returns = evaluate( |
|
model_path, |
|
make_env, |
|
args.env_id, |
|
eval_episodes=10, |
|
run_name=f"{run_name}-eval", |
|
Model=(Network, Actor, Critic), |
|
) |
|
for idx, episodic_return in enumerate(episodic_returns): |
|
writer.add_scalar("eval/episodic_return", episodic_return, idx) |
|
|
|
if args.upload_model: |
|
from cleanrl_utils.huggingface import push_to_hub |
|
|
|
repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}" |
|
repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name |
|
push_to_hub(args, episodic_returns, repo_id, "PPO", f"runs/{run_name}", f"videos/{run_name}-eval") |
|
|
|
envs.close() |
|
writer.close() |
|
|