File size: 3,103 Bytes
76a55af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from dataclasses import astuple
from typing import Optional

import gym
import numpy as np
from torch.utils.tensorboard.writer import SummaryWriter

from rl_algo_impls.runner.config import Config, EnvHyperparams
from rl_algo_impls.wrappers.action_mask_wrapper import MicrortsMaskWrapper
from rl_algo_impls.wrappers.episode_stats_writer import EpisodeStatsWriter
from rl_algo_impls.wrappers.hwc_to_chw_observation import HwcToChwObservation
from rl_algo_impls.wrappers.is_vector_env import IsVectorEnv
from rl_algo_impls.wrappers.microrts_stats_recorder import MicrortsStatsRecorder
from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv


def make_microrts_env(
    config: Config,
    hparams: EnvHyperparams,
    training: bool = True,
    render: bool = False,
    normalize_load_path: Optional[str] = None,
    tb_writer: Optional[SummaryWriter] = None,
) -> VecEnv:
    import gym_microrts
    from gym_microrts import microrts_ai

    from rl_algo_impls.shared.vec_env.microrts_compat import (
        MicroRTSGridModeVecEnvCompat,
    )

    (
        _,  # env_type
        n_envs,
        _,  # frame_stack
        make_kwargs,
        _,  # no_reward_timeout_steps
        _,  # no_reward_fire_steps
        _,  # vec_env_class
        _,  # normalize
        _,  # normalize_kwargs,
        rolling_length,
        _,  # train_record_video
        _,  # video_step_interval
        _,  # initial_steps_to_truncate
        _,  # clip_atari_rewards
        _,  # normalize_type
        _,  # mask_actions
        bots,
    ) = astuple(hparams)

    seed = config.seed(training=training)

    make_kwargs = make_kwargs or {}
    if "num_selfplay_envs" not in make_kwargs:
        make_kwargs["num_selfplay_envs"] = 0
    if "num_bot_envs" not in make_kwargs:
        make_kwargs["num_bot_envs"] = n_envs - make_kwargs["num_selfplay_envs"]
    if "reward_weight" in make_kwargs:
        make_kwargs["reward_weight"] = np.array(make_kwargs["reward_weight"])
    if bots:
        ai2s = []
        for ai_name, n in bots.items():
            for _ in range(n):
                if len(ai2s) >= make_kwargs["num_bot_envs"]:
                    break
                ai = getattr(microrts_ai, ai_name)
                assert ai, f"{ai_name} not in microrts_ai"
                ai2s.append(ai)
    else:
        ai2s = [microrts_ai.randomAI for _ in make_kwargs["num_bot_envs"]]
    make_kwargs["ai2s"] = ai2s
    envs = MicroRTSGridModeVecEnvCompat(**make_kwargs)
    envs = HwcToChwObservation(envs)
    envs = IsVectorEnv(envs)
    envs = MicrortsMaskWrapper(envs)

    if seed is not None:
        envs.action_space.seed(seed)
        envs.observation_space.seed(seed)

    envs = gym.wrappers.RecordEpisodeStatistics(envs)
    envs = MicrortsStatsRecorder(envs, config.algo_hyperparams.get("gamma", 0.99))
    if training:
        assert tb_writer
        envs = EpisodeStatsWriter(
            envs,
            tb_writer,
            training=training,
            rolling_length=rolling_length,
            additional_keys_to_log=config.additional_keys_to_log,
        )

    return envs