randommm commited on
Commit
6e701f9
·
1 Parent(s): 3e23aec
facility_location/env/obs_extractor.py CHANGED
@@ -117,20 +117,21 @@ class ObsExtractor:
117
  return obs
118
 
119
  def _get_obs_graph(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
 
120
  facility = self._flc.get_current_solution().astype(np.float32)
121
  distance = self._flc.get_current_distance().astype(np.float32)
122
  distance = distance / np.max(distance)
123
  cost = self._flc.get_current_cost().astype(np.float32)
124
  cost = cost / np.max(cost)
125
  gain, loss = self._flc.get_gain_and_loss()
126
- gain = gain / np.max(gain)
127
- loss = loss / np.max(loss)
128
  dynamic_node_features = np.stack([
129
  facility,
130
  distance[:,0],
131
  distance[:,1],
132
  cost[:,0],
133
- cost[:,1],
134
  gain,
135
  loss,
136
  ], axis=-1)
 
117
  return obs
118
 
119
  def _get_obs_graph(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
120
+ EPS = 1e-8
121
  facility = self._flc.get_current_solution().astype(np.float32)
122
  distance = self._flc.get_current_distance().astype(np.float32)
123
  distance = distance / np.max(distance)
124
  cost = self._flc.get_current_cost().astype(np.float32)
125
  cost = cost / np.max(cost)
126
  gain, loss = self._flc.get_gain_and_loss()
127
+ gain = gain / (np.max(gain) + EPS)
128
+ loss = loss / (np.max(loss) + EPS)
129
  dynamic_node_features = np.stack([
130
  facility,
131
  distance[:,0],
132
  distance[:,1],
133
  cost[:,0],
134
+ cost[:,1],
135
  gain,
136
  loss,
137
  ], axis=-1)