import warnings from collections import OrderedDict from typing import Text import numpy as np from stable_baselines3.common.vec_env import VecCheckNan class DictVecCheckNan(VecCheckNan): def _check_val(self, event: str, **kwargs) -> None: # if warn and warn once and have warned once: then stop checking if not self.raise_exception and self.warn_once and self._user_warned: return found = [] def check_val_np(check_name: Text, check_val: np.ndarray) -> None: has_nan = np.any(np.isnan(check_val)) has_inf = self.check_inf and np.any(np.isinf(check_val)) if has_inf: found.append((check_name, "inf")) if has_nan: found.append((check_name, "nan")) for name, val in kwargs.items(): if isinstance(val, np.ndarray): check_val_np(name, val) elif isinstance(val, OrderedDict): for inner_name, inner_val in val.items(): check_val_np(f"{name}-{inner_name}", inner_val) else: raise ValueError(f"Unsupported observation type {type(val)}.") if found: self._user_warned = True msg = "" for i, (name, type_val) in enumerate(found): msg += f"found {type_val} in {name}" if i != len(found) - 1: msg += ", " msg += ".\r\nOriginated from the " if event == "reset": msg += "environment observation (at reset)" elif event == "step_wait": msg += f"environment, Last given value was: \r\n\taction={self._actions}" elif event == "step_async": msg += f"RL model, Last given value was: \r\n\tobservations={self._observations}" else: raise ValueError("Internal error.") if self.raise_exception: raise ValueError(msg) else: warnings.warn(msg, UserWarning)