ubuntu commited on
Commit
fc0a115
·
1 Parent(s): 68e5f5a

try fix bugs

Browse files
__pycache__/genn_astar.cpython-39.pyc ADDED
Binary file (4.78 kB). View file
 
__pycache__/one_hot.cpython-39.pyc ADDED
Binary file (1.54 kB). View file
 
best_genn_AIDS700nef_gcn_astar.pt DELETED
Binary file (45.4 kB)
 
genn_astar.py CHANGED
@@ -1,187 +1,191 @@
1
- import os
2
- import numpy as np
3
- import networkx as nx
4
- import pygmtools as pygm
5
- import torch
6
- try:
7
- from torch_geometric.data import Data
8
- except:
9
- os.system("pip install --no-index torch-sparse -f https://pytorch-geometric.com/whl/torch-2.0.0%2Bcpu.html")
10
- os.system("pip install --no-index torch-scatter -f https://pytorch-geometric.com/whl/torch-2.0.0%2Bcpu.html")
11
- os.system("pip install --no-index torch-spline-conv -f https://pytorch-geometric.com/whl/torch-2.0.0%2Bcpu.html")
12
- os.system("pip install --no-index torch-cluster -f https://pytorch-geometric.com/whl/torch-2.0.0%2Bcpu.html")
13
- from torch_geometric.data import Data
14
- from one_hot import one_hot
15
- from torch_geometric.transforms import OneHotDegree
16
- import matplotlib.pyplot as plt
17
- import pygmtools as pygm
18
- pygm.set_backend('pytorch')
19
-
20
-
21
- ######################################################
22
- # Constant Variable #
23
- ######################################################
24
-
25
- AIDS700NEF_TYPE = [
26
- 'O', 'S', 'C', 'N', 'Cl', 'Br', 'B', 'Si', 'Hg', 'I', 'Bi', 'P', 'F',
27
- 'Cu', 'Ho', 'Pd', 'Ru', 'Pt', 'Sn', 'Li', 'Ga', 'Tb', 'As', 'Co', 'Pb',
28
- 'Sb', 'Se', 'Ni', 'Te'
29
- ]
30
-
31
-
32
- COLOR = [
33
- '#FF69B4', # O - 热情的粉红色
34
- '#00CED1', # S - 深蓝绿色
35
- '#FFD700', # C - 金色
36
- '#FFA500', # N - 橙色
37
- '#FF6347', # Cl - 番茄红色
38
- '#8B008B', # Br - 深洋红色
39
- '#00FF7F', # B - 春天的绿色
40
- '#40E0D0', # Si - 绿松石色
41
- '#FF4500', # Hg - 橙红色
42
- '#9932CC', # I - 深兰花紫色
43
- '#9370DB', # Bi - 中紫色
44
- '#FFA500', # P - 橙色
45
- '#FFFF00', # F - 黄色
46
- '#B8860B', # Cu - 深金色
47
- '#7FFFD4', # Ho - 碧绿色
48
- '#FFD700', # Pd - 金色
49
- '#B22222', # Ru - 砖红色
50
- '#E5E4E2', # Pt - 浅灰色
51
- '#A9A9A9', # Sn - 深灰色
52
- '#32CD32', # Li - 酸橙色
53
- '#CD853F', # Ga - 秘鲁色
54
- '#7FFFD4', # Tb - 碧绿色
55
- '#8A2BE2', # As - 紫罗兰色
56
- '#FFD700', # Co - 金色
57
- '#808080', # Pb - 灰色
58
- '#A9A9A9', # Sb - 深灰色
59
- '#FA8072', # Se - 鲑鱼色
60
- '#BEBEBE', # Ni - 浅灰色
61
- '#800080' # Te - 紫色
62
- ]
63
-
64
-
65
- ######################################################
66
- # Utils Func #
67
- ######################################################
68
-
69
- def from_gexf(filename: str, node_types: list=None):
70
- r"""
71
- Read Data from GEXF file
72
- """
73
- if not filename.endswith('.gexf'):
74
- raise ValueError("File type error, 'from_gexf' function only supports GEXF files")
75
- graph = nx.read_gexf(filename)
76
- mapping = {name: j for j, name in enumerate(graph.nodes())}
77
- graph = nx.relabel_nodes(graph, mapping)
78
- edge_index = torch.from_numpy(np.array(graph.edges, dtype=np.int64).transpose())
79
- x = None
80
- labels = None
81
- data = None
82
- colors = None
83
- if 'type' in graph.nodes(data=True)[0].keys():
84
- labels = dict()
85
- colors = list()
86
- num_nodes = graph.number_of_nodes()
87
- x = torch.zeros(num_nodes, dtype=torch.long)
88
- node_types = AIDS700NEF_TYPE if node_types is None else node_types
89
- for node, info in graph.nodes(data=True):
90
- x[int(node)] = node_types.index(info['type'])
91
- labels[int(node)] = str(int(node)) + info['type']
92
- colors.append(COLOR[x[int(node)]])
93
- x = one_hot(x, num_classes=len(node_types))
94
- data = Data(x=x, edge_index=edge_index, edge_attr=None)
95
- return graph, data, labels, colors
96
-
97
-
98
- def draw(graph, colors, labels, filename, title, pos_type=None):
99
- if pos_type is None:
100
- pos = nx.kamada_kawai_layout(graph)
101
- elif pos_type == "spring":
102
- pos = nx.spring_layout(graph)
103
- plt.figure()
104
- plt.gca().set_title(title)
105
- nx.draw(graph, pos, with_labels=True, node_color=colors, edge_color='gray', labels=labels)
106
- plt.savefig(filename)
107
- plt.clf()
108
-
109
-
110
- ######################################################
111
- # GED UI #
112
- ######################################################
113
-
114
- def astar(
115
- g1_path: str,
116
- g2_path: str,
117
- output_path: str="examples",
118
- filename: str="example",
119
- device='cpu'
120
- ):
121
- if not os.path.exists(output_path):
122
- os.mkdir(output_path)
123
- output_filename = os.path.join(output_path, filename) + "_{}.png"
124
-
125
- # Load data
126
- g1, d1, l1, c1 = from_gexf(g1_path)
127
- g2, d2, l2, c2 = from_gexf(g2_path)
128
- if len(c1) > len(c2):
129
- graph1, data1, labels1, colors1 = g2, d2, l2, c2
130
- graph2, data2, labels2, colors2 = g1, d1, l1, c1
131
- else:
132
- graph1, data1, labels1, colors1 = g1, d1, l1, c1
133
- graph2, data2, labels2, colors2 = g2, d2, l2, c2
134
-
135
- # Build Graph and Adj Matrix
136
- data1 = OneHotDegree(max_degree=6)(data1)
137
- data2 = OneHotDegree(max_degree=6)(data2)
138
- feat1 = data1.x.to(device)
139
- feat2 = data2.x.to(device)
140
- A1 = torch.tensor(pygm.utils.from_networkx(graph1)).float().to(device)
141
- A2 = torch.tensor(pygm.utils.from_networkx(graph2)).float().to(device)
142
-
143
- # Caculate the ged
144
- x_pred = pygm.genn_astar(feat1, feat2, A1, A2, return_network=False)
145
-
146
- # Plot
147
- draw(graph1, colors1, labels1, output_filename.format(1), "Graph1")
148
- draw(graph2, colors2, labels2, output_filename.format(5), f"Graph2")
149
-
150
- # Match Process
151
- total_cost = 0
152
- labels1_1 = labels1.copy()
153
- for i in range(x_pred.shape[0]):
154
- target = torch.nonzero(x_pred[i])[0].item()
155
- labels1_1[i] = labels1[i].replace(str(i), str(target))
156
- title = "Node Match"
157
- draw(graph1, colors1, labels1_1, output_filename.format(2), title)
158
-
159
- # Node Change
160
- cur_cost = 0
161
- labels1_2 = labels1.copy()
162
- colors1_2 = colors1.copy()
163
- target2ori = dict()
164
- targets = list()
165
- for i in range(x_pred.shape[0]):
166
- target = torch.nonzero(x_pred[i])[0].item()
167
- if labels1_1[i] != labels2[target]:
168
- cur_cost += 1
169
- labels1_2[i] = labels2[target]
170
- colors1_2[i] = colors2[target]
171
- target2ori[target] = i
172
- targets.append(target)
173
- total_cost += cur_cost
174
- title = f"Node Change"
175
- draw(graph1, colors1_2, labels1_2, output_filename.format(3), title)
176
-
177
- # Edge Change
178
- leave_cost = np.array(graph2).shape[0] - np.array(graph1).shape[0]
179
- leave_cost += graph2.number_of_nodes() - graph1.number_of_nodes()
180
- e2 = np.array(graph2.edges)
181
- new_edges = list()
182
- for edge in e2:
183
- if edge[0] in targets and edge[1] in targets:
184
- new_edges.append([target2ori[edge[0]], target2ori[edge[1]]])
185
- graph1.edges = nx.Graph(new_edges).edges
186
- title = f"Edge Change"
 
 
 
 
187
  draw(graph1, colors1_2, labels1_2, output_filename.format(4), title, pos_type="spring")
 
1
+ import os
2
+ import numpy as np
3
+ import networkx as nx
4
+ import pygmtools as pygm
5
+ import torch
6
+ try:
7
+ from torch_geometric.data import Data
8
+ except:
9
+ os.system("pip install --no-index torch-sparse -f https://pytorch-geometric.com/whl/torch-2.0.0%2Bcpu.html")
10
+ os.system("pip install --no-index torch-scatter -f https://pytorch-geometric.com/whl/torch-2.0.0%2Bcpu.html")
11
+ os.system("pip install --no-index torch-spline-conv -f https://pytorch-geometric.com/whl/torch-2.0.0%2Bcpu.html")
12
+ os.system("pip install --no-index torch-cluster -f https://pytorch-geometric.com/whl/torch-2.0.0%2Bcpu.html")
13
+ from torch_geometric.data import Data
14
+ from one_hot import one_hot
15
+ from torch_geometric.transforms import OneHotDegree
16
+ import matplotlib.pyplot as plt
17
+ import pygmtools as pygm
18
+ pygm.set_backend('pytorch')
19
+
20
+
21
+ ######################################################
22
+ # Constant Variable #
23
+ ######################################################
24
+
25
+ AIDS700NEF_TYPE = [
26
+ 'O', 'S', 'C', 'N', 'Cl', 'Br', 'B', 'Si', 'Hg', 'I', 'Bi', 'P', 'F',
27
+ 'Cu', 'Ho', 'Pd', 'Ru', 'Pt', 'Sn', 'Li', 'Ga', 'Tb', 'As', 'Co', 'Pb',
28
+ 'Sb', 'Se', 'Ni', 'Te'
29
+ ]
30
+
31
+
32
+ COLOR = [
33
+ '#FF69B4', # O - 热情的粉红色
34
+ '#00CED1', # S - 深蓝绿色
35
+ '#FFD700', # C - 金色
36
+ '#FFA500', # N - 橙色
37
+ '#FF6347', # Cl - 番茄红色
38
+ '#8B008B', # Br - 深洋红色
39
+ '#00FF7F', # B - 春天的绿色
40
+ '#40E0D0', # Si - 绿松石色
41
+ '#FF4500', # Hg - 橙红色
42
+ '#9932CC', # I - 深兰花紫色
43
+ '#9370DB', # Bi - 中紫色
44
+ '#FFA500', # P - 橙色
45
+ '#FFFF00', # F - 黄色
46
+ '#B8860B', # Cu - 深金色
47
+ '#7FFFD4', # Ho - 碧绿色
48
+ '#FFD700', # Pd - 金色
49
+ '#B22222', # Ru - 砖红色
50
+ '#E5E4E2', # Pt - 浅灰色
51
+ '#A9A9A9', # Sn - 深灰色
52
+ '#32CD32', # Li - 酸橙色
53
+ '#CD853F', # Ga - 秘鲁色
54
+ '#7FFFD4', # Tb - 碧绿色
55
+ '#8A2BE2', # As - 紫罗兰色
56
+ '#FFD700', # Co - 金色
57
+ '#808080', # Pb - 灰色
58
+ '#A9A9A9', # Sb - 深灰色
59
+ '#FA8072', # Se - 鲑鱼色
60
+ '#BEBEBE', # Ni - 浅灰色
61
+ '#800080' # Te - 紫色
62
+ ]
63
+
64
+
65
+ ######################################################
66
+ # Utils Func #
67
+ ######################################################
68
+
69
+ def from_gexf(filename: str, node_types: list=None):
70
+ r"""
71
+ Read Data from GEXF file
72
+ """
73
+ if not filename.endswith('.gexf'):
74
+ raise ValueError("File type error, 'from_gexf' function only supports GEXF files")
75
+ graph = nx.read_gexf(filename)
76
+ mapping = {name: j for j, name in enumerate(graph.nodes())}
77
+ graph = nx.relabel_nodes(graph, mapping)
78
+ edge_index = torch.from_numpy(np.array(graph.edges, dtype=np.int64).transpose())
79
+ x = None
80
+ labels = None
81
+ data = None
82
+ colors = None
83
+ if 'type' in graph.nodes(data=True)[0].keys():
84
+ labels = dict()
85
+ colors = list()
86
+ num_nodes = graph.number_of_nodes()
87
+ x = torch.zeros(num_nodes, dtype=torch.long)
88
+ node_types = AIDS700NEF_TYPE if node_types is None else node_types
89
+ for node, info in graph.nodes(data=True):
90
+ x[int(node)] = node_types.index(info['type'])
91
+ labels[int(node)] = str(int(node)) + info['type']
92
+ colors.append(COLOR[x[int(node)]])
93
+ x = one_hot(x, num_classes=len(node_types))
94
+ data = Data(x=x, edge_index=edge_index, edge_attr=None)
95
+ return graph, data, labels, colors
96
+
97
+
98
+ def draw(graph, colors, labels, filename, title, pos_type=None):
99
+ if pos_type is None:
100
+ pos = nx.kamada_kawai_layout(graph)
101
+ elif pos_type == "spring":
102
+ pos = nx.spring_layout(graph)
103
+ plt.figure()
104
+ plt.gca().set_title(title)
105
+ nx.draw(graph, pos, with_labels=True, node_color=colors, edge_color='gray', labels=labels)
106
+ plt.savefig(filename)
107
+ plt.clf()
108
+
109
+
110
+ ######################################################
111
+ # GED UI #
112
+ ######################################################
113
+
114
+ def astar(
115
+ g1_path: str,
116
+ g2_path: str,
117
+ output_path: str="examples",
118
+ filename: str="example",
119
+ device='cpu'
120
+ ):
121
+ if not os.path.exists(output_path):
122
+ os.mkdir(output_path)
123
+ output_filename = os.path.join(output_path, filename) + "_{}.png"
124
+
125
+ # Load data
126
+ g1, d1, l1, c1 = from_gexf(g1_path)
127
+ g2, d2, l2, c2 = from_gexf(g2_path)
128
+ if len(c1) > len(c2):
129
+ graph1, data1, labels1, colors1 = g2, d2, l2, c2
130
+ graph2, data2, labels2, colors2 = g1, d1, l1, c1
131
+ else:
132
+ graph1, data1, labels1, colors1 = g1, d1, l1, c1
133
+ graph2, data2, labels2, colors2 = g2, d2, l2, c2
134
+
135
+ # Build Graph and Adj Matrix
136
+ data1 = OneHotDegree(max_degree=6)(data1)
137
+ data2 = OneHotDegree(max_degree=6)(data2)
138
+ feat1 = data1.x.to(device)
139
+ feat2 = data2.x.to(device)
140
+ A1 = torch.tensor(pygm.utils.from_networkx(graph1)).float().to(device)
141
+ A2 = torch.tensor(pygm.utils.from_networkx(graph2)).float().to(device)
142
+
143
+ import site
144
+ site_path = site.getsitepackages()[0]
145
+ pygm_path = os.path.join(site_path, "pygmtools")
146
+ print(os.listdir(pygm_path))
147
+ # Caculate the ged
148
+ x_pred = pygm.genn_astar(feat1, feat2, A1, A2, return_network=False)
149
+
150
+ # Plot
151
+ draw(graph1, colors1, labels1, output_filename.format(1), "Graph1")
152
+ draw(graph2, colors2, labels2, output_filename.format(5), f"Graph2")
153
+
154
+ # Match Process
155
+ total_cost = 0
156
+ labels1_1 = labels1.copy()
157
+ for i in range(x_pred.shape[0]):
158
+ target = torch.nonzero(x_pred[i])[0].item()
159
+ labels1_1[i] = labels1[i].replace(str(i), str(target))
160
+ title = "Node Match"
161
+ draw(graph1, colors1, labels1_1, output_filename.format(2), title)
162
+
163
+ # Node Change
164
+ cur_cost = 0
165
+ labels1_2 = labels1.copy()
166
+ colors1_2 = colors1.copy()
167
+ target2ori = dict()
168
+ targets = list()
169
+ for i in range(x_pred.shape[0]):
170
+ target = torch.nonzero(x_pred[i])[0].item()
171
+ if labels1_1[i] != labels2[target]:
172
+ cur_cost += 1
173
+ labels1_2[i] = labels2[target]
174
+ colors1_2[i] = colors2[target]
175
+ target2ori[target] = i
176
+ targets.append(target)
177
+ total_cost += cur_cost
178
+ title = f"Node Change"
179
+ draw(graph1, colors1_2, labels1_2, output_filename.format(3), title)
180
+
181
+ # Edge Change
182
+ leave_cost = np.array(graph2).shape[0] - np.array(graph1).shape[0]
183
+ leave_cost += graph2.number_of_nodes() - graph1.number_of_nodes()
184
+ e2 = np.array(graph2.edges)
185
+ new_edges = list()
186
+ for edge in e2:
187
+ if edge[0] in targets and edge[1] in targets:
188
+ new_edges.append([target2ori[edge[0]], target2ori[edge[1]]])
189
+ graph1.edges = nx.Graph(new_edges).edges
190
+ title = f"Edge Change"
191
  draw(graph1, colors1_2, labels1_2, output_filename.format(4), title, pos_type="spring")
media/ged_image_1.png CHANGED
media/ged_image_2.png CHANGED
media/ged_image_3.png CHANGED
media/ged_image_4.png CHANGED
media/ged_image_5.png CHANGED