igorcheb commited on
Commit
3207519
1 Parent(s): 920e52d

Create Agent_class.py

Browse files
Files changed (1) hide show
  1. Agent_class.py +49 -0
Agent_class.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class ParameterisedPolicy(torch.nn.Module):
4
+ """
5
+ REINFORCE RL agent class. Returns action when the ParameterisedPolicy.act(observation) is used.
6
+ observation is a gym state vector.
7
+ obs_len - length of the state vector
8
+ act_space_len - length of the action vector
9
+
10
+ """
11
+ def __init__(self, obs_len=8, act_space_len=2):
12
+ super().__init__()
13
+ self.deterministic = False
14
+ self.continuous = True
15
+ self.obs_len = obs_len
16
+ self.act_space_len = act_space_len
17
+ self.lin_1 = torch.nn.Linear(self.obs_len, 256)
18
+ self.rel_1 = torch.nn.ReLU()
19
+
20
+ self.lin_2 = torch.nn.Linear(256, 128)
21
+ self.rel_2 = torch.nn.ReLU()
22
+
23
+ self.lin_3 = torch.nn.Linear(128, self.act_space_len)
24
+
25
+ self.lin_4 = torch.nn.Linear(128, self.act_space_len)
26
+ self.elu = torch.nn.ELU()
27
+
28
+
29
+ def forward(self, x):
30
+ x = self.lin_1(x)
31
+ x = self.rel_1(x)
32
+
33
+ x = self.lin_2(x)
34
+ x = self.rel_2(x)
35
+
36
+ mu = self.lin_3(x)
37
+
38
+ x = self.lin_4(x)
39
+ sigma = self.elu(x) + 1.000001
40
+
41
+ return mu, sigma
42
+
43
+ def act(self, observation):
44
+
45
+ (mus, sigmas) = self.forward(torch.tensor(observation, dtype=torch.float32))
46
+ m = torch.distributions.normal.Normal(mus, sigmas)
47
+ action = m.sample().detach().numpy()
48
+
49
+ return action