PPOxFamily / ppof_ch7_code_p1.py
TuTuHuss
update(hus): update data from official server
7955c6f
from typing import Optional, Tuple, List
import torch
import torch.nn as nn
import treetensor.torch as ttorch
class PPOFModel(nn.Module):
mode = ['compute_actor', 'compute_critic', 'compute_actor_critic']
def __init__(
self,
obs_shape: Tuple[int],
action_shape: int,
encoder_hidden_size_list: List = [128, 128, 64],
actor_head_hidden_size: int = 64,
actor_head_layer_num: int = 1,
critic_head_hidden_size: int = 64,
critic_head_layer_num: int = 1,
activation: Optional[nn.Module] = nn.ReLU(),
) -> None:
super(PPOFModel, self).__init__()
self.obs_shape, self.action_shape = obs_shape, action_shape
# encoder
layers = []
input_size = obs_shape[0]
kernel_size_list = [8, 4, 3]
stride_list = [4, 2, 1]
for i in range(len(encoder_hidden_size_list)):
output_size = encoder_hidden_size_list[i]
layers.append(nn.Conv2d(input_size, output_size, kernel_size_list[i], stride_list[i]))
layers.append(activation)
input_size = output_size
layers.append(nn.Flatten())
self.encoder = nn.Sequential(*layers)
flatten_size = input_size = self.get_flatten_size()
# critic
layers = []
for i in range(critic_head_layer_num):
layers.append(nn.Linear(input_size, critic_head_hidden_size))
layers.append(activation)
input_size = critic_head_hidden_size
layers.append(nn.Linear(critic_head_hidden_size, 1))
self.critic = nn.Sequential(*layers)
# actor
layers = []
input_size = flatten_size
for i in range(actor_head_layer_num):
layers.append(nn.Linear(input_size, actor_head_hidden_size))
layers.append(activation)
input_size = actor_head_hidden_size
self.actor = nn.Sequential(*layers)
self.mu = nn.Linear(actor_head_hidden_size, action_shape)
self.log_sigma = nn.Parameter(torch.zeros(1, action_shape))
# init weights
self.init_weights()
def init_weights(self) -> None:
# You need to implement this function
raise NotImplementedError
def get_flatten_size(self) -> int:
test_data = torch.randn(1, *self.obs_shape)
with torch.no_grad():
output = self.encoder(test_data)
return output.shape[1]
def forward(self, inputs: ttorch.Tensor, mode: str) -> ttorch.Tensor:
assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
return getattr(self, mode)(inputs)
def compute_actor(self, x: ttorch.Tensor) -> ttorch.Tensor:
x = self.encoder(x)
x = self.actor(x)
mu = self.mu(x)
log_sigma = self.log_sigma + torch.zeros_like(mu) # addition aims to broadcast shape
sigma = torch.exp(log_sigma)
return ttorch.as_tensor({'mu': mu, 'sigma': sigma})
def compute_critic(self, x: ttorch.Tensor) -> ttorch.Tensor:
x = self.encoder(x)
value = self.critic(x)
return value
def compute_actor_critic(self, x: ttorch.Tensor) -> ttorch.Tensor:
x = self.encoder(x)
value = self.critic(x)
x = self.actor(x)
mu = self.mu(x)
log_sigma = self.log_sigma + torch.zeros_like(mu) # addition aims to broadcast shape
sigma = torch.exp(log_sigma)
return ttorch.as_tensor({'logit': {'mu': mu, 'sigma': sigma}, 'value': value})
def test_ppof_model() -> None:
model = PPOFModel((4, 84, 84), 5)
print(model)
data = torch.randn(3, 4, 84, 84)
output = model(data, mode='compute_critic')
assert output.shape == (3, 1)
output = model(data, mode='compute_actor')
assert output.mu.shape == (3, 5)
assert output.sigma.shape == (3, 5)
output = model(data, mode='compute_actor_critic')
assert output.value.shape == (3, 1)
assert output.logit.mu.shape == (3, 5)
assert output.logit.sigma.shape == (3, 5)
print('End...')
if __name__ == "__main__":
test_ppof_model()