from collections import OrderedDict from typing import Tuple from gym import spaces import torch as th from torch import nn from stable_baselines3.common.torch_layers import BaseFeaturesExtractor from stable_baselines3.common.type_aliases import TensorDict import time def mean_features(h: th.Tensor, mask: th.Tensor): float_mask = mask.float() mean_h = (h * float_mask.unsqueeze(-1)).sum(dim=1) / float_mask.sum(dim=1, keepdim=True) return mean_h # def compute_state(observations: TensorDict, h_nodes: th.Tensor): # node_mask = observations['node_mask'].bool() # mean_h_nodes = mean_features(h_nodes, node_mask) # old_facility_mask = observations['old_facility_mask'].bool() # h_old_facility = mean_features(h_nodes, old_facility_mask) # h_old_facility_repeat = h_old_facility.unsqueeze(1).expand(-1, h_nodes.shape[1], -1) # state_policy_old_facility = th.cat([ # h_nodes, # h_old_facility_repeat, # h_nodes - h_old_facility_repeat, # h_nodes * h_old_facility_repeat], dim=-1) # new_facility_mask = observations['new_facility_mask'].bool() # h_new_facility = mean_features(h_nodes, new_facility_mask) # h_new_facility_repeat = h_new_facility.unsqueeze(1).expand(-1, h_nodes.shape[1], -1) # state_policy_new_facility = th.cat([ # h_nodes, # h_new_facility_repeat, # h_nodes - h_new_facility_repeat, # state_value = th.cat([ # mean_h_nodes, # h_old_facility, # h_new_facility], dim=-1) # return state_policy_old_facility, state_policy_new_facility, state_value, old_facility_mask, new_facility_mask def compute_state(observations: TensorDict, h_edges: th.Tensor): dynamic_edge_mask = observations['dynamic_edge_mask'].bool() mean_h_edges = mean_features(h_edges, dynamic_edge_mask) state_policy_facility_pair = h_edges state_value = mean_h_edges return state_policy_facility_pair, state_value, dynamic_edge_mask class FacilityLocationMLPExtractor(BaseFeaturesExtractor): def __init__( self, observation_space: spaces.Dict, hidden_units: Tuple = (32, 32), ) -> None: super().__init__(observation_space, features_dim=1) self.node_mlp = self.create_mlp(observation_space.spaces['node_features'].shape[1], hidden_units) @staticmethod def create_mlp(input_dim: int, hidden_units: Tuple) -> nn.Sequential: layers = OrderedDict() for i, units in enumerate(hidden_units): if i == 0: layers[f'mlp-extractor-linear_{i}'] = nn.Linear(input_dim, units) else: layers[f'mlp-extractor-linear_{i}'] = nn.Linear(hidden_units[i - 1], units) layers[f'mlp-extractor-tanh_{i}'] = nn.Tanh() return nn.Sequential(layers) def forward(self, observations: TensorDict) -> Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]: node_features = observations['node_features'] h_nodes = self.node_mlp(node_features) return compute_state(observations, h_nodes) @staticmethod def get_policy_feature_dim(node_dim: int) -> int: return node_dim * 4 @staticmethod def get_value_feature_dim(node_dim: int) -> int: return node_dim * 3 class FacilityLocationGNNExtractor(BaseFeaturesExtractor): def __init__( self, observation_space: spaces.Dict, num_gnn_layers: int = 2, node_dim: int = 32, ) -> None: super().__init__(observation_space, features_dim=1) num_node_features = observation_space.spaces['node_features'].shape[1] self.node_encoder = self.create_node_encoder(num_node_features, node_dim) self.gnn_layers = self.create_gnn(num_gnn_layers, node_dim) self.single_gnn_layer = self.create_gnn(1, node_dim)[0] @staticmethod def create_node_encoder(num_node_features: int, node_dim: int) -> nn.Sequential: node_encoder = nn.Sequential( nn.Linear(num_node_features, node_dim), nn.Tanh()) return node_encoder @staticmethod def create_gnn(num_gnn_layers: int, node_dim: int) -> nn.ModuleList: layers = nn.ModuleList() for i in range(num_gnn_layers): gnn_layer = nn.Sequential( nn.Linear(node_dim, node_dim), nn.Tanh()) layers.append(gnn_layer) return layers @staticmethod def scatter_count(h_edges, indices, edge_mask, max_num_nodes): batch_size = h_edges.shape[0] num_latents = h_edges.shape[2] h_nodes = th.zeros(batch_size, max_num_nodes, num_latents).to(h_edges.device) count_edge = th.zeros_like(h_nodes) count = th.broadcast_to(edge_mask.unsqueeze(-1), h_edges.shape).float() idx = indices.unsqueeze(-1).expand(-1, -1, num_latents) h_nodes = h_nodes.scatter_add_(1, idx, h_edges) count_edge = count_edge.scatter_add_(1, idx, count) return h_nodes, count_edge @staticmethod def gather_to_edges(h_nodes, edge_index, edge_mask, gnn_layer): h_nodes = gnn_layer(h_nodes) h_edges_12 = th.gather(h_nodes, 1, edge_index[:, :, 0].unsqueeze(-1).expand(-1, -1, h_nodes.size(-1))) h_edges_21 = th.gather(h_nodes, 1, edge_index[:, :, 1].unsqueeze(-1).expand(-1, -1, h_nodes.size(-1))) mask = th.broadcast_to(edge_mask.unsqueeze(-1), h_edges_12.shape) h_edges_12 = th.where(mask, h_edges_12, th.zeros_like(h_edges_12)) h_edges_21 = th.where(mask, h_edges_21, th.zeros_like(h_edges_21)) return h_edges_12, h_edges_21 @classmethod def scatter_to_nodes(cls, h_edges, edge_index, edge_mask, node_mask): h_edges_12, h_edges_21 = h_edges max_num_nodes = node_mask.shape[1] h_nodes_1, count_1 = cls.scatter_count(h_edges_21, edge_index[:, :, 0], edge_mask, max_num_nodes) h_nodes_2, count_2 = cls.scatter_count(h_edges_12, edge_index[:, :, 1], edge_mask, max_num_nodes) h_nodes_sum = h_nodes_1 + h_nodes_2 mask = th.broadcast_to(node_mask.unsqueeze(-1), h_nodes_sum.shape) count = count_1 + count_2 count_padding = th.ones_like(count) count = th.where(mask, count, count_padding) h_nodes = h_nodes_sum / count return h_nodes def forward(self, observations: TensorDict) -> Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]: t1 = time.time() node_features = observations['node_features'] h_nodes = self.node_encoder(node_features) edge_static_index = observations['static_adjacency_list'].long() edge_dynamic_index = observations['dynamic_adjacency_list'].long() node_mask = observations['node_mask'].bool() static_edge_mask = observations['static_edge_mask'].bool() dynamic_edge_mask = observations['dynamic_edge_mask'].bool() for gnn_layer in self.gnn_layers: h_edges = self.gather_to_edges(h_nodes, edge_static_index, static_edge_mask, gnn_layer) h_nodes_new = self.scatter_to_nodes(h_edges, edge_static_index, static_edge_mask, node_mask) h_nodes = h_nodes + h_nodes_new h_edges12 , h_edges21 = self.gather_to_edges(h_nodes, edge_dynamic_index, dynamic_edge_mask, self.single_gnn_layer) h_edges = th.cat([h_edges12, h_edges21], dim=-1) t2 = time.time() # print('cal embedding time:', t2-t1) return compute_state(observations, h_edges) @staticmethod def get_policy_feature_dim(node_dim: int) -> int: return node_dim * 2 @staticmethod def get_value_feature_dim(node_dim: int) -> int: return node_dim * 2 class FacilityLocationAttentionGNNExtractor(FacilityLocationGNNExtractor): def __init__( self, observation_space: spaces.Dict, num_gnn_layers: int = 2, node_dim: int = 32, ) -> None: super().__init__(observation_space, num_gnn_layers, node_dim) num_node_features = observation_space.spaces['node_features'].shape[1] self.node_encoder = self.create_node_encoder(num_node_features, node_dim) self.gnn_layers = self.create_gnn(num_gnn_layers, node_dim) self.attention = nn.MultiheadAttention(node_dim, node_dim) def forward(self, observations: TensorDict) -> Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]: node_features = observations['node_features'] h_nodes = self.node_encoder(node_features) edge_static_index = observations['static_adjacency_list'].long() edge_dynamic_index = observations['dynamic_adjacency_list'].long() node_mask = observations['node_mask'].bool() edge_mask = observations['edge_mask'].bool() for gnn_layer in self.gnn_layers: h_edges = self.gather_to_edges(h_nodes, edge_static_index, edge_mask, gnn_layer) h_nodes_new = self.scatter_to_nodes(h_edges, edge_static_index, edge_mask, node_mask) h_nodes = h_nodes + h_nodes_new h_nodes = self.attention(h_nodes, h_nodes, h_nodes)[0] return compute_state(observations, h_nodes)