from functools import partial from typing import Callable, Tuple, Text, Union from collections import OrderedDict import numpy as np from gym import spaces import torch as th from torch import nn from stable_baselines3.common.policies import ActorCriticPolicy from stable_baselines3.common.utils import get_device from stable_baselines3.common.type_aliases import Schedule def create_mlp(head: Text, input_dim: int, hidden_units: Tuple) -> nn.Sequential: layers = OrderedDict() for i, units in enumerate(hidden_units): if i == 0: layers[f'{head}_linear_{i}'] = nn.Linear(input_dim, units) else: layers[f'{head}_linear_{i}'] = nn.Linear(hidden_units[i - 1], units) if i != len(hidden_units) - 1: layers[f'{head}_tanh_{i}'] = nn.Tanh() if head.startswith('policy'): layers[f'{head}_flatten'] = nn.Flatten() return nn.Sequential(layers) class MaskedFacilityLocationNetwork(nn.Module): def __init__( self, policy_feature_dim: int, value_feature_dim: int, policy_hidden_units: Tuple = (32, 32, 1), value_hidden_units: Tuple = (32, 32, 1), device: Union[th.device, Text] = "auto", ): super().__init__() device = get_device(device) # Policy network # self.old_facility_policy_net = create_mlp('policy-old-facility', # policy_feature_dim, # policy_hidden_units).to(device) # self.new_facility_policy_net = create_mlp('policy-new-facility', # policy_feature_dim, # policy_hidden_units).to(device) self.pair_facility_policy_net = create_mlp('policy-pair-facility', policy_feature_dim, policy_hidden_units).to(device) # Value network self.value_net = create_mlp('value', value_feature_dim, value_hidden_units).to(device) def forward(self, features: Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]) -> Tuple[th.Tensor, th.Tensor]: return self.forward_actor(features), self.forward_critic(features) # def forward_actor(self, features: Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]) -> th.Tensor: # state_policy_old_facility, state_policy_new_facility, _, old_facility_mask, new_facility_mask = features # old_facility_logits = self.old_facility_policy_net(state_policy_old_facility) # (batch_size, node_range) # old_facility_padding = th.full_like(old_facility_mask, -th.inf, dtype=th.float32) # masked_old_facility_logits = th.where(old_facility_mask, old_facility_logits, old_facility_padding) # new_facility_logits = self.new_facility_policy_net(state_policy_new_facility) # (batch_size, node_range) # new_facility_padding = th.full_like(new_facility_mask, -th.inf, dtype=th.float32) # masked_new_facility_logits = th.where(new_facility_mask, new_facility_logits, new_facility_padding) # masked_old_new_facility_logits = th.cat([masked_old_facility_logits, masked_new_facility_logits], dim=1) # return masked_old_new_facility_logits def forward_actor(self, features: Tuple[th.Tensor, th.Tensor, th.Tensor]) -> th.Tensor: state_policy_pair_facility, _, dynamic_edge_mask = features pair_facility_logits = self.pair_facility_policy_net(state_policy_pair_facility) pair_facility_padding = th.full_like(dynamic_edge_mask, -th.inf, dtype=th.float32) masked_pair_facility_logits = th.where(dynamic_edge_mask, pair_facility_logits, pair_facility_padding) return masked_pair_facility_logits def forward_critic(self, features: Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]) -> th.Tensor: _, state_value, _ = features return self.value_net(state_value) class POPSTARMaskedFacilityLocationNetwork(nn.Module): def __init__( self, policy_feature_dim: int, value_feature_dim: int, policy_hidden_units: Tuple = (32, 32, 1), value_hidden_units: Tuple = (32, 32, 1), device: Union[th.device, Text] = "auto", ): super().__init__() device = get_device(device) # Policy network self.old_facility_policy_net = create_mlp('policy-old-facility', policy_feature_dim, policy_hidden_units).to(device) self.new_facility_policy_net = create_mlp('policy-new-facility', policy_feature_dim, policy_hidden_units).to(device) self.old_new_facility_policy_net = create_mlp('policy-old-new-facility', policy_feature_dim * 4, policy_hidden_units).to(device) # Value network self.value_net = create_mlp('value', value_feature_dim, value_hidden_units).to(device) def forward(self, features: Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]) -> Tuple[th.Tensor, th.Tensor]: return self.forward_actor(features), self.forward_critic(features) def forward_actor(self, features: Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]) -> th.Tensor: state_policy_old_facility, state_policy_new_facility, _, old_facility_mask, new_facility_mask = features node_range = old_facility_mask.shape[1] loss = self.old_facility_policy_net(state_policy_old_facility) # (batch_size, node_range) loss = loss.repeat_interleave(node_range, dim=1) gain = self.new_facility_policy_net(state_policy_new_facility) # (batch_size, node_range) gain = gain.repeat(1, node_range) state_policy_old_facility_expand = state_policy_old_facility.unsqueeze(2).expand(-1, -1, node_range, -1) state_policy_new_facility_expand = state_policy_new_facility.unsqueeze(1).expand(-1, node_range, -1, -1) state_policy_old_new_facility = th.cat( [ state_policy_old_facility_expand, state_policy_new_facility_expand, state_policy_old_facility_expand - state_policy_new_facility_expand, state_policy_old_facility_expand * state_policy_new_facility_expand ], dim=-1 ) extra = self.old_new_facility_policy_net(state_policy_old_new_facility) # (batch_size, node_range * node_range) logits = gain - loss + extra action_mask = th.logical_and(old_facility_mask.unsqueeze(2), new_facility_mask.unsqueeze(1)).flatten(start_dim=1) padding = th.full_like(action_mask, -th.inf, dtype=th.float32) masked_logits = th.where(action_mask, logits, padding) return masked_logits def forward_critic(self, features: Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]) -> th.Tensor: _, _, state_value, _, _ = features return self.value_net(state_value) class MaskedFacilityLocationActorCriticPolicy(ActorCriticPolicy): def __init__( self, observation_space: spaces.Space, action_space: spaces.Space, lr_schedule: Callable[[float], float], *args, **kwargs, ): self.policy_feature_dim = kwargs.pop('policy_feature_dim') self.value_feature_dim = kwargs.pop('value_feature_dim') self.policy_hidden_units = kwargs.pop('policy_hidden_units') self.value_hidden_units = kwargs.pop('value_hidden_units') self.popstar = kwargs.pop('popstar') super().__init__( observation_space, action_space, lr_schedule, # Pass remaining arguments to base class *args, **kwargs, ) def _build(self, lr_schedule: Schedule) -> None: self._build_mlp_extractor() self.action_net = nn.Identity() self.value_net = nn.Identity() # Init weights: use orthogonal initialization # with small initial weight for the output if self.ortho_init: # TODO: check for features_extractor # Values from stable-baselines. # features_extractor/mlp values are # originally from openai/baselines (default gains/init_scales). module_gains = { self.features_extractor: np.sqrt(2), self.mlp_extractor: np.sqrt(2), } # if not self.share_features_extractor: # # Note(antonin): this is to keep SB3 results # # consistent, see GH#1148 # del module_gains[self.features_extractor] # module_gains[self.pi_features_extractor] = np.sqrt(2) # module_gains[self.vf_features_extractor] = np.sqrt(2) for module, gain in module_gains.items(): module.apply(partial(self.init_weights, gain=gain)) # Setup optimizer with initial learning rate self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) def _build_mlp_extractor(self) -> None: if not self.popstar: self.mlp_extractor = MaskedFacilityLocationNetwork( self.policy_feature_dim, self.value_feature_dim, self.policy_hidden_units, self.value_hidden_units, self.device, ) else: self.mlp_extractor = POPSTARMaskedFacilityLocationNetwork( self.policy_feature_dim, self.value_feature_dim, self.policy_hidden_units, self.value_hidden_units, self.device, )