File size: 2,056 Bytes
a257639
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import warnings
from collections import OrderedDict
from typing import Text

import numpy as np
from stable_baselines3.common.vec_env import VecCheckNan


class DictVecCheckNan(VecCheckNan):
    def _check_val(self, event: str, **kwargs) -> None:
        # if warn and warn once and have warned once: then stop checking
        if not self.raise_exception and self.warn_once and self._user_warned:
            return

        found = []

        def check_val_np(check_name: Text, check_val: np.ndarray) -> None:
            has_nan = np.any(np.isnan(check_val))
            has_inf = self.check_inf and np.any(np.isinf(check_val))
            if has_inf:
                found.append((check_name, "inf"))
            if has_nan:
                found.append((check_name, "nan"))

        for name, val in kwargs.items():
            if isinstance(val, np.ndarray):
                check_val_np(name, val)
            elif isinstance(val, OrderedDict):
                for inner_name, inner_val in val.items():
                    check_val_np(f"{name}-{inner_name}", inner_val)
            else:
                raise ValueError(f"Unsupported observation type {type(val)}.")

        if found:
            self._user_warned = True
            msg = ""
            for i, (name, type_val) in enumerate(found):
                msg += f"found {type_val} in {name}"
                if i != len(found) - 1:
                    msg += ", "

            msg += ".\r\nOriginated from the "

            if event == "reset":
                msg += "environment observation (at reset)"
            elif event == "step_wait":
                msg += f"environment, Last given value was: \r\n\taction={self._actions}"
            elif event == "step_async":
                msg += f"RL model, Last given value was: \r\n\tobservations={self._observations}"
            else:
                raise ValueError("Internal error.")

            if self.raise_exception:
                raise ValueError(msg)
            else:
                warnings.warn(msg, UserWarning)