File size: 4,150 Bytes
7955c6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from typing import Optional, Tuple, List
import torch
import torch.nn as nn
import treetensor.torch as ttorch


class PPOFModel(nn.Module):
    mode = ['compute_actor', 'compute_critic', 'compute_actor_critic']

    def __init__(
            self,
            obs_shape: Tuple[int],
            action_shape: int,
            encoder_hidden_size_list: List = [128, 128, 64],
            actor_head_hidden_size: int = 64,
            actor_head_layer_num: int = 1,
            critic_head_hidden_size: int = 64,
            critic_head_layer_num: int = 1,
            activation: Optional[nn.Module] = nn.ReLU(),
    ) -> None:
        super(PPOFModel, self).__init__()
        self.obs_shape, self.action_shape = obs_shape, action_shape

        # encoder
        layers = []
        input_size = obs_shape[0]
        kernel_size_list = [8, 4, 3]
        stride_list = [4, 2, 1]
        for i in range(len(encoder_hidden_size_list)):
            output_size = encoder_hidden_size_list[i]
            layers.append(nn.Conv2d(input_size, output_size, kernel_size_list[i], stride_list[i]))
            layers.append(activation)
            input_size = output_size
        layers.append(nn.Flatten())
        self.encoder = nn.Sequential(*layers)

        flatten_size = input_size = self.get_flatten_size()
        # critic
        layers = []
        for i in range(critic_head_layer_num):
            layers.append(nn.Linear(input_size, critic_head_hidden_size))
            layers.append(activation)
            input_size = critic_head_hidden_size
        layers.append(nn.Linear(critic_head_hidden_size, 1))
        self.critic = nn.Sequential(*layers)
        # actor
        layers = []
        input_size = flatten_size
        for i in range(actor_head_layer_num):
            layers.append(nn.Linear(input_size, actor_head_hidden_size))
            layers.append(activation)
            input_size = actor_head_hidden_size
        self.actor = nn.Sequential(*layers)
        self.mu = nn.Linear(actor_head_hidden_size, action_shape)
        self.log_sigma = nn.Parameter(torch.zeros(1, action_shape))

        # init weights
        self.init_weights()

    def init_weights(self) -> None:
        # You need to implement this function
        raise NotImplementedError

    def get_flatten_size(self) -> int:
        test_data = torch.randn(1, *self.obs_shape)
        with torch.no_grad():
            output = self.encoder(test_data)
        return output.shape[1]

    def forward(self, inputs: ttorch.Tensor, mode: str) -> ttorch.Tensor:
        assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
        return getattr(self, mode)(inputs)

    def compute_actor(self, x: ttorch.Tensor) -> ttorch.Tensor:
        x = self.encoder(x)
        x = self.actor(x)
        mu = self.mu(x)
        log_sigma = self.log_sigma + torch.zeros_like(mu)  # addition aims to broadcast shape
        sigma = torch.exp(log_sigma)
        return ttorch.as_tensor({'mu': mu, 'sigma': sigma})

    def compute_critic(self, x: ttorch.Tensor) -> ttorch.Tensor:
        x = self.encoder(x)
        value = self.critic(x)
        return value

    def compute_actor_critic(self, x: ttorch.Tensor) -> ttorch.Tensor:
        x = self.encoder(x)
        value = self.critic(x)
        x = self.actor(x)
        mu = self.mu(x)
        log_sigma = self.log_sigma + torch.zeros_like(mu)  # addition aims to broadcast shape
        sigma = torch.exp(log_sigma)
        return ttorch.as_tensor({'logit': {'mu': mu, 'sigma': sigma}, 'value': value})


def test_ppof_model() -> None:
    model = PPOFModel((4, 84, 84), 5)
    print(model)
    data = torch.randn(3, 4, 84, 84)
    output = model(data, mode='compute_critic')
    assert output.shape == (3, 1)
    output = model(data, mode='compute_actor')
    assert output.mu.shape == (3, 5)
    assert output.sigma.shape == (3, 5)
    output = model(data, mode='compute_actor_critic')
    assert output.value.shape == (3, 1)
    assert output.logit.mu.shape == (3, 5)
    assert output.logit.sigma.shape == (3, 5)
    print('End...')


if __name__ == "__main__":
    test_ppof_model()