from typing import Dict, Tuple, Text import numpy as np from facility_location.env.facility_location_client import FacilityLocationClient from facility_location.utils.config import Config class ObsExtractor: def __init__(self, cfg: Config, flc: FacilityLocationClient, node_range: int, edge_range: int): self.cfg = cfg self._flc = flc self._node_range = node_range self._edge_range = edge_range self._construct_virtual_node_feature() self._construct_node_features() self._construct_action_mask() def _construct_virtual_node_feature(self) -> None: virtual_node_facility = 0 virtual_node_distance_min = 0 virtual_node_distance_sub_min = 0 virtual_node_cost_min = 0 virtual_node_cost_sub_min = 0 virtual_gain = 0 virtual_loss = 0 virtual_node_x = 0.5 virtual_node_y = 0.5 virtual_node_demand = 1 virtual_node_avg_distance = 0 virtual_node_avg_cost = 0 self._virtual_dynamic_node_feature = np.array([ virtual_node_facility, virtual_node_distance_min, virtual_node_distance_sub_min, virtual_node_cost_min, virtual_node_cost_sub_min, virtual_gain, virtual_loss, ], dtype=np.float32) self._virtual_static_node_feature = np.array([ virtual_node_x, virtual_node_y, virtual_node_demand, virtual_node_avg_distance, virtual_node_avg_cost, ], dtype=np.float32) self._virtual_node_feature = np.concatenate([ self._virtual_dynamic_node_feature, self._virtual_static_node_feature, ], axis=-1) def _construct_node_features(self) -> None: self._node_features = np.zeros((self._node_range, self._virtual_node_feature.size), dtype=np.float32) def _construct_action_mask(self) -> None: self._old_facility_mask = np.full(self._node_range, False) self._new_facility_mask = np.full(self._node_range, False) def get_node_dim(self) -> int: return self._virtual_node_feature.size def reset(self) -> None: self._compute_static_obs() self._reset_node_features() self._reset_action_mask() def _compute_static_obs(self) -> None: xy, demands, n, _ = self._flc.get_instance() if n + 2 > self._node_range: print(n, self._node_range) # raise ValueError('The number of nodes exceeds the maximum limit.') self._n = n avg_distance, avg_cost = self._flc.get_avg_distance_and_cost() avg_distance = avg_distance / np.max(avg_distance) avg_cost = avg_cost / np.max(avg_cost) self._static_node_features = np.stack([ xy[:, 0], xy[:, 1], demands, avg_distance, avg_cost, ], axis=-1).astype(np.float32) static_adjacency_list = self._flc.get_static_adjacency_list() obs_node_mask = np.full(1 + n, True) self._obs_node_mask = self._pad_mask(obs_node_mask, self._node_range, 'nodes') obs_static_edge_mask = np.full(n + static_adjacency_list.shape[0], True) self._obs_static_edge_mask = self._pad_mask(obs_static_edge_mask, self._edge_range, 'edges') self._static_adjacency_list = self._pad_edge(static_adjacency_list) def _reset_node_features(self) -> None: self._node_features[:, :] = 0 self._node_features[0] = self._virtual_node_feature self._node_features[1:self._n+1, len(self._virtual_dynamic_node_feature):] = self._static_node_features def _reset_action_mask(self) -> None: self._old_facility_mask[:] = False self._new_facility_mask[:] = False def get_obs(self, t: int) -> Dict: obs_nodes, obs_static_edges, obs_dynamic_edges, \ obs_node_mask, obs_static_edge_mask, obs_dynamic_edges_mask = self._get_obs_graph() obs = { 'node_features': obs_nodes, 'static_adjacency_list': obs_static_edges, 'dynamic_adjacency_list': obs_dynamic_edges, 'node_mask': obs_node_mask, 'static_edge_mask': obs_static_edge_mask, 'dynamic_edge_mask': obs_dynamic_edges_mask, } return obs def _get_obs_graph(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: EPS = 1e-8 facility = self._flc.get_current_solution().astype(np.float32) distance = self._flc.get_current_distance().astype(np.float32) distance = distance / np.max(distance) cost = self._flc.get_current_cost().astype(np.float32) cost = cost / np.max(cost) gain, loss = self._flc.get_gain_and_loss() gain = gain / (np.max(gain) + EPS) loss = loss / (np.max(loss) + EPS) dynamic_node_features = np.stack([ facility, distance[:,0], distance[:,1], cost[:,0], cost[:,1], gain, loss, ], axis=-1) self._node_features[1:self._n+1, :len(self._virtual_dynamic_node_feature)] = dynamic_node_features obs_nodes = self._node_features obs_static_edges = self._static_adjacency_list obs_dynamic_edges = self._flc.get_dynamic_adjacency_list() # print(obs_dynamic_edges.shape) obs_dynamic_edge_mask = np.full(obs_dynamic_edges.shape[0], True) obs_node_mask = self._obs_node_mask obs_static_edge_mask = self._obs_static_edge_mask obs_dynamic_edges = self._pad_edge_wo_virtual(obs_dynamic_edges) obs_dynamic_edge_mask = self._pad_mask(obs_dynamic_edge_mask, self._edge_range, 'edges') return obs_nodes, obs_static_edges, obs_dynamic_edges, obs_node_mask, obs_static_edge_mask, obs_dynamic_edge_mask # return obs_nodes, obs_static_edges, obs_node_mask, obs_edge_mask def _get_obs_action_mask(self, t: int) -> Tuple[np.ndarray, np.ndarray]: old_facility_mask, new_facility_mask = self._flc.get_facility_mask() old_tabu_mask, new_tabu_mask = self._flc.get_tabu_mask(t) self._old_facility_mask[1:self._n+1] = np.logical_and(old_facility_mask, old_tabu_mask) self._new_facility_mask[1:self._n+1] = np.logical_and(new_facility_mask, new_tabu_mask) obs_old_facility_mask = self._old_facility_mask obs_new_facility_mask = self._new_facility_mask if not np.any(obs_old_facility_mask) or not np.any(obs_new_facility_mask): raise ValueError('The action mask is empty.') return obs_old_facility_mask, obs_new_facility_mask @staticmethod def _pad_mask(mask: np.ndarray, max_num: int, name: Text) -> np.ndarray: pad = (0, max_num - mask.size) if pad[1] < 0: raise ValueError(f'The number of {name} exceeds the maximum limit.') return np.pad(mask, pad, mode='constant', constant_values=False) def _pad_edge(self, edge: np.ndarray) -> np.ndarray: virtual_edge = np.stack([np.zeros(self._n), np.arange(1, self._n + 1)], axis=-1).astype(np.int32) edge = np.concatenate([virtual_edge, edge + 1], axis=0) pad = ((0, self._edge_range - edge.shape[0]), (0, 0)) if pad[0][1] < 0: raise ValueError('The number of edges exceeds the maximum limit.') return np.pad(edge, pad, mode='constant', constant_values=self._node_range - 1) def _pad_edge_wo_virtual(self, edge: np.ndarray) -> np.ndarray: pad = ((0, self._edge_range - edge.shape[0]), (0, 0)) if pad[0][1] < 0: print(self._edge_range, edge.shape[0]) raise ValueError('The number of edges exceeds the maximum limit.') return np.pad(edge + 1, pad, mode='constant', constant_values=self._node_range - 1)