sgoodfriend's picture
A2C playing CarRacing-v0 from https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b
fff5a3d
from abc import ABC, abstractmethod
from typing import NamedTuple, Optional, Sequence, Tuple
import torch
import torch.nn as nn
from gym.spaces import Box, Discrete, Space
from rl_algo_impls.shared.actor import PiForward
class ACNForward(NamedTuple):
pi_forward: PiForward
v: torch.Tensor
class ActorCriticNetwork(nn.Module, ABC):
@abstractmethod
def forward(
self,
obs: torch.Tensor,
action: torch.Tensor,
action_masks: Optional[torch.Tensor] = None,
) -> ACNForward:
...
@abstractmethod
def distribution_and_value(
self, obs: torch.Tensor, action_masks: Optional[torch.Tensor] = None
) -> ACNForward:
...
@abstractmethod
def value(self, obs: torch.Tensor) -> torch.Tensor:
...
@abstractmethod
def reset_noise(self, batch_size: Optional[int] = None) -> None:
...
@property
def action_shape(self) -> Tuple[int, ...]:
...
def default_hidden_sizes(obs_space: Space) -> Sequence[int]:
if isinstance(obs_space, Box):
if len(obs_space.shape) == 3: # type: ignore
# By default feature extractor to output has no hidden layers
return []
elif len(obs_space.shape) == 1: # type: ignore
return [64, 64]
else:
raise ValueError(f"Unsupported observation space: {obs_space}")
elif isinstance(obs_space, Discrete):
return [64]
else:
raise ValueError(f"Unsupported observation space: {obs_space}")