Spaces:
Runtime error
Runtime error
import warnings | |
from collections import OrderedDict | |
from typing import Text | |
import numpy as np | |
from stable_baselines3.common.vec_env import VecCheckNan | |
class DictVecCheckNan(VecCheckNan): | |
def _check_val(self, event: str, **kwargs) -> None: | |
# if warn and warn once and have warned once: then stop checking | |
if not self.raise_exception and self.warn_once and self._user_warned: | |
return | |
found = [] | |
def check_val_np(check_name: Text, check_val: np.ndarray) -> None: | |
has_nan = np.any(np.isnan(check_val)) | |
has_inf = self.check_inf and np.any(np.isinf(check_val)) | |
if has_inf: | |
found.append((check_name, "inf")) | |
if has_nan: | |
found.append((check_name, "nan")) | |
for name, val in kwargs.items(): | |
if isinstance(val, np.ndarray): | |
check_val_np(name, val) | |
elif isinstance(val, OrderedDict): | |
for inner_name, inner_val in val.items(): | |
check_val_np(f"{name}-{inner_name}", inner_val) | |
else: | |
raise ValueError(f"Unsupported observation type {type(val)}.") | |
if found: | |
self._user_warned = True | |
msg = "" | |
for i, (name, type_val) in enumerate(found): | |
msg += f"found {type_val} in {name}" | |
if i != len(found) - 1: | |
msg += ", " | |
msg += ".\r\nOriginated from the " | |
if event == "reset": | |
msg += "environment observation (at reset)" | |
elif event == "step_wait": | |
msg += f"environment, Last given value was: \r\n\taction={self._actions}" | |
elif event == "step_async": | |
msg += f"RL model, Last given value was: \r\n\tobservations={self._observations}" | |
else: | |
raise ValueError("Internal error.") | |
if self.raise_exception: | |
raise ValueError(msg) | |
else: | |
warnings.warn(msg, UserWarning) | |