File size: 1,533 Bytes
fff5a3d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
from abc import ABC, abstractmethod
from typing import NamedTuple, Optional, Sequence, Tuple
import torch
import torch.nn as nn
from gym.spaces import Box, Discrete, Space
from rl_algo_impls.shared.actor import PiForward
class ACNForward(NamedTuple):
pi_forward: PiForward
v: torch.Tensor
class ActorCriticNetwork(nn.Module, ABC):
@abstractmethod
def forward(
self,
obs: torch.Tensor,
action: torch.Tensor,
action_masks: Optional[torch.Tensor] = None,
) -> ACNForward:
...
@abstractmethod
def distribution_and_value(
self, obs: torch.Tensor, action_masks: Optional[torch.Tensor] = None
) -> ACNForward:
...
@abstractmethod
def value(self, obs: torch.Tensor) -> torch.Tensor:
...
@abstractmethod
def reset_noise(self, batch_size: Optional[int] = None) -> None:
...
@property
def action_shape(self) -> Tuple[int, ...]:
...
def default_hidden_sizes(obs_space: Space) -> Sequence[int]:
if isinstance(obs_space, Box):
if len(obs_space.shape) == 3: # type: ignore
# By default feature extractor to output has no hidden layers
return []
elif len(obs_space.shape) == 1: # type: ignore
return [64, 64]
else:
raise ValueError(f"Unsupported observation space: {obs_space}")
elif isinstance(obs_space, Discrete):
return [64]
else:
raise ValueError(f"Unsupported observation space: {obs_space}")
|