sgoodfriend's picture
PPO playing QbertNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/2067e21d62fff5db60168687e7d9e89019a8bfc0
5c87b65
raw
history blame
1.07 kB
import optuna
from gym.spaces import Box
from typing import Any, Dict
from rl_algo_impls.wrappers.vectorable_wrapper import (
VecEnv,
single_action_space,
)
def sample_on_policy_hyperparams(
trial: optuna.Trial, policy_hparams: Dict[str, Any], env: VecEnv
) -> Dict[str, Any]:
act_space = single_action_space(env)
policy_hparams["init_layers_orthogonal"] = trial.suggest_categorical(
"init_layers_orthogonal", [True, False]
)
policy_hparams["activation_fn"] = trial.suggest_categorical(
"activation_fn", ["tanh", "relu"]
)
if isinstance(act_space, Box):
policy_hparams["log_std_init"] = trial.suggest_float("log_std_init", -5, 0.5)
policy_hparams["use_sde"] = trial.suggest_categorical("use_sde", [False, True])
if policy_hparams.get("use_sde", False):
policy_hparams["squash_output"] = trial.suggest_categorical(
"squash_output", [False, True]
)
elif "squash_output" in policy_hparams:
del policy_hparams["squash_output"]
return policy_hparams