苏泓源
update
a257639
from functools import partial
from typing import Callable, Tuple, Text, Union
from collections import OrderedDict
import numpy as np
from gym import spaces
import torch as th
from torch import nn
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.utils import get_device
from stable_baselines3.common.type_aliases import Schedule
def create_mlp(head: Text, input_dim: int, hidden_units: Tuple) -> nn.Sequential:
layers = OrderedDict()
for i, units in enumerate(hidden_units):
if i == 0:
layers[f'{head}_linear_{i}'] = nn.Linear(input_dim, units)
else:
layers[f'{head}_linear_{i}'] = nn.Linear(hidden_units[i - 1], units)
if i != len(hidden_units) - 1:
layers[f'{head}_tanh_{i}'] = nn.Tanh()
if head.startswith('policy'):
layers[f'{head}_flatten'] = nn.Flatten()
return nn.Sequential(layers)
class MaskedFacilityLocationNetwork(nn.Module):
def __init__(
self,
policy_feature_dim: int,
value_feature_dim: int,
policy_hidden_units: Tuple = (32, 32, 1),
value_hidden_units: Tuple = (32, 32, 1),
device: Union[th.device, Text] = "auto",
):
super().__init__()
device = get_device(device)
# Policy network
# self.old_facility_policy_net = create_mlp('policy-old-facility',
# policy_feature_dim,
# policy_hidden_units).to(device)
# self.new_facility_policy_net = create_mlp('policy-new-facility',
# policy_feature_dim,
# policy_hidden_units).to(device)
self.pair_facility_policy_net = create_mlp('policy-pair-facility',
policy_feature_dim,
policy_hidden_units).to(device)
# Value network
self.value_net = create_mlp('value',
value_feature_dim,
value_hidden_units).to(device)
def forward(self,
features: Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]) -> Tuple[th.Tensor, th.Tensor]:
return self.forward_actor(features), self.forward_critic(features)
# def forward_actor(self, features: Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]) -> th.Tensor:
# state_policy_old_facility, state_policy_new_facility, _, old_facility_mask, new_facility_mask = features
# old_facility_logits = self.old_facility_policy_net(state_policy_old_facility) # (batch_size, node_range)
# old_facility_padding = th.full_like(old_facility_mask, -th.inf, dtype=th.float32)
# masked_old_facility_logits = th.where(old_facility_mask, old_facility_logits, old_facility_padding)
# new_facility_logits = self.new_facility_policy_net(state_policy_new_facility) # (batch_size, node_range)
# new_facility_padding = th.full_like(new_facility_mask, -th.inf, dtype=th.float32)
# masked_new_facility_logits = th.where(new_facility_mask, new_facility_logits, new_facility_padding)
# masked_old_new_facility_logits = th.cat([masked_old_facility_logits, masked_new_facility_logits], dim=1)
# return masked_old_new_facility_logits
def forward_actor(self, features: Tuple[th.Tensor, th.Tensor, th.Tensor]) -> th.Tensor:
state_policy_pair_facility, _, dynamic_edge_mask = features
pair_facility_logits = self.pair_facility_policy_net(state_policy_pair_facility)
pair_facility_padding = th.full_like(dynamic_edge_mask, -th.inf, dtype=th.float32)
masked_pair_facility_logits = th.where(dynamic_edge_mask, pair_facility_logits, pair_facility_padding)
return masked_pair_facility_logits
def forward_critic(self, features: Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]) -> th.Tensor:
_, state_value, _ = features
return self.value_net(state_value)
class POPSTARMaskedFacilityLocationNetwork(nn.Module):
def __init__(
self,
policy_feature_dim: int,
value_feature_dim: int,
policy_hidden_units: Tuple = (32, 32, 1),
value_hidden_units: Tuple = (32, 32, 1),
device: Union[th.device, Text] = "auto",
):
super().__init__()
device = get_device(device)
# Policy network
self.old_facility_policy_net = create_mlp('policy-old-facility',
policy_feature_dim,
policy_hidden_units).to(device)
self.new_facility_policy_net = create_mlp('policy-new-facility',
policy_feature_dim,
policy_hidden_units).to(device)
self.old_new_facility_policy_net = create_mlp('policy-old-new-facility',
policy_feature_dim * 4,
policy_hidden_units).to(device)
# Value network
self.value_net = create_mlp('value',
value_feature_dim,
value_hidden_units).to(device)
def forward(self,
features: Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]) -> Tuple[th.Tensor, th.Tensor]:
return self.forward_actor(features), self.forward_critic(features)
def forward_actor(self, features: Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]) -> th.Tensor:
state_policy_old_facility, state_policy_new_facility, _, old_facility_mask, new_facility_mask = features
node_range = old_facility_mask.shape[1]
loss = self.old_facility_policy_net(state_policy_old_facility) # (batch_size, node_range)
loss = loss.repeat_interleave(node_range, dim=1)
gain = self.new_facility_policy_net(state_policy_new_facility) # (batch_size, node_range)
gain = gain.repeat(1, node_range)
state_policy_old_facility_expand = state_policy_old_facility.unsqueeze(2).expand(-1, -1, node_range, -1)
state_policy_new_facility_expand = state_policy_new_facility.unsqueeze(1).expand(-1, node_range, -1, -1)
state_policy_old_new_facility = th.cat(
[
state_policy_old_facility_expand,
state_policy_new_facility_expand,
state_policy_old_facility_expand - state_policy_new_facility_expand,
state_policy_old_facility_expand * state_policy_new_facility_expand
], dim=-1
)
extra = self.old_new_facility_policy_net(state_policy_old_new_facility) # (batch_size, node_range * node_range)
logits = gain - loss + extra
action_mask = th.logical_and(old_facility_mask.unsqueeze(2), new_facility_mask.unsqueeze(1)).flatten(start_dim=1)
padding = th.full_like(action_mask, -th.inf, dtype=th.float32)
masked_logits = th.where(action_mask, logits, padding)
return masked_logits
def forward_critic(self, features: Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]) -> th.Tensor:
_, _, state_value, _, _ = features
return self.value_net(state_value)
class MaskedFacilityLocationActorCriticPolicy(ActorCriticPolicy):
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Callable[[float], float],
*args,
**kwargs,
):
self.policy_feature_dim = kwargs.pop('policy_feature_dim')
self.value_feature_dim = kwargs.pop('value_feature_dim')
self.policy_hidden_units = kwargs.pop('policy_hidden_units')
self.value_hidden_units = kwargs.pop('value_hidden_units')
self.popstar = kwargs.pop('popstar')
super().__init__(
observation_space,
action_space,
lr_schedule,
# Pass remaining arguments to base class
*args,
**kwargs,
)
def _build(self, lr_schedule: Schedule) -> None:
self._build_mlp_extractor()
self.action_net = nn.Identity()
self.value_net = nn.Identity()
# Init weights: use orthogonal initialization
# with small initial weight for the output
if self.ortho_init:
# TODO: check for features_extractor
# Values from stable-baselines.
# features_extractor/mlp values are
# originally from openai/baselines (default gains/init_scales).
module_gains = {
self.features_extractor: np.sqrt(2),
self.mlp_extractor: np.sqrt(2),
}
# if not self.share_features_extractor:
# # Note(antonin): this is to keep SB3 results
# # consistent, see GH#1148
# del module_gains[self.features_extractor]
# module_gains[self.pi_features_extractor] = np.sqrt(2)
# module_gains[self.vf_features_extractor] = np.sqrt(2)
for module, gain in module_gains.items():
module.apply(partial(self.init_weights, gain=gain))
# Setup optimizer with initial learning rate
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
def _build_mlp_extractor(self) -> None:
if not self.popstar:
self.mlp_extractor = MaskedFacilityLocationNetwork(
self.policy_feature_dim,
self.value_feature_dim,
self.policy_hidden_units,
self.value_hidden_units,
self.device,
)
else:
self.mlp_extractor = POPSTARMaskedFacilityLocationNetwork(
self.policy_feature_dim,
self.value_feature_dim,
self.policy_hidden_units,
self.value_hidden_units,
self.device,
)