MFLP / facility_location /agent /features_extractor.py
苏泓源
update
a257639
from collections import OrderedDict
from typing import Tuple
from gym import spaces
import torch as th
from torch import nn
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.type_aliases import TensorDict
import time
def mean_features(h: th.Tensor, mask: th.Tensor):
float_mask = mask.float()
mean_h = (h * float_mask.unsqueeze(-1)).sum(dim=1) / float_mask.sum(dim=1, keepdim=True)
return mean_h
# def compute_state(observations: TensorDict, h_nodes: th.Tensor):
# node_mask = observations['node_mask'].bool()
# mean_h_nodes = mean_features(h_nodes, node_mask)
# old_facility_mask = observations['old_facility_mask'].bool()
# h_old_facility = mean_features(h_nodes, old_facility_mask)
# h_old_facility_repeat = h_old_facility.unsqueeze(1).expand(-1, h_nodes.shape[1], -1)
# state_policy_old_facility = th.cat([
# h_nodes,
# h_old_facility_repeat,
# h_nodes - h_old_facility_repeat,
# h_nodes * h_old_facility_repeat], dim=-1)
# new_facility_mask = observations['new_facility_mask'].bool()
# h_new_facility = mean_features(h_nodes, new_facility_mask)
# h_new_facility_repeat = h_new_facility.unsqueeze(1).expand(-1, h_nodes.shape[1], -1)
# state_policy_new_facility = th.cat([
# h_nodes,
# h_new_facility_repeat,
# h_nodes - h_new_facility_repeat,
# state_value = th.cat([
# mean_h_nodes,
# h_old_facility,
# h_new_facility], dim=-1)
# return state_policy_old_facility, state_policy_new_facility, state_value, old_facility_mask, new_facility_mask
def compute_state(observations: TensorDict, h_edges: th.Tensor):
dynamic_edge_mask = observations['dynamic_edge_mask'].bool()
mean_h_edges = mean_features(h_edges, dynamic_edge_mask)
state_policy_facility_pair = h_edges
state_value = mean_h_edges
return state_policy_facility_pair, state_value, dynamic_edge_mask
class FacilityLocationMLPExtractor(BaseFeaturesExtractor):
def __init__(
self,
observation_space: spaces.Dict,
hidden_units: Tuple = (32, 32),
) -> None:
super().__init__(observation_space, features_dim=1)
self.node_mlp = self.create_mlp(observation_space.spaces['node_features'].shape[1], hidden_units)
@staticmethod
def create_mlp(input_dim: int, hidden_units: Tuple) -> nn.Sequential:
layers = OrderedDict()
for i, units in enumerate(hidden_units):
if i == 0:
layers[f'mlp-extractor-linear_{i}'] = nn.Linear(input_dim, units)
else:
layers[f'mlp-extractor-linear_{i}'] = nn.Linear(hidden_units[i - 1], units)
layers[f'mlp-extractor-tanh_{i}'] = nn.Tanh()
return nn.Sequential(layers)
def forward(self, observations: TensorDict) -> Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]:
node_features = observations['node_features']
h_nodes = self.node_mlp(node_features)
return compute_state(observations, h_nodes)
@staticmethod
def get_policy_feature_dim(node_dim: int) -> int:
return node_dim * 4
@staticmethod
def get_value_feature_dim(node_dim: int) -> int:
return node_dim * 3
class FacilityLocationGNNExtractor(BaseFeaturesExtractor):
def __init__(
self,
observation_space: spaces.Dict,
num_gnn_layers: int = 2,
node_dim: int = 32,
) -> None:
super().__init__(observation_space, features_dim=1)
num_node_features = observation_space.spaces['node_features'].shape[1]
self.node_encoder = self.create_node_encoder(num_node_features, node_dim)
self.gnn_layers = self.create_gnn(num_gnn_layers, node_dim)
self.single_gnn_layer = self.create_gnn(1, node_dim)[0]
@staticmethod
def create_node_encoder(num_node_features: int, node_dim: int) -> nn.Sequential:
node_encoder = nn.Sequential(
nn.Linear(num_node_features, node_dim),
nn.Tanh())
return node_encoder
@staticmethod
def create_gnn(num_gnn_layers: int, node_dim: int) -> nn.ModuleList:
layers = nn.ModuleList()
for i in range(num_gnn_layers):
gnn_layer = nn.Sequential(
nn.Linear(node_dim, node_dim),
nn.Tanh())
layers.append(gnn_layer)
return layers
@staticmethod
def scatter_count(h_edges, indices, edge_mask, max_num_nodes):
batch_size = h_edges.shape[0]
num_latents = h_edges.shape[2]
h_nodes = th.zeros(batch_size, max_num_nodes, num_latents).to(h_edges.device)
count_edge = th.zeros_like(h_nodes)
count = th.broadcast_to(edge_mask.unsqueeze(-1), h_edges.shape).float()
idx = indices.unsqueeze(-1).expand(-1, -1, num_latents)
h_nodes = h_nodes.scatter_add_(1, idx, h_edges)
count_edge = count_edge.scatter_add_(1, idx, count)
return h_nodes, count_edge
@staticmethod
def gather_to_edges(h_nodes, edge_index, edge_mask, gnn_layer):
h_nodes = gnn_layer(h_nodes)
h_edges_12 = th.gather(h_nodes, 1, edge_index[:, :, 0].unsqueeze(-1).expand(-1, -1, h_nodes.size(-1)))
h_edges_21 = th.gather(h_nodes, 1, edge_index[:, :, 1].unsqueeze(-1).expand(-1, -1, h_nodes.size(-1)))
mask = th.broadcast_to(edge_mask.unsqueeze(-1), h_edges_12.shape)
h_edges_12 = th.where(mask, h_edges_12, th.zeros_like(h_edges_12))
h_edges_21 = th.where(mask, h_edges_21, th.zeros_like(h_edges_21))
return h_edges_12, h_edges_21
@classmethod
def scatter_to_nodes(cls, h_edges, edge_index, edge_mask, node_mask):
h_edges_12, h_edges_21 = h_edges
max_num_nodes = node_mask.shape[1]
h_nodes_1, count_1 = cls.scatter_count(h_edges_21, edge_index[:, :, 0], edge_mask, max_num_nodes)
h_nodes_2, count_2 = cls.scatter_count(h_edges_12, edge_index[:, :, 1], edge_mask, max_num_nodes)
h_nodes_sum = h_nodes_1 + h_nodes_2
mask = th.broadcast_to(node_mask.unsqueeze(-1), h_nodes_sum.shape)
count = count_1 + count_2
count_padding = th.ones_like(count)
count = th.where(mask, count, count_padding)
h_nodes = h_nodes_sum / count
return h_nodes
def forward(self, observations: TensorDict) -> Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]:
t1 = time.time()
node_features = observations['node_features']
h_nodes = self.node_encoder(node_features)
edge_static_index = observations['static_adjacency_list'].long()
edge_dynamic_index = observations['dynamic_adjacency_list'].long()
node_mask = observations['node_mask'].bool()
static_edge_mask = observations['static_edge_mask'].bool()
dynamic_edge_mask = observations['dynamic_edge_mask'].bool()
for gnn_layer in self.gnn_layers:
h_edges = self.gather_to_edges(h_nodes, edge_static_index, static_edge_mask, gnn_layer)
h_nodes_new = self.scatter_to_nodes(h_edges, edge_static_index, static_edge_mask, node_mask)
h_nodes = h_nodes + h_nodes_new
h_edges12 , h_edges21 = self.gather_to_edges(h_nodes, edge_dynamic_index, dynamic_edge_mask, self.single_gnn_layer)
h_edges = th.cat([h_edges12, h_edges21], dim=-1)
t2 = time.time()
# print('cal embedding time:', t2-t1)
return compute_state(observations, h_edges)
@staticmethod
def get_policy_feature_dim(node_dim: int) -> int:
return node_dim * 2
@staticmethod
def get_value_feature_dim(node_dim: int) -> int:
return node_dim * 2
class FacilityLocationAttentionGNNExtractor(FacilityLocationGNNExtractor):
def __init__(
self,
observation_space: spaces.Dict,
num_gnn_layers: int = 2,
node_dim: int = 32,
) -> None:
super().__init__(observation_space, num_gnn_layers, node_dim)
num_node_features = observation_space.spaces['node_features'].shape[1]
self.node_encoder = self.create_node_encoder(num_node_features, node_dim)
self.gnn_layers = self.create_gnn(num_gnn_layers, node_dim)
self.attention = nn.MultiheadAttention(node_dim, node_dim)
def forward(self, observations: TensorDict) -> Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]:
node_features = observations['node_features']
h_nodes = self.node_encoder(node_features)
edge_static_index = observations['static_adjacency_list'].long()
edge_dynamic_index = observations['dynamic_adjacency_list'].long()
node_mask = observations['node_mask'].bool()
edge_mask = observations['edge_mask'].bool()
for gnn_layer in self.gnn_layers:
h_edges = self.gather_to_edges(h_nodes, edge_static_index, edge_mask, gnn_layer)
h_nodes_new = self.scatter_to_nodes(h_edges, edge_static_index, edge_mask, node_mask)
h_nodes = h_nodes + h_nodes_new
h_nodes = self.attention(h_nodes, h_nodes, h_nodes)[0]
return compute_state(observations, h_nodes)