|
from typing import Optional, Tuple, List |
|
import torch |
|
import torch.nn as nn |
|
import treetensor.torch as ttorch |
|
|
|
|
|
class PPOFModel(nn.Module): |
|
mode = ['compute_actor', 'compute_critic', 'compute_actor_critic'] |
|
|
|
def __init__( |
|
self, |
|
obs_shape: Tuple[int], |
|
action_shape: int, |
|
encoder_hidden_size_list: List = [128, 128, 64], |
|
actor_head_hidden_size: int = 64, |
|
actor_head_layer_num: int = 1, |
|
critic_head_hidden_size: int = 64, |
|
critic_head_layer_num: int = 1, |
|
activation: Optional[nn.Module] = nn.ReLU(), |
|
) -> None: |
|
super(PPOFModel, self).__init__() |
|
self.obs_shape, self.action_shape = obs_shape, action_shape |
|
|
|
|
|
layers = [] |
|
input_size = obs_shape[0] |
|
kernel_size_list = [8, 4, 3] |
|
stride_list = [4, 2, 1] |
|
for i in range(len(encoder_hidden_size_list)): |
|
output_size = encoder_hidden_size_list[i] |
|
layers.append(nn.Conv2d(input_size, output_size, kernel_size_list[i], stride_list[i])) |
|
layers.append(activation) |
|
input_size = output_size |
|
layers.append(nn.Flatten()) |
|
self.encoder = nn.Sequential(*layers) |
|
|
|
flatten_size = input_size = self.get_flatten_size() |
|
|
|
layers = [] |
|
for i in range(critic_head_layer_num): |
|
layers.append(nn.Linear(input_size, critic_head_hidden_size)) |
|
layers.append(activation) |
|
input_size = critic_head_hidden_size |
|
layers.append(nn.Linear(critic_head_hidden_size, 1)) |
|
self.critic = nn.Sequential(*layers) |
|
|
|
layers = [] |
|
input_size = flatten_size |
|
for i in range(actor_head_layer_num): |
|
layers.append(nn.Linear(input_size, actor_head_hidden_size)) |
|
layers.append(activation) |
|
input_size = actor_head_hidden_size |
|
self.actor = nn.Sequential(*layers) |
|
self.mu = nn.Linear(actor_head_hidden_size, action_shape) |
|
self.log_sigma = nn.Parameter(torch.zeros(1, action_shape)) |
|
|
|
|
|
self.init_weights() |
|
|
|
def init_weights(self) -> None: |
|
|
|
raise NotImplementedError |
|
|
|
def get_flatten_size(self) -> int: |
|
test_data = torch.randn(1, *self.obs_shape) |
|
with torch.no_grad(): |
|
output = self.encoder(test_data) |
|
return output.shape[1] |
|
|
|
def forward(self, inputs: ttorch.Tensor, mode: str) -> ttorch.Tensor: |
|
assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) |
|
return getattr(self, mode)(inputs) |
|
|
|
def compute_actor(self, x: ttorch.Tensor) -> ttorch.Tensor: |
|
x = self.encoder(x) |
|
x = self.actor(x) |
|
mu = self.mu(x) |
|
log_sigma = self.log_sigma + torch.zeros_like(mu) |
|
sigma = torch.exp(log_sigma) |
|
return ttorch.as_tensor({'mu': mu, 'sigma': sigma}) |
|
|
|
def compute_critic(self, x: ttorch.Tensor) -> ttorch.Tensor: |
|
x = self.encoder(x) |
|
value = self.critic(x) |
|
return value |
|
|
|
def compute_actor_critic(self, x: ttorch.Tensor) -> ttorch.Tensor: |
|
x = self.encoder(x) |
|
value = self.critic(x) |
|
x = self.actor(x) |
|
mu = self.mu(x) |
|
log_sigma = self.log_sigma + torch.zeros_like(mu) |
|
sigma = torch.exp(log_sigma) |
|
return ttorch.as_tensor({'logit': {'mu': mu, 'sigma': sigma}, 'value': value}) |
|
|
|
|
|
def test_ppof_model() -> None: |
|
model = PPOFModel((4, 84, 84), 5) |
|
print(model) |
|
data = torch.randn(3, 4, 84, 84) |
|
output = model(data, mode='compute_critic') |
|
assert output.shape == (3, 1) |
|
output = model(data, mode='compute_actor') |
|
assert output.mu.shape == (3, 5) |
|
assert output.sigma.shape == (3, 5) |
|
output = model(data, mode='compute_actor_critic') |
|
assert output.value.shape == (3, 1) |
|
assert output.logit.mu.shape == (3, 5) |
|
assert output.logit.sigma.shape == (3, 5) |
|
print('End...') |
|
|
|
|
|
if __name__ == "__main__": |
|
test_ppof_model() |
|
|