File size: 1,274 Bytes
460072a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from typing import Any, Dict, List

import numpy as np

from rl_algo_impls.wrappers.vectorable_wrapper import (
    VecEnvObs,
    VecEnvStepReturn,
    VecotarableWrapper,
)


class MicrortsStatsRecorder(VecotarableWrapper):
    def __init__(self, env, gamma: float) -> None:
        super().__init__(env)
        self.gamma = gamma
        self.raw_rewards = [[] for _ in range(self.num_envs)]

    def reset(self) -> VecEnvObs:
        obs = super().reset()
        self.raw_rewards = [[] for _ in range(self.num_envs)]
        return obs

    def step(self, actions: np.ndarray) -> VecEnvStepReturn:
        obs, rews, dones, infos = self.env.step(actions)
        self._update_infos(infos, dones)
        return obs, rews, dones, infos

    def _update_infos(self, infos: List[Dict[str, Any]], dones: np.ndarray) -> None:
        for idx, info in enumerate(infos):
            self.raw_rewards[idx].append(info["raw_rewards"])
        for idx, (info, done) in enumerate(zip(infos, dones)):
            if done:
                raw_rewards = np.array(self.raw_rewards[idx]).sum(0)
                raw_names = [str(rf) for rf in self.env.unwrapped.rfs]
                info["microrts_stats"] = dict(zip(raw_names, raw_rewards))
                self.raw_rewards[idx] = []