import torch class ParameterisedPolicy(torch.nn.Module): """ REINFORCE RL agent class. Returns action when the ParameterisedPolicy.act(observation) method is used. observation is a gym state vector. obs_len - length of the state vector act_space_len - length of the action vector """ def __init__(self, obs_len=8, act_space_len=2): super().__init__() self.obs_len = obs_len self.act_space_len = act_space_len self.lin_1 = torch.nn.Linear(self.obs_len, 256) self.rel_1 = torch.nn.ReLU() self.lin_2 = torch.nn.Linear(256, 128) self.rel_2 = torch.nn.ReLU() self.lin_3 = torch.nn.Linear(128, self.act_space_len) self.lin_4 = torch.nn.Linear(128, self.act_space_len) self.elu = torch.nn.ELU() def forward(self, x): x = self.lin_1(x) x = self.rel_1(x) x = self.lin_2(x) x = self.rel_2(x) mu = self.lin_3(x) x = self.lin_4(x) sigma = self.elu(x) + 1.000001 return mu, sigma def act(self, observation): """ Method returns action when gym state vector is passed. """ (mus, sigmas) = self.forward(torch.tensor(observation, dtype=torch.float32)) m = torch.distributions.normal.Normal(mus, sigmas) action = m.sample().detach().numpy() return action