igorcheb commited on
Commit
38a4bc3
1 Parent(s): 71e1160

Update Agent_class.py

Browse files
Files changed (1) hide show
  1. Agent_class.py +4 -4
Agent_class.py CHANGED
@@ -2,7 +2,7 @@ import torch
2
 
3
  class ParameterisedPolicy(torch.nn.Module):
4
  """
5
- REINFORCE RL agent class. Returns action when the ParameterisedPolicy.act(observation) is used.
6
  observation is a gym state vector.
7
  obs_len - length of the state vector
8
  act_space_len - length of the action vector
@@ -10,8 +10,6 @@ class ParameterisedPolicy(torch.nn.Module):
10
  """
11
  def __init__(self, obs_len=8, act_space_len=2):
12
  super().__init__()
13
- self.deterministic = False
14
- self.continuous = True
15
  self.obs_len = obs_len
16
  self.act_space_len = act_space_len
17
  self.lin_1 = torch.nn.Linear(self.obs_len, 256)
@@ -41,7 +39,9 @@ class ParameterisedPolicy(torch.nn.Module):
41
  return mu, sigma
42
 
43
  def act(self, observation):
44
-
 
 
45
  (mus, sigmas) = self.forward(torch.tensor(observation, dtype=torch.float32))
46
  m = torch.distributions.normal.Normal(mus, sigmas)
47
  action = m.sample().detach().numpy()
 
2
 
3
  class ParameterisedPolicy(torch.nn.Module):
4
  """
5
+ REINFORCE RL agent class. Returns action when the ParameterisedPolicy.act(observation) method is used.
6
  observation is a gym state vector.
7
  obs_len - length of the state vector
8
  act_space_len - length of the action vector
 
10
  """
11
  def __init__(self, obs_len=8, act_space_len=2):
12
  super().__init__()
 
 
13
  self.obs_len = obs_len
14
  self.act_space_len = act_space_len
15
  self.lin_1 = torch.nn.Linear(self.obs_len, 256)
 
39
  return mu, sigma
40
 
41
  def act(self, observation):
42
+ """
43
+ Method returns action when gym state vector is passed.
44
+ """
45
  (mus, sigmas) = self.forward(torch.tensor(observation, dtype=torch.float32))
46
  m = torch.distributions.normal.Normal(mus, sigmas)
47
  action = m.sample().detach().numpy()