File size: 1,474 Bytes
3207519
 
 
 
38a4bc3
3207519
a3988fe
 
3207519
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38a4bc3
 
 
3207519
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch

class ParameterisedPolicy(torch.nn.Module):
    """
    REINFORCE RL agent class. Returns action when the ParameterisedPolicy.act(observation) method is used.
    observation is a gym state vector.
    obs_len - length of the state vector.
    act_space_len - length of the action vector.
    """
    def __init__(self, obs_len=8, act_space_len=2):
        super().__init__()
        self.obs_len = obs_len
        self.act_space_len = act_space_len
        self.lin_1 = torch.nn.Linear(self.obs_len, 256)
        self.rel_1 = torch.nn.ReLU()
        
        self.lin_2 = torch.nn.Linear(256, 128)
        self.rel_2 = torch.nn.ReLU()
            
        self.lin_3 = torch.nn.Linear(128, self.act_space_len)
        
        self.lin_4 = torch.nn.Linear(128, self.act_space_len)
        self.elu = torch.nn.ELU()
        
        
    def forward(self, x):
        x = self.lin_1(x)
        x = self.rel_1(x)
        
        x = self.lin_2(x)
        x = self.rel_2(x)
        
        mu = self.lin_3(x)
        
        x = self.lin_4(x)
        sigma = self.elu(x) + 1.000001
        
        return mu, sigma
    
    def act(self, observation):
        """
        Method returns action when gym state vector is passed.
        """
        (mus, sigmas) = self.forward(torch.tensor(observation, dtype=torch.float32))
        m = torch.distributions.normal.Normal(mus, sigmas)
        action = m.sample().detach().numpy()
        
        return action