Spaces:
Runtime error
Runtime error
from functools import partial | |
from typing import Callable, Tuple, Text, Union | |
from collections import OrderedDict | |
import numpy as np | |
from gym import spaces | |
import torch as th | |
from torch import nn | |
from stable_baselines3.common.policies import ActorCriticPolicy | |
from stable_baselines3.common.utils import get_device | |
from stable_baselines3.common.type_aliases import Schedule | |
def create_mlp(head: Text, input_dim: int, hidden_units: Tuple) -> nn.Sequential: | |
layers = OrderedDict() | |
for i, units in enumerate(hidden_units): | |
if i == 0: | |
layers[f'{head}_linear_{i}'] = nn.Linear(input_dim, units) | |
else: | |
layers[f'{head}_linear_{i}'] = nn.Linear(hidden_units[i - 1], units) | |
if i != len(hidden_units) - 1: | |
layers[f'{head}_tanh_{i}'] = nn.Tanh() | |
if head.startswith('policy'): | |
layers[f'{head}_flatten'] = nn.Flatten() | |
return nn.Sequential(layers) | |
class MaskedFacilityLocationNetwork(nn.Module): | |
def __init__( | |
self, | |
policy_feature_dim: int, | |
value_feature_dim: int, | |
policy_hidden_units: Tuple = (32, 32, 1), | |
value_hidden_units: Tuple = (32, 32, 1), | |
device: Union[th.device, Text] = "auto", | |
): | |
super().__init__() | |
device = get_device(device) | |
# Policy network | |
# self.old_facility_policy_net = create_mlp('policy-old-facility', | |
# policy_feature_dim, | |
# policy_hidden_units).to(device) | |
# self.new_facility_policy_net = create_mlp('policy-new-facility', | |
# policy_feature_dim, | |
# policy_hidden_units).to(device) | |
self.pair_facility_policy_net = create_mlp('policy-pair-facility', | |
policy_feature_dim, | |
policy_hidden_units).to(device) | |
# Value network | |
self.value_net = create_mlp('value', | |
value_feature_dim, | |
value_hidden_units).to(device) | |
def forward(self, | |
features: Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]) -> Tuple[th.Tensor, th.Tensor]: | |
return self.forward_actor(features), self.forward_critic(features) | |
# def forward_actor(self, features: Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]) -> th.Tensor: | |
# state_policy_old_facility, state_policy_new_facility, _, old_facility_mask, new_facility_mask = features | |
# old_facility_logits = self.old_facility_policy_net(state_policy_old_facility) # (batch_size, node_range) | |
# old_facility_padding = th.full_like(old_facility_mask, -th.inf, dtype=th.float32) | |
# masked_old_facility_logits = th.where(old_facility_mask, old_facility_logits, old_facility_padding) | |
# new_facility_logits = self.new_facility_policy_net(state_policy_new_facility) # (batch_size, node_range) | |
# new_facility_padding = th.full_like(new_facility_mask, -th.inf, dtype=th.float32) | |
# masked_new_facility_logits = th.where(new_facility_mask, new_facility_logits, new_facility_padding) | |
# masked_old_new_facility_logits = th.cat([masked_old_facility_logits, masked_new_facility_logits], dim=1) | |
# return masked_old_new_facility_logits | |
def forward_actor(self, features: Tuple[th.Tensor, th.Tensor, th.Tensor]) -> th.Tensor: | |
state_policy_pair_facility, _, dynamic_edge_mask = features | |
pair_facility_logits = self.pair_facility_policy_net(state_policy_pair_facility) | |
pair_facility_padding = th.full_like(dynamic_edge_mask, -th.inf, dtype=th.float32) | |
masked_pair_facility_logits = th.where(dynamic_edge_mask, pair_facility_logits, pair_facility_padding) | |
return masked_pair_facility_logits | |
def forward_critic(self, features: Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]) -> th.Tensor: | |
_, state_value, _ = features | |
return self.value_net(state_value) | |
class POPSTARMaskedFacilityLocationNetwork(nn.Module): | |
def __init__( | |
self, | |
policy_feature_dim: int, | |
value_feature_dim: int, | |
policy_hidden_units: Tuple = (32, 32, 1), | |
value_hidden_units: Tuple = (32, 32, 1), | |
device: Union[th.device, Text] = "auto", | |
): | |
super().__init__() | |
device = get_device(device) | |
# Policy network | |
self.old_facility_policy_net = create_mlp('policy-old-facility', | |
policy_feature_dim, | |
policy_hidden_units).to(device) | |
self.new_facility_policy_net = create_mlp('policy-new-facility', | |
policy_feature_dim, | |
policy_hidden_units).to(device) | |
self.old_new_facility_policy_net = create_mlp('policy-old-new-facility', | |
policy_feature_dim * 4, | |
policy_hidden_units).to(device) | |
# Value network | |
self.value_net = create_mlp('value', | |
value_feature_dim, | |
value_hidden_units).to(device) | |
def forward(self, | |
features: Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]) -> Tuple[th.Tensor, th.Tensor]: | |
return self.forward_actor(features), self.forward_critic(features) | |
def forward_actor(self, features: Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]) -> th.Tensor: | |
state_policy_old_facility, state_policy_new_facility, _, old_facility_mask, new_facility_mask = features | |
node_range = old_facility_mask.shape[1] | |
loss = self.old_facility_policy_net(state_policy_old_facility) # (batch_size, node_range) | |
loss = loss.repeat_interleave(node_range, dim=1) | |
gain = self.new_facility_policy_net(state_policy_new_facility) # (batch_size, node_range) | |
gain = gain.repeat(1, node_range) | |
state_policy_old_facility_expand = state_policy_old_facility.unsqueeze(2).expand(-1, -1, node_range, -1) | |
state_policy_new_facility_expand = state_policy_new_facility.unsqueeze(1).expand(-1, node_range, -1, -1) | |
state_policy_old_new_facility = th.cat( | |
[ | |
state_policy_old_facility_expand, | |
state_policy_new_facility_expand, | |
state_policy_old_facility_expand - state_policy_new_facility_expand, | |
state_policy_old_facility_expand * state_policy_new_facility_expand | |
], dim=-1 | |
) | |
extra = self.old_new_facility_policy_net(state_policy_old_new_facility) # (batch_size, node_range * node_range) | |
logits = gain - loss + extra | |
action_mask = th.logical_and(old_facility_mask.unsqueeze(2), new_facility_mask.unsqueeze(1)).flatten(start_dim=1) | |
padding = th.full_like(action_mask, -th.inf, dtype=th.float32) | |
masked_logits = th.where(action_mask, logits, padding) | |
return masked_logits | |
def forward_critic(self, features: Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]) -> th.Tensor: | |
_, _, state_value, _, _ = features | |
return self.value_net(state_value) | |
class MaskedFacilityLocationActorCriticPolicy(ActorCriticPolicy): | |
def __init__( | |
self, | |
observation_space: spaces.Space, | |
action_space: spaces.Space, | |
lr_schedule: Callable[[float], float], | |
*args, | |
**kwargs, | |
): | |
self.policy_feature_dim = kwargs.pop('policy_feature_dim') | |
self.value_feature_dim = kwargs.pop('value_feature_dim') | |
self.policy_hidden_units = kwargs.pop('policy_hidden_units') | |
self.value_hidden_units = kwargs.pop('value_hidden_units') | |
self.popstar = kwargs.pop('popstar') | |
super().__init__( | |
observation_space, | |
action_space, | |
lr_schedule, | |
# Pass remaining arguments to base class | |
*args, | |
**kwargs, | |
) | |
def _build(self, lr_schedule: Schedule) -> None: | |
self._build_mlp_extractor() | |
self.action_net = nn.Identity() | |
self.value_net = nn.Identity() | |
# Init weights: use orthogonal initialization | |
# with small initial weight for the output | |
if self.ortho_init: | |
# TODO: check for features_extractor | |
# Values from stable-baselines. | |
# features_extractor/mlp values are | |
# originally from openai/baselines (default gains/init_scales). | |
module_gains = { | |
self.features_extractor: np.sqrt(2), | |
self.mlp_extractor: np.sqrt(2), | |
} | |
# if not self.share_features_extractor: | |
# # Note(antonin): this is to keep SB3 results | |
# # consistent, see GH#1148 | |
# del module_gains[self.features_extractor] | |
# module_gains[self.pi_features_extractor] = np.sqrt(2) | |
# module_gains[self.vf_features_extractor] = np.sqrt(2) | |
for module, gain in module_gains.items(): | |
module.apply(partial(self.init_weights, gain=gain)) | |
# Setup optimizer with initial learning rate | |
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) | |
def _build_mlp_extractor(self) -> None: | |
if not self.popstar: | |
self.mlp_extractor = MaskedFacilityLocationNetwork( | |
self.policy_feature_dim, | |
self.value_feature_dim, | |
self.policy_hidden_units, | |
self.value_hidden_units, | |
self.device, | |
) | |
else: | |
self.mlp_extractor = POPSTARMaskedFacilityLocationNetwork( | |
self.policy_feature_dim, | |
self.value_feature_dim, | |
self.policy_hidden_units, | |
self.value_hidden_units, | |
self.device, | |
) | |