sgoodfriend's picture
A2C playing CarRacing-v0 from https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b
fff5a3d
from typing import Optional, Sequence, Tuple, Type
import numpy as np
import torch
import torch.nn as nn
from gym.spaces import MultiDiscrete, Space
from rl_algo_impls.shared.actor import pi_forward
from rl_algo_impls.shared.actor.gridnet import GridnetDistribution
from rl_algo_impls.shared.actor.gridnet_decoder import Transpose
from rl_algo_impls.shared.module.utils import layer_init
from rl_algo_impls.shared.policy.actor_critic_network.network import (
ACNForward,
ActorCriticNetwork,
default_hidden_sizes,
)
from rl_algo_impls.shared.policy.critic import CriticHead
from rl_algo_impls.shared.policy.policy import ACTIVATION
class UNetActorCriticNetwork(ActorCriticNetwork):
def __init__(
self,
observation_space: Space,
action_space: Space,
action_plane_space: Space,
v_hidden_sizes: Optional[Sequence[int]] = None,
init_layers_orthogonal: bool = True,
activation_fn: str = "tanh",
cnn_layers_init_orthogonal: Optional[bool] = None,
) -> None:
if cnn_layers_init_orthogonal is None:
cnn_layers_init_orthogonal = True
super().__init__()
assert isinstance(action_space, MultiDiscrete)
assert isinstance(action_plane_space, MultiDiscrete)
self.range_size = np.max(observation_space.high) - np.min(observation_space.low) # type: ignore
self.map_size = len(action_space.nvec) // len(action_plane_space.nvec) # type: ignore
self.action_vec = action_plane_space.nvec # type: ignore
activation = ACTIVATION[activation_fn]
def conv_relu(
in_channels: int, out_channels: int, kernel_size: int = 3, padding: int = 1
) -> nn.Module:
return nn.Sequential(
layer_init(
nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
padding=padding,
),
cnn_layers_init_orthogonal,
),
activation(),
)
def up_conv_relu(in_channels: int, out_channels: int) -> nn.Module:
return nn.Sequential(
layer_init(
nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
),
cnn_layers_init_orthogonal,
),
activation(),
)
in_channels = observation_space.shape[0] # type: ignore
self.enc1 = conv_relu(in_channels, 32)
self.enc2 = nn.Sequential(max_pool(), conv_relu(32, 64))
self.enc3 = nn.Sequential(max_pool(), conv_relu(64, 128))
self.enc4 = nn.Sequential(max_pool(), conv_relu(128, 256))
self.enc5 = nn.Sequential(
max_pool(), conv_relu(256, 512, kernel_size=1, padding=0)
)
self.dec4 = up_conv_relu(512, 256)
self.dec3 = nn.Sequential(conv_relu(512, 256), up_conv_relu(256, 128))
self.dec2 = nn.Sequential(conv_relu(256, 128), up_conv_relu(128, 64))
self.dec1 = nn.Sequential(conv_relu(128, 64), up_conv_relu(64, 32))
self.out = nn.Sequential(
conv_relu(64, 32),
layer_init(
nn.Conv2d(32, self.action_vec.sum(), kernel_size=1, padding=0),
cnn_layers_init_orthogonal,
std=0.01,
),
Transpose((0, 2, 3, 1)),
)
with torch.no_grad():
cnn_out = torch.flatten(
self.enc5(
self.enc4(
self.enc3(
self.enc2(
self.enc1(
self._preprocess(
torch.as_tensor(observation_space.sample())
)
)
)
)
)
),
start_dim=1,
)
v_hidden_sizes = (
v_hidden_sizes
if v_hidden_sizes is not None
else default_hidden_sizes(observation_space)
)
self.critic_head = CriticHead(
in_dim=cnn_out.shape[1:],
hidden_sizes=v_hidden_sizes,
activation=activation,
init_layers_orthogonal=init_layers_orthogonal,
)
def _preprocess(self, obs: torch.Tensor) -> torch.Tensor:
if len(obs.shape) == 3:
obs = obs.unsqueeze(0)
return obs.float() / self.range_size
def forward(
self,
obs: torch.Tensor,
action: torch.Tensor,
action_masks: Optional[torch.Tensor] = None,
) -> ACNForward:
return self._distribution_and_value(
obs, action=action, action_masks=action_masks
)
def distribution_and_value(
self, obs: torch.Tensor, action_masks: Optional[torch.Tensor] = None
) -> ACNForward:
return self._distribution_and_value(obs, action_masks=action_masks)
def _distribution_and_value(
self,
obs: torch.Tensor,
action: Optional[torch.Tensor] = None,
action_masks: Optional[torch.Tensor] = None,
) -> ACNForward:
assert (
action_masks is not None
), f"No mask case unhandled in {self.__class__.__name__}"
obs = self._preprocess(obs)
e1 = self.enc1(obs)
e2 = self.enc2(e1)
e3 = self.enc3(e2)
e4 = self.enc4(e3)
e5 = self.enc5(e4)
v = self.critic_head(e5)
d4 = self.dec4(e5)
d3 = self.dec3(torch.cat((d4, e4), dim=1))
d2 = self.dec2(torch.cat((d3, e3), dim=1))
d1 = self.dec1(torch.cat((d2, e2), dim=1))
logits = self.out(torch.cat((d1, e1), dim=1))
pi = GridnetDistribution(self.map_size, self.action_vec, logits, action_masks)
return ACNForward(pi_forward(pi, action), v)
def value(self, obs: torch.Tensor) -> torch.Tensor:
obs = self._preprocess(obs)
e1 = self.enc1(obs)
e2 = self.enc2(e1)
e3 = self.enc3(e2)
e4 = self.enc4(e3)
e5 = self.enc5(e4)
return self.critic_head(e5)
def reset_noise(self, batch_size: Optional[int] = None) -> None:
pass
@property
def action_shape(self) -> Tuple[int, ...]:
return (self.map_size, len(self.action_vec))
def max_pool() -> nn.MaxPool2d:
return nn.MaxPool2d(3, stride=2, padding=1)