Spaces:
Runtime error
Runtime error
File size: 2,056 Bytes
a257639 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
import warnings
from collections import OrderedDict
from typing import Text
import numpy as np
from stable_baselines3.common.vec_env import VecCheckNan
class DictVecCheckNan(VecCheckNan):
def _check_val(self, event: str, **kwargs) -> None:
# if warn and warn once and have warned once: then stop checking
if not self.raise_exception and self.warn_once and self._user_warned:
return
found = []
def check_val_np(check_name: Text, check_val: np.ndarray) -> None:
has_nan = np.any(np.isnan(check_val))
has_inf = self.check_inf and np.any(np.isinf(check_val))
if has_inf:
found.append((check_name, "inf"))
if has_nan:
found.append((check_name, "nan"))
for name, val in kwargs.items():
if isinstance(val, np.ndarray):
check_val_np(name, val)
elif isinstance(val, OrderedDict):
for inner_name, inner_val in val.items():
check_val_np(f"{name}-{inner_name}", inner_val)
else:
raise ValueError(f"Unsupported observation type {type(val)}.")
if found:
self._user_warned = True
msg = ""
for i, (name, type_val) in enumerate(found):
msg += f"found {type_val} in {name}"
if i != len(found) - 1:
msg += ", "
msg += ".\r\nOriginated from the "
if event == "reset":
msg += "environment observation (at reset)"
elif event == "step_wait":
msg += f"environment, Last given value was: \r\n\taction={self._actions}"
elif event == "step_async":
msg += f"RL model, Last given value was: \r\n\tobservations={self._observations}"
else:
raise ValueError("Internal error.")
if self.raise_exception:
raise ValueError(msg)
else:
warnings.warn(msg, UserWarning)
|