Spaces:
Runtime error
Runtime error
from collections import OrderedDict | |
from typing import Tuple | |
from gym import spaces | |
import torch as th | |
from torch import nn | |
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor | |
from stable_baselines3.common.type_aliases import TensorDict | |
import time | |
def mean_features(h: th.Tensor, mask: th.Tensor): | |
float_mask = mask.float() | |
mean_h = (h * float_mask.unsqueeze(-1)).sum(dim=1) / float_mask.sum(dim=1, keepdim=True) | |
return mean_h | |
# def compute_state(observations: TensorDict, h_nodes: th.Tensor): | |
# node_mask = observations['node_mask'].bool() | |
# mean_h_nodes = mean_features(h_nodes, node_mask) | |
# old_facility_mask = observations['old_facility_mask'].bool() | |
# h_old_facility = mean_features(h_nodes, old_facility_mask) | |
# h_old_facility_repeat = h_old_facility.unsqueeze(1).expand(-1, h_nodes.shape[1], -1) | |
# state_policy_old_facility = th.cat([ | |
# h_nodes, | |
# h_old_facility_repeat, | |
# h_nodes - h_old_facility_repeat, | |
# h_nodes * h_old_facility_repeat], dim=-1) | |
# new_facility_mask = observations['new_facility_mask'].bool() | |
# h_new_facility = mean_features(h_nodes, new_facility_mask) | |
# h_new_facility_repeat = h_new_facility.unsqueeze(1).expand(-1, h_nodes.shape[1], -1) | |
# state_policy_new_facility = th.cat([ | |
# h_nodes, | |
# h_new_facility_repeat, | |
# h_nodes - h_new_facility_repeat, | |
# state_value = th.cat([ | |
# mean_h_nodes, | |
# h_old_facility, | |
# h_new_facility], dim=-1) | |
# return state_policy_old_facility, state_policy_new_facility, state_value, old_facility_mask, new_facility_mask | |
def compute_state(observations: TensorDict, h_edges: th.Tensor): | |
dynamic_edge_mask = observations['dynamic_edge_mask'].bool() | |
mean_h_edges = mean_features(h_edges, dynamic_edge_mask) | |
state_policy_facility_pair = h_edges | |
state_value = mean_h_edges | |
return state_policy_facility_pair, state_value, dynamic_edge_mask | |
class FacilityLocationMLPExtractor(BaseFeaturesExtractor): | |
def __init__( | |
self, | |
observation_space: spaces.Dict, | |
hidden_units: Tuple = (32, 32), | |
) -> None: | |
super().__init__(observation_space, features_dim=1) | |
self.node_mlp = self.create_mlp(observation_space.spaces['node_features'].shape[1], hidden_units) | |
def create_mlp(input_dim: int, hidden_units: Tuple) -> nn.Sequential: | |
layers = OrderedDict() | |
for i, units in enumerate(hidden_units): | |
if i == 0: | |
layers[f'mlp-extractor-linear_{i}'] = nn.Linear(input_dim, units) | |
else: | |
layers[f'mlp-extractor-linear_{i}'] = nn.Linear(hidden_units[i - 1], units) | |
layers[f'mlp-extractor-tanh_{i}'] = nn.Tanh() | |
return nn.Sequential(layers) | |
def forward(self, observations: TensorDict) -> Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]: | |
node_features = observations['node_features'] | |
h_nodes = self.node_mlp(node_features) | |
return compute_state(observations, h_nodes) | |
def get_policy_feature_dim(node_dim: int) -> int: | |
return node_dim * 4 | |
def get_value_feature_dim(node_dim: int) -> int: | |
return node_dim * 3 | |
class FacilityLocationGNNExtractor(BaseFeaturesExtractor): | |
def __init__( | |
self, | |
observation_space: spaces.Dict, | |
num_gnn_layers: int = 2, | |
node_dim: int = 32, | |
) -> None: | |
super().__init__(observation_space, features_dim=1) | |
num_node_features = observation_space.spaces['node_features'].shape[1] | |
self.node_encoder = self.create_node_encoder(num_node_features, node_dim) | |
self.gnn_layers = self.create_gnn(num_gnn_layers, node_dim) | |
self.single_gnn_layer = self.create_gnn(1, node_dim)[0] | |
def create_node_encoder(num_node_features: int, node_dim: int) -> nn.Sequential: | |
node_encoder = nn.Sequential( | |
nn.Linear(num_node_features, node_dim), | |
nn.Tanh()) | |
return node_encoder | |
def create_gnn(num_gnn_layers: int, node_dim: int) -> nn.ModuleList: | |
layers = nn.ModuleList() | |
for i in range(num_gnn_layers): | |
gnn_layer = nn.Sequential( | |
nn.Linear(node_dim, node_dim), | |
nn.Tanh()) | |
layers.append(gnn_layer) | |
return layers | |
def scatter_count(h_edges, indices, edge_mask, max_num_nodes): | |
batch_size = h_edges.shape[0] | |
num_latents = h_edges.shape[2] | |
h_nodes = th.zeros(batch_size, max_num_nodes, num_latents).to(h_edges.device) | |
count_edge = th.zeros_like(h_nodes) | |
count = th.broadcast_to(edge_mask.unsqueeze(-1), h_edges.shape).float() | |
idx = indices.unsqueeze(-1).expand(-1, -1, num_latents) | |
h_nodes = h_nodes.scatter_add_(1, idx, h_edges) | |
count_edge = count_edge.scatter_add_(1, idx, count) | |
return h_nodes, count_edge | |
def gather_to_edges(h_nodes, edge_index, edge_mask, gnn_layer): | |
h_nodes = gnn_layer(h_nodes) | |
h_edges_12 = th.gather(h_nodes, 1, edge_index[:, :, 0].unsqueeze(-1).expand(-1, -1, h_nodes.size(-1))) | |
h_edges_21 = th.gather(h_nodes, 1, edge_index[:, :, 1].unsqueeze(-1).expand(-1, -1, h_nodes.size(-1))) | |
mask = th.broadcast_to(edge_mask.unsqueeze(-1), h_edges_12.shape) | |
h_edges_12 = th.where(mask, h_edges_12, th.zeros_like(h_edges_12)) | |
h_edges_21 = th.where(mask, h_edges_21, th.zeros_like(h_edges_21)) | |
return h_edges_12, h_edges_21 | |
def scatter_to_nodes(cls, h_edges, edge_index, edge_mask, node_mask): | |
h_edges_12, h_edges_21 = h_edges | |
max_num_nodes = node_mask.shape[1] | |
h_nodes_1, count_1 = cls.scatter_count(h_edges_21, edge_index[:, :, 0], edge_mask, max_num_nodes) | |
h_nodes_2, count_2 = cls.scatter_count(h_edges_12, edge_index[:, :, 1], edge_mask, max_num_nodes) | |
h_nodes_sum = h_nodes_1 + h_nodes_2 | |
mask = th.broadcast_to(node_mask.unsqueeze(-1), h_nodes_sum.shape) | |
count = count_1 + count_2 | |
count_padding = th.ones_like(count) | |
count = th.where(mask, count, count_padding) | |
h_nodes = h_nodes_sum / count | |
return h_nodes | |
def forward(self, observations: TensorDict) -> Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]: | |
t1 = time.time() | |
node_features = observations['node_features'] | |
h_nodes = self.node_encoder(node_features) | |
edge_static_index = observations['static_adjacency_list'].long() | |
edge_dynamic_index = observations['dynamic_adjacency_list'].long() | |
node_mask = observations['node_mask'].bool() | |
static_edge_mask = observations['static_edge_mask'].bool() | |
dynamic_edge_mask = observations['dynamic_edge_mask'].bool() | |
for gnn_layer in self.gnn_layers: | |
h_edges = self.gather_to_edges(h_nodes, edge_static_index, static_edge_mask, gnn_layer) | |
h_nodes_new = self.scatter_to_nodes(h_edges, edge_static_index, static_edge_mask, node_mask) | |
h_nodes = h_nodes + h_nodes_new | |
h_edges12 , h_edges21 = self.gather_to_edges(h_nodes, edge_dynamic_index, dynamic_edge_mask, self.single_gnn_layer) | |
h_edges = th.cat([h_edges12, h_edges21], dim=-1) | |
t2 = time.time() | |
# print('cal embedding time:', t2-t1) | |
return compute_state(observations, h_edges) | |
def get_policy_feature_dim(node_dim: int) -> int: | |
return node_dim * 2 | |
def get_value_feature_dim(node_dim: int) -> int: | |
return node_dim * 2 | |
class FacilityLocationAttentionGNNExtractor(FacilityLocationGNNExtractor): | |
def __init__( | |
self, | |
observation_space: spaces.Dict, | |
num_gnn_layers: int = 2, | |
node_dim: int = 32, | |
) -> None: | |
super().__init__(observation_space, num_gnn_layers, node_dim) | |
num_node_features = observation_space.spaces['node_features'].shape[1] | |
self.node_encoder = self.create_node_encoder(num_node_features, node_dim) | |
self.gnn_layers = self.create_gnn(num_gnn_layers, node_dim) | |
self.attention = nn.MultiheadAttention(node_dim, node_dim) | |
def forward(self, observations: TensorDict) -> Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]: | |
node_features = observations['node_features'] | |
h_nodes = self.node_encoder(node_features) | |
edge_static_index = observations['static_adjacency_list'].long() | |
edge_dynamic_index = observations['dynamic_adjacency_list'].long() | |
node_mask = observations['node_mask'].bool() | |
edge_mask = observations['edge_mask'].bool() | |
for gnn_layer in self.gnn_layers: | |
h_edges = self.gather_to_edges(h_nodes, edge_static_index, edge_mask, gnn_layer) | |
h_nodes_new = self.scatter_to_nodes(h_edges, edge_static_index, edge_mask, node_mask) | |
h_nodes = h_nodes + h_nodes_new | |
h_nodes = self.attention(h_nodes, h_nodes, h_nodes)[0] | |
return compute_state(observations, h_nodes) |