ubuntu commited on
Commit
6b33608
·
1 Parent(s): 4019bc8

Initial Commit

Browse files
app.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import shutil
3
+ import gradio as gr
4
+ from genn_astar import astar
5
+
6
+
7
+ GED_IMG_DEFAULT_PATH = "src/ged_default.png"
8
+ GED_SOLUTION_1_PATH = "src/ged_image_1.png"
9
+ GED_SOLUTION_2_PATH = "src/ged_image_2.png"
10
+ GED_SOLUTION_3_PATH = "src/ged_image_3.png"
11
+ GED_SOLUTION_4_PATH = "src/ged_image_4.png"
12
+ GED_SOLUTION_5_PATH = "src/ged_image_5.png"
13
+
14
+
15
+ def _handle_ged_solve(
16
+ gexf_1_path: str,
17
+ gexf_2_path: str
18
+ ):
19
+ if gexf_1_path is None:
20
+ raise gr.Error("Please upload file completely!")
21
+ if gexf_2_path is None:
22
+ raise gr.Error("Please upload file completely!")
23
+
24
+ start_time = time.time()
25
+ astar(
26
+ g1_path=gexf_1_path,
27
+ g2_path=gexf_2_path,
28
+ output_path="src",
29
+ filename="ged_image"
30
+ )
31
+ solved_time = time.time() - start_time
32
+
33
+ message = "Successfully solve the GED problem, using time ({:.3f}s).".format(solved_time)
34
+
35
+ return message, GED_SOLUTION_1_PATH, GED_SOLUTION_2_PATH, GED_SOLUTION_3_PATH, \
36
+ GED_SOLUTION_4_PATH, GED_SOLUTION_5_PATH
37
+
38
+
39
+ def handle_ged_solve(
40
+ gexf_1_path: str,
41
+ gexf_2_path: str
42
+ ):
43
+ try:
44
+ message = _handle_ged_solve(
45
+ gexf_1_path=gexf_1_path,
46
+ gexf_2_path=gexf_2_path
47
+ )
48
+ return message
49
+ except Exception as e:
50
+ message = str(e)
51
+ return message, GED_SOLUTION_1_PATH, GED_SOLUTION_2_PATH, GED_SOLUTION_3_PATH, \
52
+ GED_SOLUTION_4_PATH, GED_SOLUTION_5_PATH
53
+
54
+
55
+ def handle_ged_clear():
56
+ shutil.copy(
57
+ src=GED_IMG_DEFAULT_PATH,
58
+ dst=GED_SOLUTION_1_PATH
59
+ )
60
+ shutil.copy(
61
+ src=GED_IMG_DEFAULT_PATH,
62
+ dst=GED_SOLUTION_2_PATH
63
+ )
64
+ shutil.copy(
65
+ src=GED_IMG_DEFAULT_PATH,
66
+ dst=GED_SOLUTION_3_PATH
67
+ )
68
+ shutil.copy(
69
+ src=GED_IMG_DEFAULT_PATH,
70
+ dst=GED_SOLUTION_4_PATH
71
+ )
72
+ shutil.copy(
73
+ src=GED_IMG_DEFAULT_PATH,
74
+ dst=GED_SOLUTION_5_PATH
75
+ )
76
+
77
+ message = "successfully clear the files!"
78
+ return message, GED_SOLUTION_1_PATH, GED_SOLUTION_2_PATH, GED_SOLUTION_3_PATH, \
79
+ GED_SOLUTION_4_PATH, GED_SOLUTION_5_PATH
80
+
81
+
82
+ def convert_image_path_to_bytes(image_path):
83
+ with open(image_path, "rb") as f:
84
+ image_bytes = f.read()
85
+ return image_bytes
86
+
87
+
88
+ with gr.Blocks() as ged_page:
89
+
90
+ gr.Markdown(
91
+ '''
92
+ This space displays the solution to the Graph Edit Distance problem.
93
+ ## How to use this Space?
94
+ - Upload two '.gexf' files.
95
+ - The images of the GED problem and solution will be shown after you click the solve button.
96
+ - Click the 'clear' button to clear all the files.
97
+ ## Examples
98
+ - You can get the test examples from our [GED Dataset Repo.](https://huggingface.co/datasets/SJTU-TES/Graph-Edit-Distance)
99
+ '''
100
+ )
101
+
102
+ with gr.Row(variant="panel"):
103
+ with gr.Column(scale=2):
104
+ with gr.Row():
105
+ ged_img_1 = gr.File(
106
+ label="Upload .gexf File",
107
+ file_types=[".gexf"],
108
+ min_width=40,
109
+ )
110
+ ged_img_2 = gr.File(
111
+ label="Upload .gexf File",
112
+ file_types=[".gexf"],
113
+ min_width=40,
114
+ )
115
+ with gr.Column(scale=2):
116
+ info = gr.Textbox(
117
+ value="",
118
+ label="Log",
119
+ scale=4,
120
+ )
121
+ with gr.Row():
122
+ with gr.Column(scale=1, min_width=100):
123
+ solve_button = gr.Button(
124
+ value="Solve",
125
+ variant="primary",
126
+ scale=1
127
+ )
128
+ with gr.Column(scale=1, min_width=100):
129
+ clear_button = gr.Button(
130
+ "Clear",
131
+ variant="secondary",
132
+ scale=1
133
+ )
134
+ with gr.Column(scale=8):
135
+ pass
136
+ with gr.Row(variant="panel"):
137
+ ged_solution_1 = gr.Image(
138
+ value=GED_SOLUTION_1_PATH,
139
+ type="filepath",
140
+ label="1"
141
+ )
142
+ ged_solution_2 = gr.Image(
143
+ value=GED_SOLUTION_2_PATH,
144
+ type="filepath",
145
+ label="2"
146
+ )
147
+ ged_solution_3 = gr.Image(
148
+ value=GED_SOLUTION_3_PATH,
149
+ type="filepath",
150
+ label="3"
151
+ )
152
+ ged_solution_4 = gr.Image(
153
+ value=GED_SOLUTION_4_PATH,
154
+ type="filepath",
155
+ label="4"
156
+ )
157
+ ged_solution_5 = gr.Image(
158
+ value=GED_SOLUTION_5_PATH,
159
+ type="filepath",
160
+ label="5"
161
+ )
162
+
163
+
164
+ solve_button.click(
165
+ handle_ged_solve,
166
+ [ged_img_1, ged_img_2],
167
+ outputs=[info, ged_solution_1, ged_solution_2,
168
+ ged_solution_3, ged_solution_4, ged_solution_5]
169
+ )
170
+
171
+ clear_button.click(
172
+ handle_ged_clear,
173
+ inputs=None,
174
+ outputs=[info, ged_solution_1, ged_solution_2,
175
+ ged_solution_3, ged_solution_4, ged_solution_5]
176
+ )
177
+
178
+
179
+ if __name__ == "__main__":
180
+ ged_page.launch(debug = True)
genn_astar.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import networkx as nx
4
+ import pygmtools as pygm
5
+ import torch
6
+ try:
7
+ from torch_geometric.data import Data
8
+ except:
9
+ os.system("pip install --no-index torch-sparse -f https://pytorch-geometric.com/whl/torch-2.0.0%2Bcpu.html")
10
+ os.system("pip install --no-index torch-scatter -f https://pytorch-geometric.com/whl/torch-2.0.0%2Bcpu.html")
11
+ os.system("pip install --no-index torch-spline-conv -f https://pytorch-geometric.com/whl/torch-2.0.0%2Bcpu.html")
12
+ os.system("pip install --no-index torch-cluster -f https://pytorch-geometric.com/whl/torch-2.0.0%2Bcpu.html")
13
+ from torch_geometric.data import Data
14
+ from one_hot import one_hot
15
+ from torch_geometric.transforms import OneHotDegree
16
+ import matplotlib.pyplot as plt
17
+ import pygmtools as pygm
18
+ pygm.set_backend('pytorch')
19
+
20
+
21
+ ######################################################
22
+ # Constant Variable #
23
+ ######################################################
24
+
25
+ AIDS700NEF_TYPE = [
26
+ 'O', 'S', 'C', 'N', 'Cl', 'Br', 'B', 'Si', 'Hg', 'I', 'Bi', 'P', 'F',
27
+ 'Cu', 'Ho', 'Pd', 'Ru', 'Pt', 'Sn', 'Li', 'Ga', 'Tb', 'As', 'Co', 'Pb',
28
+ 'Sb', 'Se', 'Ni', 'Te'
29
+ ]
30
+
31
+
32
+ COLOR = [
33
+ '#FF69B4', # O - 热情的粉红色
34
+ '#00CED1', # S - 深蓝绿色
35
+ '#FFD700', # C - 金色
36
+ '#FFA500', # N - 橙色
37
+ '#FF6347', # Cl - 番茄红色
38
+ '#8B008B', # Br - 深洋红色
39
+ '#00FF7F', # B - 春天的绿色
40
+ '#40E0D0', # Si - 绿松石色
41
+ '#FF4500', # Hg - 橙红色
42
+ '#9932CC', # I - 深兰花紫色
43
+ '#9370DB', # Bi - 中紫色
44
+ '#FFA500', # P - 橙色
45
+ '#FFFF00', # F - 黄色
46
+ '#B8860B', # Cu - 深金色
47
+ '#7FFFD4', # Ho - 碧绿色
48
+ '#FFD700', # Pd - 金色
49
+ '#B22222', # Ru - 砖红色
50
+ '#E5E4E2', # Pt - 浅灰色
51
+ '#A9A9A9', # Sn - 深灰色
52
+ '#32CD32', # Li - 酸橙色
53
+ '#CD853F', # Ga - 秘鲁色
54
+ '#7FFFD4', # Tb - 碧绿色
55
+ '#8A2BE2', # As - 紫罗兰色
56
+ '#FFD700', # Co - 金色
57
+ '#808080', # Pb - 灰色
58
+ '#A9A9A9', # Sb - 深灰色
59
+ '#FA8072', # Se - 鲑鱼色
60
+ '#BEBEBE', # Ni - 浅灰色
61
+ '#800080' # Te - 紫色
62
+ ]
63
+
64
+
65
+ ######################################################
66
+ # Utils Func #
67
+ ######################################################
68
+
69
+ def from_gexf(filename: str, node_types: list=None):
70
+ r"""
71
+ Read Data from GEXF file
72
+ """
73
+ if not filename.endswith('.gexf'):
74
+ raise ValueError("File type error, 'from_gexf' function only supports GEXF files")
75
+ graph = nx.read_gexf(filename)
76
+ mapping = {name: j for j, name in enumerate(graph.nodes())}
77
+ graph = nx.relabel_nodes(graph, mapping)
78
+ edge_index = torch.from_numpy(np.array(graph.edges, dtype=np.int64).transpose())
79
+ x = None
80
+ labels = None
81
+ data = None
82
+ colors = None
83
+ if 'type' in graph.nodes(data=True)[0].keys():
84
+ labels = dict()
85
+ colors = list()
86
+ num_nodes = graph.number_of_nodes()
87
+ x = torch.zeros(num_nodes, dtype=torch.long)
88
+ node_types = AIDS700NEF_TYPE if node_types is None else node_types
89
+ for node, info in graph.nodes(data=True):
90
+ x[int(node)] = node_types.index(info['type'])
91
+ labels[int(node)] = str(int(node)) + info['type']
92
+ colors.append(COLOR[x[int(node)]])
93
+ x = one_hot(x, num_classes=len(node_types))
94
+ data = Data(x=x, edge_index=edge_index, edge_attr=None)
95
+ return graph, data, labels, colors
96
+
97
+
98
+ def draw(graph, colors, labels, filename, title, pos_type=None):
99
+ if pos_type is None:
100
+ pos = nx.kamada_kawai_layout(graph)
101
+ elif pos_type == "spring":
102
+ pos = nx.spring_layout(graph)
103
+ plt.figure()
104
+ plt.gca().set_title(title)
105
+ nx.draw(graph, pos, with_labels=True, node_color=colors, edge_color='gray', labels=labels)
106
+ plt.savefig(filename)
107
+ plt.clf()
108
+
109
+
110
+ ######################################################
111
+ # GED UI #
112
+ ######################################################
113
+
114
+ def astar(
115
+ g1_path: str,
116
+ g2_path: str,
117
+ output_path: str="examples",
118
+ filename: str="example",
119
+ device='cpu'
120
+ ):
121
+ if not os.path.exists(output_path):
122
+ os.mkdir(output_path)
123
+ output_filename = os.path.join(output_path, filename) + "_{}.png"
124
+
125
+ # Load data
126
+ g1, d1, l1, c1 = from_gexf(g1_path)
127
+ g2, d2, l2, c2 = from_gexf(g2_path)
128
+ if len(c1) > len(c2):
129
+ graph1, data1, labels1, colors1 = g2, d2, l2, c2
130
+ graph2, data2, labels2, colors2 = g1, d1, l1, c1
131
+ else:
132
+ graph1, data1, labels1, colors1 = g1, d1, l1, c1
133
+ graph2, data2, labels2, colors2 = g2, d2, l2, c2
134
+
135
+ # Build Graph and Adj Matrix
136
+ data1 = OneHotDegree(max_degree=6)(data1)
137
+ data2 = OneHotDegree(max_degree=6)(data2)
138
+ feat1 = data1.x.to(device)
139
+ feat2 = data2.x.to(device)
140
+ A1 = torch.tensor(pygm.utils.from_networkx(graph1)).float().to(device)
141
+ A2 = torch.tensor(pygm.utils.from_networkx(graph2)).float().to(device)
142
+
143
+ # Caculate the ged
144
+ x_pred = pygm.genn_astar(feat1, feat2, A1, A2, return_network=False)
145
+
146
+ # Plot
147
+ draw(graph1, colors1, labels1, output_filename.format(1), "Graph1")
148
+ draw(graph2, colors2, labels2, output_filename.format(5), f"Graph2")
149
+
150
+ # Match Process
151
+ total_cost = 0
152
+ labels1_1 = labels1.copy()
153
+ for i in range(x_pred.shape[0]):
154
+ target = torch.nonzero(x_pred[i])[0].item()
155
+ labels1_1[i] = labels1[i].replace(str(i), str(target))
156
+ title = "Node Match"
157
+ draw(graph1, colors1, labels1_1, output_filename.format(2), title)
158
+
159
+ # Node Change
160
+ cur_cost = 0
161
+ labels1_2 = labels1.copy()
162
+ colors1_2 = colors1.copy()
163
+ target2ori = dict()
164
+ targets = list()
165
+ for i in range(x_pred.shape[0]):
166
+ target = torch.nonzero(x_pred[i])[0].item()
167
+ if labels1_1[i] != labels2[target]:
168
+ cur_cost += 1
169
+ labels1_2[i] = labels2[target]
170
+ colors1_2[i] = colors2[target]
171
+ target2ori[target] = i
172
+ targets.append(target)
173
+ total_cost += cur_cost
174
+ title = f"Node Change"
175
+ draw(graph1, colors1_2, labels1_2, output_filename.format(3), title)
176
+
177
+ # Edge Change
178
+ leave_cost = np.array(graph2).shape[0] - np.array(graph1).shape[0]
179
+ leave_cost += graph2.number_of_nodes() - graph1.number_of_nodes()
180
+ e2 = np.array(graph2.edges)
181
+ new_edges = list()
182
+ for edge in e2:
183
+ if edge[0] in targets and edge[1] in targets:
184
+ new_edges.append([target2ori[edge[0]], target2ori[edge[1]]])
185
+ graph1.edges = nx.Graph(new_edges).edges
186
+ title = f"Edge Change"
187
+ draw(graph1, colors1_2, labels1_2, output_filename.format(4), title, pos_type="spring")
one_hot.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+
7
+ def one_hot(
8
+ index: Tensor,
9
+ num_classes: Optional[int] = None,
10
+ dtype: Optional[torch.dtype] = None,
11
+ ) -> Tensor:
12
+ r"""Taskes a one-dimensional :obj:`index` tensor and returns a one-hot
13
+ encoded representation of it with shape :obj:`[*, num_classes]` that has
14
+ zeros everywhere except where the index of last dimension matches the
15
+ corresponding value of the input tensor, in which case it will be :obj:`1`.
16
+
17
+ .. note::
18
+ This is a more memory-efficient version of
19
+ :meth:`torch.nn.functional.one_hot` as you can customize the output
20
+ :obj:`dtype`.
21
+
22
+ Args:
23
+ index (torch.Tensor): The one-dimensional input tensor.
24
+ num_classes (int, optional): The total number of classes. If set to
25
+ :obj:`None`, the number of classes will be inferred as one greater
26
+ than the largest class value in the input tensor.
27
+ (default: :obj:`None`)
28
+ dtype (torch.dtype, optional): The :obj:`dtype` of the output tensor.
29
+ """
30
+ if index.dim() != 1:
31
+ raise ValueError("'index' tensor needs to be one-dimensional")
32
+
33
+ if num_classes is None:
34
+ num_classes = int(index.max()) + 1
35
+
36
+ out = torch.zeros((index.size(0), num_classes), dtype=dtype,
37
+ device=index.device)
38
+ return out.scatter_(1, index.unsqueeze(1), 1)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ pygmtools==0.4.3
3
+ matplotlib
4
+ torch==2.0.0
5
+ torchvision==0.15.1
6
+ scikit-learn
7
+ torch_geometric==2.0.0
8
+ networkx==2.8.8
src/ged_default.png ADDED
src/ged_image_1.png ADDED
src/ged_image_2.png ADDED
src/ged_image_3.png ADDED
src/ged_image_4.png ADDED
src/ged_image_5.png ADDED