Spaces:
Running
Running
yjwtheonly
commited on
Commit
•
fce1f4b
1
Parent(s):
8ae6390
Agnostic
Browse files- DiseaseAgnostic/KG_extractor.py +473 -0
- DiseaseAgnostic/edge_to_abstract.py +652 -0
- DiseaseAgnostic/evaluation.py +219 -0
- DiseaseAgnostic/generate_target_and_attack.py +371 -0
- DiseaseAgnostic/model.py +520 -0
- DiseaseAgnostic/utils.py +187 -0
DiseaseAgnostic/KG_extractor.py
ADDED
@@ -0,0 +1,473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#%%
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from torch.autograd import Variable
|
5 |
+
from sklearn import metrics
|
6 |
+
|
7 |
+
import datetime
|
8 |
+
from typing import Dict, Tuple, List
|
9 |
+
import logging
|
10 |
+
import os
|
11 |
+
import utils
|
12 |
+
import pickle as pkl
|
13 |
+
import json
|
14 |
+
import torch.backends.cudnn as cudnn
|
15 |
+
|
16 |
+
from tqdm import tqdm
|
17 |
+
|
18 |
+
import sys
|
19 |
+
sys.path.append("..")
|
20 |
+
import Parameters
|
21 |
+
|
22 |
+
parser = utils.get_argument_parser()
|
23 |
+
parser.add_argument('--reasonable-rate', type = float, default=0.7, help = 'The added edge\'s existance rank prob greater than this rate')
|
24 |
+
parser.add_argument('--mode', type=str, default='sentence', help='sentence, finetune, biogpt, bioBART')
|
25 |
+
parser.add_argument('--action', type=str, default='parse', help='parse or extract')
|
26 |
+
parser.add_argument('--init-mode', type = str, default='random', help = 'How to select target nodes')
|
27 |
+
parser.add_argument('--ratio', type = str, default='', help='ratio of the number of changed words')
|
28 |
+
args = parser.parse_args()
|
29 |
+
args = utils.set_hyperparams(args)
|
30 |
+
|
31 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
32 |
+
|
33 |
+
utils.seed_all(args.seed)
|
34 |
+
np.set_printoptions(precision=5)
|
35 |
+
cudnn.benchmark = False
|
36 |
+
|
37 |
+
data_path = '../DiseaseSpecific/processed_data/GNBR'
|
38 |
+
target_path = f'processed_data/target_{args.reasonable_rate}{args.init_mode}.pkl'
|
39 |
+
attack_path = f'processed_data/attack_edge_{args.model}_{args.reasonable_rate}{args.init_mode}.pkl'
|
40 |
+
modified_attack_path = f'processed_data/attack_edge_{args.model}_{args.reasonable_rate}{args.init_mode}{args.mode}.pkl'
|
41 |
+
|
42 |
+
with open(attack_path, 'rb') as fl:
|
43 |
+
Attack_edge_list = pkl.load(fl)
|
44 |
+
attack_data = np.array(Attack_edge_list).reshape(-1, 3)
|
45 |
+
#%%
|
46 |
+
with open(os.path.join(data_path, 'entities_reverse_dict.json')) as fl:
|
47 |
+
id_to_meshid = json.load(fl)
|
48 |
+
with open(os.path.join(data_path, 'entities_dict.json'), 'r') as fl:
|
49 |
+
meshid_to_id = json.load(fl)
|
50 |
+
with open(Parameters.GNBRfile+'entity_raw_name', 'rb') as fl:
|
51 |
+
entity_raw_name = pkl.load(fl)
|
52 |
+
with open(Parameters.GNBRfile+'retieve_sentence_through_edgetype', 'rb') as fl:
|
53 |
+
retieve_sentence_through_edgetype = pkl.load(fl)
|
54 |
+
with open(Parameters.GNBRfile+'raw_text_of_each_sentence', 'rb') as fl:
|
55 |
+
raw_text_sen = pkl.load(fl)
|
56 |
+
with open(Parameters.GNBRfile+'original_entity_raw_name', 'rb') as fl:
|
57 |
+
full_entity_raw_name = pkl.load(fl)
|
58 |
+
for k, v in entity_raw_name.items():
|
59 |
+
assert v in full_entity_raw_name[k]
|
60 |
+
|
61 |
+
#find unique
|
62 |
+
once_set = set()
|
63 |
+
twice_set = set()
|
64 |
+
|
65 |
+
with open('../DiseaseSpecific/generate_abstract/valid_entity.json', 'r') as fl:
|
66 |
+
valid_entity = json.load(fl)
|
67 |
+
valid_entity = set(valid_entity)
|
68 |
+
|
69 |
+
good_name = set()
|
70 |
+
for k, v, in full_entity_raw_name.items():
|
71 |
+
names = list(v)
|
72 |
+
for name in names:
|
73 |
+
# if name == 'in a':
|
74 |
+
# print(names)
|
75 |
+
good_name.add(name)
|
76 |
+
# if name not in once_set:
|
77 |
+
# once_set.add(name)
|
78 |
+
# else:
|
79 |
+
# twice_set.add(name)
|
80 |
+
# assert 'WNK4' in once_set
|
81 |
+
# good_name = set.difference(once_set, twice_set)
|
82 |
+
# assert 'in a' not in good_name
|
83 |
+
# assert 'STE20' not in good_name
|
84 |
+
# assert 'STE20' not in valid_entity
|
85 |
+
# assert 'STE20-related proline-alanine-rich kinase' not in good_name
|
86 |
+
# assert 'STE20-related proline-alanine-rich kinase' not in valid_entity
|
87 |
+
# raise Exception
|
88 |
+
|
89 |
+
name_to_type = {}
|
90 |
+
name_to_meshid = {}
|
91 |
+
|
92 |
+
for k, v, in full_entity_raw_name.items():
|
93 |
+
names = list(v)
|
94 |
+
for name in names:
|
95 |
+
if name in good_name:
|
96 |
+
name_to_type[name] = k.split('_')[0]
|
97 |
+
name_to_meshid[name] = k
|
98 |
+
|
99 |
+
import spacy
|
100 |
+
import networkx as nx
|
101 |
+
import pprint
|
102 |
+
|
103 |
+
def check(p, s):
|
104 |
+
|
105 |
+
if p < 1 or p >= len(s):
|
106 |
+
return True
|
107 |
+
return not((s[p]>='a' and s[p]<='z') or (s[p]>='A' and s[p]<='Z') or (s[p]>='0' and s[p]<='9'))
|
108 |
+
|
109 |
+
def raw_to_format(sen):
|
110 |
+
|
111 |
+
text = sen
|
112 |
+
l = 0
|
113 |
+
ret = []
|
114 |
+
while(l < len(text)):
|
115 |
+
bo =False
|
116 |
+
if text[l] != ' ':
|
117 |
+
for i in range(len(text), l, -1): # reversing is important !!!
|
118 |
+
cc = text[l:i]
|
119 |
+
if (cc in good_name or cc in valid_entity) and check(l-1, text) and check(i, text):
|
120 |
+
ret.append(cc.replace(' ', '_'))
|
121 |
+
l = i
|
122 |
+
bo = True
|
123 |
+
break
|
124 |
+
if not bo:
|
125 |
+
ret.append(text[l])
|
126 |
+
l += 1
|
127 |
+
return ''.join(ret)
|
128 |
+
|
129 |
+
if args.mode == 'sentence':
|
130 |
+
with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}_chat.json', 'r') as fl:
|
131 |
+
draft = json.load(fl)
|
132 |
+
elif args.mode == 'finetune':
|
133 |
+
with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}_sentence_finetune.json', 'r') as fl:
|
134 |
+
draft = json.load(fl)
|
135 |
+
elif args.mode == 'bioBART':
|
136 |
+
with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}{args.ratio}_bioBART_finetune.json', 'r') as fl:
|
137 |
+
draft = json.load(fl)
|
138 |
+
elif args.mode == 'biogpt':
|
139 |
+
with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}_biogpt.json', 'r') as fl:
|
140 |
+
draft = json.load(fl)
|
141 |
+
else:
|
142 |
+
raise Exception('No!!!')
|
143 |
+
|
144 |
+
nlp = spacy.load("en_core_web_sm")
|
145 |
+
|
146 |
+
type_set = set()
|
147 |
+
for aa in range(36):
|
148 |
+
dependency_sen_dict = retieve_sentence_through_edgetype[aa]['manual']
|
149 |
+
tmp_dict = retieve_sentence_through_edgetype[aa]['auto']
|
150 |
+
dependencys = list(dependency_sen_dict.keys()) + list(tmp_dict.keys())
|
151 |
+
for dependency in dependencys:
|
152 |
+
dep_list = dependency.split(' ')
|
153 |
+
for sub_dep in dep_list:
|
154 |
+
sub_dep_list = sub_dep.split('|')
|
155 |
+
assert(len(sub_dep_list) == 3)
|
156 |
+
type_set.add(sub_dep_list[1])
|
157 |
+
# print('Type:', type_set)
|
158 |
+
|
159 |
+
if args.action == 'parse':
|
160 |
+
# dp_path, sen_list = list(dependency_sen_dict.items())[0]
|
161 |
+
# check
|
162 |
+
# paper_id, sen_id = sen_list[0]
|
163 |
+
# sen = raw_text_sen[paper_id][sen_id]
|
164 |
+
# doc = nlp(sen['text'])
|
165 |
+
# print(dp_path, '\n')
|
166 |
+
# pprint.pprint(sen)
|
167 |
+
# print()
|
168 |
+
# for token in doc:
|
169 |
+
# print((token.head.text, token.text, token.dep_))
|
170 |
+
|
171 |
+
out = ''
|
172 |
+
for k, v_dict in draft.items():
|
173 |
+
input = v_dict['in']
|
174 |
+
output = v_dict['out']
|
175 |
+
if input == '':
|
176 |
+
continue
|
177 |
+
output = output.replace('\n', ' ')
|
178 |
+
doc = nlp(output)
|
179 |
+
for sen in doc.sents:
|
180 |
+
out += raw_to_format(sen.text) + '\n'
|
181 |
+
with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}_{args.mode}_parsein.txt', 'w') as fl:
|
182 |
+
fl.write(out)
|
183 |
+
elif args.action == 'extract':
|
184 |
+
|
185 |
+
# dependency_to_type_id = {}
|
186 |
+
# for k, v in Parameters.edge_type_to_id.items():
|
187 |
+
# dependency_to_type_id[k] = {}
|
188 |
+
# for type in v:
|
189 |
+
# LL = list(retieve_sentence_through_edgetype[type]['manual'].keys()) + list(retieve_sentence_through_edgetype[type]['auto'].keys())
|
190 |
+
# for dp in LL:
|
191 |
+
# dependency_to_type_id[k][dp] = type
|
192 |
+
if os.path.exists('generate_abstract/dependency_to_type_id.pickle'):
|
193 |
+
with open('generate_abstract/dependency_to_type_id.pickle', 'rb') as fl:
|
194 |
+
dependency_to_type_id = pkl.load(fl)
|
195 |
+
else:
|
196 |
+
dependency_to_type_id = {}
|
197 |
+
print('Loading path data ...')
|
198 |
+
for k in Parameters.edge_type_to_id.keys():
|
199 |
+
start, end = k.split('-')
|
200 |
+
dependency_to_type_id[k] = {}
|
201 |
+
inner_edge_type_to_id = Parameters.edge_type_to_id[k]
|
202 |
+
inner_edge_type_dict = Parameters.edge_type_dict[k]
|
203 |
+
cal_manual_num = [0] * len(inner_edge_type_to_id)
|
204 |
+
with open('../GNBRdata/part-i-'+start+'-'+end+'-path-theme-distributions.txt', 'r') as fl:
|
205 |
+
for i, line in tqdm(list(enumerate(fl.readlines()))):
|
206 |
+
tmp = line.split('\t')
|
207 |
+
if i == 0:
|
208 |
+
head = [tmp[i] for i in range(1, len(tmp), 2)]
|
209 |
+
assert ' '.join(head) == ' '.join(inner_edge_type_dict[0])
|
210 |
+
continue
|
211 |
+
probability = [float(tmp[i]) for i in range(1, len(tmp), 2)]
|
212 |
+
flag_list = [int(tmp[i]) for i in range(2, len(tmp), 2)]
|
213 |
+
indices = np.where(np.asarray(flag_list) == 1)[0]
|
214 |
+
if len(indices) >= 1:
|
215 |
+
tmp_p = [cal_manual_num[i] for i in indices]
|
216 |
+
p = indices[np.argmin(tmp_p)]
|
217 |
+
cal_manual_num[p] += 1
|
218 |
+
else:
|
219 |
+
p = np.argmax(probability)
|
220 |
+
assert tmp[0].lower() not in dependency_to_type_id.keys()
|
221 |
+
dependency_to_type_id[k][tmp[0].lower()] = inner_edge_type_to_id[p]
|
222 |
+
with open('generate_abstract/dependency_to_type_id.pickle', 'wb') as fl:
|
223 |
+
pkl.dump(dependency_to_type_id, fl)
|
224 |
+
|
225 |
+
# record = []
|
226 |
+
# with open(f'generate_abstract/par_parseout.txt', 'r') as fl:
|
227 |
+
# Tmp = []
|
228 |
+
# tmp = []
|
229 |
+
# for i,line in enumerate(fl.readlines()):
|
230 |
+
# # print(len(line), line)
|
231 |
+
# line = line.replace('\n', '')
|
232 |
+
# if len(line) > 1:
|
233 |
+
# tmp.append(line)
|
234 |
+
# else:
|
235 |
+
# Tmp.append(tmp)
|
236 |
+
# tmp = []
|
237 |
+
# if len(Tmp) == 3:
|
238 |
+
# record.append(Tmp)
|
239 |
+
# Tmp = []
|
240 |
+
|
241 |
+
# print(len(record))
|
242 |
+
# record_index = 0
|
243 |
+
# add = 0
|
244 |
+
# Attack = []
|
245 |
+
# for ii in range(100):
|
246 |
+
|
247 |
+
# # input = v_dict['in']
|
248 |
+
# # output = v_dict['out']
|
249 |
+
# # output = output.replace('\n', ' ')
|
250 |
+
# s, r, o = attack_data[ii]
|
251 |
+
# dependency_sen_dict = retieve_sentence_through_edgetype[int(r)]['manual']
|
252 |
+
|
253 |
+
# target_dp = set()
|
254 |
+
# for dp_path, sen_list in dependency_sen_dict.items():
|
255 |
+
# target_dp.add(dp_path)
|
256 |
+
# DP_list = []
|
257 |
+
# for _ in range(1):
|
258 |
+
# dp_dict = {}
|
259 |
+
# data = record[record_index]
|
260 |
+
# record_index += 1
|
261 |
+
# dp_paths = data[2]
|
262 |
+
# nodes_list = []
|
263 |
+
# edges_list = []
|
264 |
+
# for line in dp_paths:
|
265 |
+
# ttp, tmp = line.split('(')
|
266 |
+
# assert tmp[-1] == ')'
|
267 |
+
# tmp = tmp[:-1]
|
268 |
+
# e1, e2 = tmp.split(', ')
|
269 |
+
# if not ttp in type_set and ':' in ttp:
|
270 |
+
# ttp = ttp.split(':')[0]
|
271 |
+
# dp_dict[f'{e1}_x_{e2}'] = [e1, ttp, e2]
|
272 |
+
# dp_dict[f'{e2}_x_{e1}'] = [e1, ttp, e2]
|
273 |
+
# nodes_list.append(e1)
|
274 |
+
# nodes_list.append(e2)
|
275 |
+
# edges_list.append((e1, e2))
|
276 |
+
# nodes_list = list(set(nodes_list))
|
277 |
+
# pure_name = [('-'.join(name.split('-')[:-1])).replace('_', ' ') for name in nodes_list]
|
278 |
+
# graph = nx.Graph(edges_list)
|
279 |
+
|
280 |
+
# type_list = [name_to_type[name] if name in good_name else '' for name in pure_name]
|
281 |
+
# # print(type_list)
|
282 |
+
# # for i in range(len(type_list)):
|
283 |
+
# # print(pure_name[i], type_list[i])
|
284 |
+
# for i in range(len(nodes_list)):
|
285 |
+
# if type_list[i] != '':
|
286 |
+
# for j in range(len(nodes_list)):
|
287 |
+
# if i != j and type_list[j] != '':
|
288 |
+
# if f'{type_list[i]}-{type_list[j]}' in Parameters.edge_type_to_id.keys():
|
289 |
+
# # print(f'{type_list[i]}_{type_list[j]}')
|
290 |
+
# ret_path = []
|
291 |
+
# sp = nx.shortest_path(graph, source=nodes_list[i], target=nodes_list[j])
|
292 |
+
# start = sp[0]
|
293 |
+
# end = sp[-1]
|
294 |
+
# for k in range(len(sp)-1):
|
295 |
+
# e1, ttp, e2 = dp_dict[f'{sp[k]}_x_{sp[k+1]}']
|
296 |
+
# if e1 == start:
|
297 |
+
# e1 = 'start_entity-x'
|
298 |
+
# if e2 == start:
|
299 |
+
# e2 = 'start_entity-x'
|
300 |
+
# if e1 == end:
|
301 |
+
# e1 = 'end_entity-x'
|
302 |
+
# if e2 == end:
|
303 |
+
# e2 = 'end_entity-x'
|
304 |
+
# ret_path.append(f'{"-".join(e1.split("-")[:-1])}|{ttp}|{"-".join(e2.split("-")[:-1])}'.lower())
|
305 |
+
# dependency_P = ' '.join(ret_path)
|
306 |
+
# DP_list.append((f'{type_list[i]}-{type_list[j]}',
|
307 |
+
# name_to_meshid[pure_name[i]],
|
308 |
+
# name_to_meshid[pure_name[j]],
|
309 |
+
# dependency_P))
|
310 |
+
|
311 |
+
# boo = False
|
312 |
+
# modified_attack = []
|
313 |
+
# for k, ss, tt, dp in DP_list:
|
314 |
+
# if dp in dependency_to_type_id[k].keys():
|
315 |
+
# tp = str(dependency_to_type_id[k][dp])
|
316 |
+
# id_ss = str(meshid_to_id[ss])
|
317 |
+
# id_tt = str(meshid_to_id[tt])
|
318 |
+
# modified_attack.append(f'{id_ss}*{tp}*{id_tt}')
|
319 |
+
# if int(dependency_to_type_id[k][dp]) == int(r):
|
320 |
+
# # if id_to_meshid[s] == ss and id_to_meshid[o] == tt:
|
321 |
+
# boo = True
|
322 |
+
# modified_attack = list(set(modified_attack))
|
323 |
+
# modified_attack = [k.split('*') for k in modified_attack]
|
324 |
+
# if boo:
|
325 |
+
# add += 1
|
326 |
+
# # else:
|
327 |
+
# # print(ii)
|
328 |
+
|
329 |
+
# # for i in range(len(type_list)):
|
330 |
+
# # if type_list[i]:
|
331 |
+
# # print(pure_name[i], type_list[i])
|
332 |
+
# # for k, ss, tt, dp in DP_list:
|
333 |
+
# # print(k, dp)
|
334 |
+
# # print(record[record_index - 1])
|
335 |
+
# # raise Exception('No!!')
|
336 |
+
# Attack.append(modified_attack)
|
337 |
+
|
338 |
+
record = []
|
339 |
+
with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}_{args.mode}_parseout.txt', 'r') as fl:
|
340 |
+
Tmp = []
|
341 |
+
tmp = []
|
342 |
+
for i,line in enumerate(fl.readlines()):
|
343 |
+
# print(len(line), line)
|
344 |
+
line = line.replace('\n', '')
|
345 |
+
if len(line) > 1:
|
346 |
+
tmp.append(line)
|
347 |
+
else:
|
348 |
+
if len(Tmp) == 2:
|
349 |
+
if len(tmp) == 1 and '/' in tmp[0].split(' ')[0]:
|
350 |
+
Tmp.append([])
|
351 |
+
record.append(Tmp)
|
352 |
+
Tmp = []
|
353 |
+
Tmp.append(tmp)
|
354 |
+
if len(Tmp) == 2 and tmp[0][:5] != '(ROOT':
|
355 |
+
print(record[-1][2])
|
356 |
+
raise Exception('??')
|
357 |
+
tmp = []
|
358 |
+
if len(Tmp) == 3:
|
359 |
+
record.append(Tmp)
|
360 |
+
Tmp = []
|
361 |
+
with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}_{args.mode}_parsein.txt', 'r') as fl:
|
362 |
+
parsin = fl.readlines()
|
363 |
+
|
364 |
+
print('Record len', len(record), 'Parsin len:', len(parsin))
|
365 |
+
record_index = 0
|
366 |
+
add = 0
|
367 |
+
|
368 |
+
Attack = []
|
369 |
+
for ii, (k, v_dict) in enumerate(tqdm(draft.items())):
|
370 |
+
|
371 |
+
input = v_dict['in']
|
372 |
+
output = v_dict['out']
|
373 |
+
output = output.replace('\n', ' ')
|
374 |
+
s, r, o = attack_data[ii]
|
375 |
+
s = str(s)
|
376 |
+
r = str(r)
|
377 |
+
o = str(o)
|
378 |
+
assert ii == int(k.split('_')[-1])
|
379 |
+
|
380 |
+
DP_list = []
|
381 |
+
if input != '':
|
382 |
+
|
383 |
+
dependency_sen_dict = retieve_sentence_through_edgetype[int(r)]['manual']
|
384 |
+
target_dp = set()
|
385 |
+
for dp_path, sen_list in dependency_sen_dict.items():
|
386 |
+
target_dp.add(dp_path)
|
387 |
+
doc = nlp(output)
|
388 |
+
|
389 |
+
for sen in doc.sents:
|
390 |
+
dp_dict = {}
|
391 |
+
if record_index >= len(record):
|
392 |
+
break
|
393 |
+
data = record[record_index]
|
394 |
+
record_index += 1
|
395 |
+
dp_paths = data[2]
|
396 |
+
nodes_list = []
|
397 |
+
edges_list = []
|
398 |
+
for line in dp_paths:
|
399 |
+
aa = line.split('(')
|
400 |
+
if len(aa) == 1:
|
401 |
+
print(ii)
|
402 |
+
print(sen)
|
403 |
+
print(data)
|
404 |
+
raise Exception
|
405 |
+
ttp, tmp = aa[0], aa[1]
|
406 |
+
assert tmp[-1] == ')'
|
407 |
+
tmp = tmp[:-1]
|
408 |
+
e1, e2 = tmp.split(', ')
|
409 |
+
if not ttp in type_set and ':' in ttp:
|
410 |
+
ttp = ttp.split(':')[0]
|
411 |
+
dp_dict[f'{e1}_x_{e2}'] = [e1, ttp, e2]
|
412 |
+
dp_dict[f'{e2}_x_{e1}'] = [e1, ttp, e2]
|
413 |
+
nodes_list.append(e1)
|
414 |
+
nodes_list.append(e2)
|
415 |
+
edges_list.append((e1, e2))
|
416 |
+
nodes_list = list(set(nodes_list))
|
417 |
+
pure_name = [('-'.join(name.split('-')[:-1])).replace('_', ' ') for name in nodes_list]
|
418 |
+
graph = nx.Graph(edges_list)
|
419 |
+
|
420 |
+
type_list = [name_to_type[name] if name in good_name else '' for name in pure_name]
|
421 |
+
# print(type_list)
|
422 |
+
for i in range(len(nodes_list)):
|
423 |
+
if type_list[i] != '':
|
424 |
+
for j in range(len(nodes_list)):
|
425 |
+
if i != j and type_list[j] != '':
|
426 |
+
if f'{type_list[i]}-{type_list[j]}' in Parameters.edge_type_to_id.keys():
|
427 |
+
# print(f'{type_list[i]}_{type_list[j]}')
|
428 |
+
ret_path = []
|
429 |
+
sp = nx.shortest_path(graph, source=nodes_list[i], target=nodes_list[j])
|
430 |
+
start = sp[0]
|
431 |
+
end = sp[-1]
|
432 |
+
for k in range(len(sp)-1):
|
433 |
+
e1, ttp, e2 = dp_dict[f'{sp[k]}_x_{sp[k+1]}']
|
434 |
+
if e1 == start:
|
435 |
+
e1 = 'start_entity-x'
|
436 |
+
if e2 == start:
|
437 |
+
e2 = 'start_entity-x'
|
438 |
+
if e1 == end:
|
439 |
+
e1 = 'end_entity-x'
|
440 |
+
if e2 == end:
|
441 |
+
e2 = 'end_entity-x'
|
442 |
+
ret_path.append(f'{"-".join(e1.split("-")[:-1])}|{ttp}|{"-".join(e2.split("-")[:-1])}'.lower())
|
443 |
+
dependency_P = ' '.join(ret_path)
|
444 |
+
DP_list.append((f'{type_list[i]}-{type_list[j]}',
|
445 |
+
name_to_meshid[pure_name[i]],
|
446 |
+
name_to_meshid[pure_name[j]],
|
447 |
+
dependency_P))
|
448 |
+
|
449 |
+
boo = False
|
450 |
+
modified_attack = []
|
451 |
+
for k, ss, tt, dp in DP_list:
|
452 |
+
if dp in dependency_to_type_id[k].keys():
|
453 |
+
tp = str(dependency_to_type_id[k][dp])
|
454 |
+
id_ss = str(meshid_to_id[ss])
|
455 |
+
id_tt = str(meshid_to_id[tt])
|
456 |
+
modified_attack.append(f'{id_ss}*{tp}*{id_tt}')
|
457 |
+
if int(dependency_to_type_id[k][dp]) == int(r):
|
458 |
+
if id_to_meshid[s] == ss and id_to_meshid[o] == tt:
|
459 |
+
boo = True
|
460 |
+
modified_attack = list(set(modified_attack))
|
461 |
+
modified_attack = [k.split('*') for k in modified_attack]
|
462 |
+
if boo:
|
463 |
+
# print(DP_list)
|
464 |
+
add += 1
|
465 |
+
Attack.append(modified_attack)
|
466 |
+
print(add)
|
467 |
+
print('End record_index:', record_index)
|
468 |
+
final_Attack = Attack
|
469 |
+
print('Len of Attack:', len(Attack))
|
470 |
+
with open(modified_attack_path, 'wb') as fl:
|
471 |
+
pkl.dump(final_Attack, fl)
|
472 |
+
else:
|
473 |
+
raise Exception('Wrong action !!')
|
DiseaseAgnostic/edge_to_abstract.py
ADDED
@@ -0,0 +1,652 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#%%
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from torch.autograd import Variable
|
5 |
+
from sklearn import metrics
|
6 |
+
|
7 |
+
import datetime
|
8 |
+
from typing import Dict, Tuple, List
|
9 |
+
import logging
|
10 |
+
import os
|
11 |
+
import utils
|
12 |
+
import pickle as pkl
|
13 |
+
import json
|
14 |
+
import torch.backends.cudnn as cudnn
|
15 |
+
|
16 |
+
from tqdm import tqdm
|
17 |
+
|
18 |
+
import sys
|
19 |
+
sys.path.append("..")
|
20 |
+
import Parameters
|
21 |
+
|
22 |
+
parser = utils.get_argument_parser()
|
23 |
+
parser.add_argument('--reasonable-rate', type = float, default=0.7, help = 'The added edge\'s existance rank prob greater than this rate')
|
24 |
+
parser.add_argument('--mode', type=str, default='sentence', help='sentence, biogpt or finetune')
|
25 |
+
parser.add_argument('--init-mode', type = str, default='random', help = 'How to select target nodes')
|
26 |
+
parser.add_argument('--ratio', type = str, default='', help='ratio of the number of changed words')
|
27 |
+
args = parser.parse_args()
|
28 |
+
args = utils.set_hyperparams(args)
|
29 |
+
|
30 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
31 |
+
|
32 |
+
utils.seed_all(args.seed)
|
33 |
+
np.set_printoptions(precision=5)
|
34 |
+
cudnn.benchmark = False
|
35 |
+
|
36 |
+
data_path = '../DiseaseSpecific/processed_data/GNBR'
|
37 |
+
target_path = f'processed_data/target_{args.reasonable_rate}{args.init_mode}.pkl'
|
38 |
+
attack_path = f'processed_data/attack_edge_{args.model}_{args.reasonable_rate}{args.init_mode}.pkl'
|
39 |
+
|
40 |
+
# target_data = utils.load_data(target_path)
|
41 |
+
with open(target_path, 'rb') as fl:
|
42 |
+
Target_node_list = pkl.load(fl)
|
43 |
+
with open(attack_path, 'rb') as fl:
|
44 |
+
Attack_edge_list = pkl.load(fl)
|
45 |
+
attack_data = np.array(Attack_edge_list).reshape(-1, 3)
|
46 |
+
# assert target_data.shape == attack_data.shape
|
47 |
+
#%%
|
48 |
+
|
49 |
+
with open('../DiseaseSpecific/processed_data/GNBR/entities_reverse_dict.json') as fl:
|
50 |
+
id_to_meshid = json.load(fl)
|
51 |
+
with open(Parameters.GNBRfile+'entity_raw_name', 'rb') as fl:
|
52 |
+
entity_raw_name = pkl.load(fl)
|
53 |
+
with open(Parameters.GNBRfile+'retieve_sentence_through_edgetype', 'rb') as fl:
|
54 |
+
retieve_sentence_through_edgetype = pkl.load(fl)
|
55 |
+
with open(Parameters.GNBRfile+'raw_text_of_each_sentence', 'rb') as fl:
|
56 |
+
raw_text_sen = pkl.load(fl)
|
57 |
+
|
58 |
+
if args.mode == 'sentence':
|
59 |
+
import torch
|
60 |
+
from torch.nn.modules.loss import CrossEntropyLoss
|
61 |
+
from transformers import AutoTokenizer
|
62 |
+
from transformers import BioGptForCausalLM
|
63 |
+
criterion = CrossEntropyLoss(reduction="none")
|
64 |
+
|
65 |
+
print('Generating GPT input ...')
|
66 |
+
|
67 |
+
tokenizer = AutoTokenizer.from_pretrained('microsoft/biogpt')
|
68 |
+
tokenizer.pad_token = tokenizer.eos_token
|
69 |
+
model = BioGptForCausalLM.from_pretrained('microsoft/biogpt', pad_token_id=tokenizer.eos_token_id)
|
70 |
+
model.to(device)
|
71 |
+
model.eval()
|
72 |
+
GPT_batch_size = 24
|
73 |
+
single_sentence = {}
|
74 |
+
test_text = []
|
75 |
+
test_dp = []
|
76 |
+
test_parse = []
|
77 |
+
for i, (s, r, o) in enumerate(tqdm(attack_data)):
|
78 |
+
|
79 |
+
s = str(s)
|
80 |
+
r = str(r)
|
81 |
+
o = str(o)
|
82 |
+
if int(s) != -1:
|
83 |
+
|
84 |
+
dependency_sen_dict = retieve_sentence_through_edgetype[int(r)]['manual']
|
85 |
+
candidate_sen = []
|
86 |
+
Dp_path = []
|
87 |
+
L = len(dependency_sen_dict.keys())
|
88 |
+
bound = 500 // L
|
89 |
+
if bound == 0:
|
90 |
+
bound = 1
|
91 |
+
for dp_path, sen_list in dependency_sen_dict.items():
|
92 |
+
if len(sen_list) > bound:
|
93 |
+
index = np.random.choice(np.array(range(len(sen_list))), bound, replace=False)
|
94 |
+
sen_list = [sen_list[aa] for aa in index]
|
95 |
+
candidate_sen += sen_list
|
96 |
+
Dp_path += [dp_path] * len(sen_list)
|
97 |
+
|
98 |
+
text_s = entity_raw_name[id_to_meshid[s]]
|
99 |
+
text_o = entity_raw_name[id_to_meshid[o]]
|
100 |
+
candidate_text_sen = []
|
101 |
+
candidate_ori_sen = []
|
102 |
+
candidate_parse_sen = []
|
103 |
+
|
104 |
+
for paper_id, sen_id in candidate_sen:
|
105 |
+
sen = raw_text_sen[paper_id][sen_id]
|
106 |
+
text = sen['text']
|
107 |
+
candidate_ori_sen.append(text)
|
108 |
+
ss = sen['start_formatted']
|
109 |
+
oo = sen['end_formatted']
|
110 |
+
text = text.replace('-LRB-', '(')
|
111 |
+
text = text.replace('-RRB-', ')')
|
112 |
+
text = text.replace('-LSB-', '[')
|
113 |
+
text = text.replace('-RSB-', ']')
|
114 |
+
text = text.replace('-LCB-', '{')
|
115 |
+
text = text.replace('-RCB-', '}')
|
116 |
+
parse_text = text
|
117 |
+
parse_text = parse_text.replace(ss, text_s.replace(' ', '_'))
|
118 |
+
parse_text = parse_text.replace(oo, text_o.replace(' ', '_'))
|
119 |
+
text = text.replace(ss, text_s)
|
120 |
+
text = text.replace(oo, text_o)
|
121 |
+
text = text.replace('_', ' ')
|
122 |
+
candidate_text_sen.append(text)
|
123 |
+
candidate_parse_sen.append(parse_text)
|
124 |
+
tokens = tokenizer( candidate_text_sen,
|
125 |
+
truncation = True,
|
126 |
+
padding = True,
|
127 |
+
max_length = 300,
|
128 |
+
return_tensors="pt")
|
129 |
+
target_ids = tokens['input_ids'].to(device)
|
130 |
+
attention_mask = tokens['attention_mask'].to(device)
|
131 |
+
|
132 |
+
L = len(candidate_text_sen)
|
133 |
+
assert L > 0
|
134 |
+
ret_log_L = []
|
135 |
+
for l in range(0, L, GPT_batch_size):
|
136 |
+
R = min(L, l + GPT_batch_size)
|
137 |
+
target = target_ids[l:R, :]
|
138 |
+
attention = attention_mask[l:R, :]
|
139 |
+
outputs = model(input_ids = target,
|
140 |
+
attention_mask = attention,
|
141 |
+
labels = target)
|
142 |
+
logits = outputs.logits
|
143 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
144 |
+
shift_labels = target[..., 1:].contiguous()
|
145 |
+
Loss = criterion(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1))
|
146 |
+
Loss = Loss.view(-1, shift_logits.shape[1])
|
147 |
+
attention = attention[..., 1:].contiguous()
|
148 |
+
log_Loss = (torch.mean(Loss * attention.float(), dim = 1) / torch.mean(attention.float(), dim = 1))
|
149 |
+
ret_log_L.append(log_Loss.detach())
|
150 |
+
|
151 |
+
|
152 |
+
ret_log_L = list(torch.cat(ret_log_L, -1).cpu().numpy())
|
153 |
+
sen_score = list(zip(candidate_text_sen, ret_log_L, candidate_ori_sen, Dp_path, candidate_parse_sen))
|
154 |
+
sen_score.sort(key = lambda x: x[1])
|
155 |
+
test_text.append(sen_score[0][2])
|
156 |
+
test_dp.append(sen_score[0][3])
|
157 |
+
test_parse.append(sen_score[0][4])
|
158 |
+
single_sentence.update({f'{s}_{r}_{o}_{i}': sen_score[0][0]})
|
159 |
+
|
160 |
+
else:
|
161 |
+
single_sentence.update({f'{s}_{r}_{o}_{i}': ''})
|
162 |
+
|
163 |
+
with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}_sentence.json', 'w') as fl:
|
164 |
+
json.dump(single_sentence, fl, indent=4)
|
165 |
+
# with open('generate_abstract/test.txt', 'w') as fl:
|
166 |
+
# fl.write('\n'.join(test_text))
|
167 |
+
# with open('generate_abstract/dp.txt', 'w') as fl:
|
168 |
+
# fl.write('\n'.join(test_dp))
|
169 |
+
with open (f'generate_abstract/path/{args.init_mode}{args.reasonable_rate}_path.json', 'w') as fl:
|
170 |
+
fl.write('\n'.join(test_dp))
|
171 |
+
with open (f'generate_abstract/path/{args.init_mode}{args.reasonable_rate}_temp.json', 'w') as fl:
|
172 |
+
fl.write('\n'.join(test_text))
|
173 |
+
|
174 |
+
elif args.mode == 'biogpt':
|
175 |
+
pass
|
176 |
+
# from biogpt_generate import GPT_eval
|
177 |
+
# import spacy
|
178 |
+
|
179 |
+
# model = GPT_eval(args.seed)
|
180 |
+
|
181 |
+
# nlp = spacy.load("en_core_web_sm")
|
182 |
+
# with open(f'generate_abstract/{args.target_split}_{args.reasonable_rate}_sentence.json', 'r') as fl:
|
183 |
+
# data = json.load(fl)
|
184 |
+
|
185 |
+
# KK = []
|
186 |
+
# input = []
|
187 |
+
# for i,(k, v) in enumerate(data.items()):
|
188 |
+
# KK.append(k)
|
189 |
+
# input.append(v)
|
190 |
+
# output = model.eval(input)
|
191 |
+
|
192 |
+
# ret = {}
|
193 |
+
# for i, o in enumerate(output):
|
194 |
+
|
195 |
+
# o = o.replace('<|abstract|>', '')
|
196 |
+
# doc = nlp(o)
|
197 |
+
# sen_list = []
|
198 |
+
# sen_set = set()
|
199 |
+
# for sen in doc.sents:
|
200 |
+
# txt = sen.text
|
201 |
+
# if not (txt.lower() in sen_set):
|
202 |
+
# sen_set.add(txt.lower())
|
203 |
+
# sen_list.append(txt)
|
204 |
+
# O = ' '.join(sen_list)
|
205 |
+
# ret[KK[i]] = {'in' : input[i], 'out' : O}
|
206 |
+
|
207 |
+
# with open(f'generate_abstract/{args.target_split}_{args.reasonable_rate}_biogpt.json', 'w') as fl:
|
208 |
+
# json.dump(ret, fl, indent=4)
|
209 |
+
|
210 |
+
elif args.mode == 'finetune':
|
211 |
+
|
212 |
+
import spacy
|
213 |
+
import pprint
|
214 |
+
from transformers import AutoModel, AutoTokenizer,BartForConditionalGeneration
|
215 |
+
|
216 |
+
print('Finetuning ...')
|
217 |
+
|
218 |
+
with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}_chat.json', 'r') as fl:
|
219 |
+
draft = json.load(fl)
|
220 |
+
with open (f'generate_abstract/path/{args.init_mode}{args.reasonable_rate}_path.json', 'r') as fl:
|
221 |
+
dpath = fl.readlines()
|
222 |
+
|
223 |
+
nlp = spacy.load("en_core_web_sm")
|
224 |
+
if os.path.exists(f'generate_abstract/bioBART/{args.init_mode}{args.reasonable_rate}{args.ratio}_candidates.json'):
|
225 |
+
with open(f'generate_abstract/bioBART/{args.init_mode}{args.reasonable_rate}{args.ratio}_candidates.json', 'r') as fl:
|
226 |
+
ret_candidates = json.load(fl)
|
227 |
+
else:
|
228 |
+
|
229 |
+
def find_mini_span(vec, words, check_set):
|
230 |
+
|
231 |
+
|
232 |
+
def cal(text, sset):
|
233 |
+
add = 0
|
234 |
+
for tt in sset:
|
235 |
+
if tt in text:
|
236 |
+
add += 1
|
237 |
+
return add
|
238 |
+
text = ' '.join(words)
|
239 |
+
max_add = cal(text, check_set)
|
240 |
+
|
241 |
+
minn = 10000000
|
242 |
+
span = ''
|
243 |
+
rc = None
|
244 |
+
for i in range(len(vec)):
|
245 |
+
if vec[i] == True:
|
246 |
+
p = -1
|
247 |
+
for j in range(i+1, len(vec)+1):
|
248 |
+
if vec[j-1] == True:
|
249 |
+
text = ' '.join(words[i:j])
|
250 |
+
if cal(text, check_set) == max_add:
|
251 |
+
p = j
|
252 |
+
break
|
253 |
+
if p > 0:
|
254 |
+
if (p-i) < minn:
|
255 |
+
minn = p-i
|
256 |
+
span = ' '.join(words[i:p])
|
257 |
+
rc = (i, p)
|
258 |
+
if rc:
|
259 |
+
for i in range(rc[0], rc[1]):
|
260 |
+
vec[i] = True
|
261 |
+
return vec, span
|
262 |
+
|
263 |
+
# def mask_func(tokenized_sen, position):
|
264 |
+
|
265 |
+
# if len(tokenized_sen) == 0:
|
266 |
+
# return []
|
267 |
+
# token_list = []
|
268 |
+
# # for sen in tokenized_sen:
|
269 |
+
# # for token in sen:
|
270 |
+
# # token_list.append(token)
|
271 |
+
# for sen in tokenized_sen:
|
272 |
+
# token_list += sen.text.split(' ')
|
273 |
+
# l_p = 0
|
274 |
+
# r_p = 1
|
275 |
+
# assert position == 'front' or position == 'back'
|
276 |
+
# if position == 'back':
|
277 |
+
# l_p, r_p = r_p, l_p
|
278 |
+
# P = np.linspace(start = l_p, stop = r_p, num = len(token_list))
|
279 |
+
# P = (P ** 3) * 0.4
|
280 |
+
|
281 |
+
# ret_list = []
|
282 |
+
# for t, p in zip(token_list, list(P)):
|
283 |
+
# if '.' in t or '(' in t or ')' in t or '[' in t or ']' in t:
|
284 |
+
# ret_list.append(t)
|
285 |
+
# else:
|
286 |
+
# if np.random.rand() < p:
|
287 |
+
# ret_list.append('<mask>')
|
288 |
+
# else:
|
289 |
+
# ret_list.append(t)
|
290 |
+
# return [' '.join(ret_list)]
|
291 |
+
def mask_func(tokenized_sen):
|
292 |
+
|
293 |
+
if len(tokenized_sen) == 0:
|
294 |
+
return []
|
295 |
+
token_list = []
|
296 |
+
# for sen in tokenized_sen:
|
297 |
+
# for token in sen:
|
298 |
+
# token_list.append(token)
|
299 |
+
for sen in tokenized_sen:
|
300 |
+
token_list += sen.text.split(' ')
|
301 |
+
if args.ratio == '':
|
302 |
+
P = 0.3
|
303 |
+
else:
|
304 |
+
P = float(args.ratio)
|
305 |
+
|
306 |
+
ret_list = []
|
307 |
+
i = 0
|
308 |
+
mask_num = 0
|
309 |
+
while i < len(token_list):
|
310 |
+
t = token_list[i]
|
311 |
+
if '.' in t or '(' in t or ')' in t or '[' in t or ']' in t:
|
312 |
+
ret_list.append(t)
|
313 |
+
i += 1
|
314 |
+
mask_num = 0
|
315 |
+
else:
|
316 |
+
length = np.random.poisson(3)
|
317 |
+
if np.random.rand() < P and length > 0:
|
318 |
+
if mask_num < 8:
|
319 |
+
ret_list.append('<mask>')
|
320 |
+
mask_num += 1
|
321 |
+
i += length
|
322 |
+
else:
|
323 |
+
ret_list.append(t)
|
324 |
+
i += 1
|
325 |
+
mask_num = 0
|
326 |
+
return [' '.join(ret_list)]
|
327 |
+
|
328 |
+
model = BartForConditionalGeneration.from_pretrained('GanjinZero/biobart-large')
|
329 |
+
model.eval()
|
330 |
+
model.to(device)
|
331 |
+
tokenizer = AutoTokenizer.from_pretrained('GanjinZero/biobart-large')
|
332 |
+
|
333 |
+
ret_candidates = {}
|
334 |
+
dpath_i = 0
|
335 |
+
|
336 |
+
for i,(k, v) in enumerate(tqdm(draft.items())):
|
337 |
+
|
338 |
+
input = v['in'].replace('\n', '')
|
339 |
+
output = v['out'].replace('\n', '')
|
340 |
+
s, r, o = attack_data[i]
|
341 |
+
s = str(s)
|
342 |
+
o = str(o)
|
343 |
+
r = str(r)
|
344 |
+
|
345 |
+
if int(s) == -1:
|
346 |
+
ret_candidates[str(i)] = {'span': '', 'prompt' : '', 'out' : [], 'in': [], 'assist': []}
|
347 |
+
continue
|
348 |
+
|
349 |
+
path_text = dpath[dpath_i].replace('\n', '')
|
350 |
+
dpath_i += 1
|
351 |
+
text_s = entity_raw_name[id_to_meshid[s]]
|
352 |
+
text_o = entity_raw_name[id_to_meshid[o]]
|
353 |
+
|
354 |
+
doc = nlp(output)
|
355 |
+
words= input.split(' ')
|
356 |
+
tokenized_sens = [sen for sen in doc.sents]
|
357 |
+
sens = np.array([sen.text for sen in doc.sents])
|
358 |
+
|
359 |
+
checkset = set([text_s, text_o])
|
360 |
+
e_entity = set(['start_entity', 'end_entity'])
|
361 |
+
for path in path_text.split(' '):
|
362 |
+
a, b, c = path.split('|')
|
363 |
+
if a not in e_entity:
|
364 |
+
checkset.add(a)
|
365 |
+
if c not in e_entity:
|
366 |
+
checkset.add(c)
|
367 |
+
vec = []
|
368 |
+
l = 0
|
369 |
+
while(l < len(words)):
|
370 |
+
bo =False
|
371 |
+
for j in range(len(words), l, -1): # reversing is important !!!
|
372 |
+
cc = ' '.join(words[l:j])
|
373 |
+
if (cc in checkset):
|
374 |
+
vec += [True] * (j-l)
|
375 |
+
l = j
|
376 |
+
bo = True
|
377 |
+
break
|
378 |
+
if not bo:
|
379 |
+
vec.append(False)
|
380 |
+
l += 1
|
381 |
+
vec, span = find_mini_span(vec, words, checkset)
|
382 |
+
# vec = np.vectorize(lambda x: x in checkset)(words)
|
383 |
+
vec[-1] = True
|
384 |
+
prompt = []
|
385 |
+
mask_num = 0
|
386 |
+
for j, bo in enumerate(vec):
|
387 |
+
if not bo:
|
388 |
+
mask_num += 1
|
389 |
+
else:
|
390 |
+
if mask_num > 0:
|
391 |
+
# mask_num = mask_num // 3 # span length ~ poisson distribution (lambda = 3)
|
392 |
+
mask_num = max(mask_num, 1)
|
393 |
+
mask_num= min(8, mask_num)
|
394 |
+
prompt += ['<mask>'] * mask_num
|
395 |
+
prompt.append(words[j])
|
396 |
+
mask_num = 0
|
397 |
+
prompt = ' '.join(prompt)
|
398 |
+
Text = []
|
399 |
+
Assist = []
|
400 |
+
|
401 |
+
for j in range(len(sens)):
|
402 |
+
Bart_input = list(sens[:j]) + [prompt] +list(sens[j+1:])
|
403 |
+
assist = list(sens[:j]) + [input] +list(sens[j+1:])
|
404 |
+
Text.append(' '.join(Bart_input))
|
405 |
+
Assist.append(' '.join(assist))
|
406 |
+
|
407 |
+
for j in range(len(sens)):
|
408 |
+
Bart_input = mask_func(tokenized_sens[:j]) + [input] + mask_func(tokenized_sens[j+1:])
|
409 |
+
assist = list(sens[:j]) + [input] +list(sens[j+1:])
|
410 |
+
Text.append(' '.join(Bart_input))
|
411 |
+
Assist.append(' '.join(assist))
|
412 |
+
|
413 |
+
batch_size = len(Text) // 2
|
414 |
+
Outs = []
|
415 |
+
for l in range(2):
|
416 |
+
A = tokenizer(Text[batch_size * l:batch_size * (l+1)],
|
417 |
+
truncation = True,
|
418 |
+
padding = True,
|
419 |
+
max_length = 1024,
|
420 |
+
return_tensors="pt")
|
421 |
+
input_ids = A['input_ids'].to(device)
|
422 |
+
attention_mask = A['attention_mask'].to(device)
|
423 |
+
aaid = model.generate(input_ids, num_beams = 5, max_length = 1024)
|
424 |
+
outs = tokenizer.batch_decode(aaid, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
425 |
+
Outs += outs
|
426 |
+
ret_candidates[str(i)] = {'span': span, 'prompt' : prompt, 'out' : Outs, 'in': Text, 'assist': Assist}
|
427 |
+
with open(f'generate_abstract/bioBART/{args.init_mode}{args.reasonable_rate}{args.ratio}_candidates.json', 'w') as fl:
|
428 |
+
json.dump(ret_candidates, fl, indent = 4)
|
429 |
+
|
430 |
+
from torch.nn.modules.loss import CrossEntropyLoss
|
431 |
+
from transformers import BioGptForCausalLM
|
432 |
+
criterion = CrossEntropyLoss(reduction="none")
|
433 |
+
|
434 |
+
tokenizer = AutoTokenizer.from_pretrained('microsoft/biogpt')
|
435 |
+
tokenizer.pad_token = tokenizer.eos_token
|
436 |
+
model = BioGptForCausalLM.from_pretrained('microsoft/biogpt', pad_token_id=tokenizer.eos_token_id)
|
437 |
+
model.to(device)
|
438 |
+
model.eval()
|
439 |
+
|
440 |
+
scored = {}
|
441 |
+
ret = {}
|
442 |
+
case_study = {}
|
443 |
+
p_ret = {}
|
444 |
+
add = 0
|
445 |
+
dpath_i = 0
|
446 |
+
inner_better = 0
|
447 |
+
outter_better = 0
|
448 |
+
better_than_gpt = 0
|
449 |
+
for i,(k, v) in enumerate(tqdm(draft.items())):
|
450 |
+
|
451 |
+
span = ret_candidates[str(i)]['span']
|
452 |
+
prompt = ret_candidates[str(i)]['prompt']
|
453 |
+
sen_list = ret_candidates[str(i)]['out']
|
454 |
+
BART_in = ret_candidates[str(i)]['in']
|
455 |
+
Assist = ret_candidates[str(i)]['assist']
|
456 |
+
|
457 |
+
s, r, o = attack_data[i]
|
458 |
+
s = str(s)
|
459 |
+
r = str(r)
|
460 |
+
o = str(o)
|
461 |
+
|
462 |
+
if int(s) == -1:
|
463 |
+
ret[k] = {'prompt': '', 'in':'', 'out': ''}
|
464 |
+
p_ret[k] = {'prompt': '', 'in':'', 'out': ''}
|
465 |
+
continue
|
466 |
+
|
467 |
+
text_s = entity_raw_name[id_to_meshid[s]]
|
468 |
+
text_o = entity_raw_name[id_to_meshid[o]]
|
469 |
+
|
470 |
+
def process(text):
|
471 |
+
|
472 |
+
for i in range(ord('A'), ord('Z')+1):
|
473 |
+
text = text.replace(f'.{chr(i)}', f'. {chr(i)}')
|
474 |
+
return text
|
475 |
+
|
476 |
+
sen_list = [process(text) for text in sen_list]
|
477 |
+
path_text = dpath[dpath_i].replace('\n', '')
|
478 |
+
dpath_i += 1
|
479 |
+
|
480 |
+
checkset = set([text_s, text_o])
|
481 |
+
e_entity = set(['start_entity', 'end_entity'])
|
482 |
+
for path in path_text.split(' '):
|
483 |
+
a, b, c = path.split('|')
|
484 |
+
if a not in e_entity:
|
485 |
+
checkset.add(a)
|
486 |
+
if c not in e_entity:
|
487 |
+
checkset.add(c)
|
488 |
+
|
489 |
+
input = v['in'].replace('\n', '')
|
490 |
+
output = v['out'].replace('\n', '')
|
491 |
+
|
492 |
+
doc = nlp(output)
|
493 |
+
gpt_sens = [sen.text for sen in doc.sents]
|
494 |
+
assert len(gpt_sens) == len(sen_list) // 2
|
495 |
+
|
496 |
+
word_sets = []
|
497 |
+
for sen in gpt_sens:
|
498 |
+
word_sets.append(set(sen.split(' ')))
|
499 |
+
|
500 |
+
def sen_align(word_sets, modified_word_sets):
|
501 |
+
|
502 |
+
l = 0
|
503 |
+
while(l < len(modified_word_sets)):
|
504 |
+
if len(word_sets[l].intersection(modified_word_sets[l])) > len(word_sets[l]) * 0.8:
|
505 |
+
l += 1
|
506 |
+
else:
|
507 |
+
break
|
508 |
+
if l == len(modified_word_sets):
|
509 |
+
return -1, -1, -1, -1
|
510 |
+
r = l + 1
|
511 |
+
r1 = None
|
512 |
+
r2 = None
|
513 |
+
for pos1 in range(r, len(word_sets)):
|
514 |
+
for pos2 in range(r, len(modified_word_sets)):
|
515 |
+
if len(word_sets[pos1].intersection(modified_word_sets[pos2])) > len(word_sets[pos1]) * 0.8:
|
516 |
+
r1 = pos1
|
517 |
+
r2 = pos2
|
518 |
+
break
|
519 |
+
if r1 is not None:
|
520 |
+
break
|
521 |
+
if r1 is None:
|
522 |
+
r1 = len(word_sets)
|
523 |
+
r2 = len(modified_word_sets)
|
524 |
+
return l, r1, l, r2
|
525 |
+
|
526 |
+
replace_sen_list = []
|
527 |
+
boundary = []
|
528 |
+
assert len(sen_list) % 2 == 0
|
529 |
+
for j in range(len(sen_list) // 2):
|
530 |
+
doc = nlp(sen_list[j])
|
531 |
+
sens = [sen.text for sen in doc.sents]
|
532 |
+
modified_word_sets = [set(sen.split(' ')) for sen in sens]
|
533 |
+
l1, r1, l2, r2 = sen_align(word_sets, modified_word_sets)
|
534 |
+
boundary.append((l1, r1, l2, r2))
|
535 |
+
if l1 == -1:
|
536 |
+
replace_sen_list.append(sen_list[j])
|
537 |
+
continue
|
538 |
+
check_text = ' '.join(sens[l2: r2])
|
539 |
+
replace_sen_list.append(' '.join(gpt_sens[:l1] + [check_text] + gpt_sens[r1:]))
|
540 |
+
sen_list = replace_sen_list + sen_list[len(sen_list) // 2:]
|
541 |
+
|
542 |
+
old_L = len(sen_list)
|
543 |
+
sen_list.append(output)
|
544 |
+
sen_list += Assist
|
545 |
+
tokens = tokenizer( sen_list,
|
546 |
+
truncation = True,
|
547 |
+
padding = True,
|
548 |
+
max_length = 1024,
|
549 |
+
return_tensors="pt")
|
550 |
+
target_ids = tokens['input_ids'].to(device)
|
551 |
+
attention_mask = tokens['attention_mask'].to(device)
|
552 |
+
L = len(sen_list)
|
553 |
+
ret_log_L = []
|
554 |
+
for l in range(0, L, 5):
|
555 |
+
R = min(L, l + 5)
|
556 |
+
target = target_ids[l:R, :]
|
557 |
+
attention = attention_mask[l:R, :]
|
558 |
+
outputs = model(input_ids = target,
|
559 |
+
attention_mask = attention,
|
560 |
+
labels = target)
|
561 |
+
logits = outputs.logits
|
562 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
563 |
+
shift_labels = target[..., 1:].contiguous()
|
564 |
+
Loss = criterion(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1))
|
565 |
+
Loss = Loss.view(-1, shift_logits.shape[1])
|
566 |
+
attention = attention[..., 1:].contiguous()
|
567 |
+
log_Loss = (torch.mean(Loss * attention.float(), dim = 1) / torch.mean(attention.float(), dim = 1))
|
568 |
+
ret_log_L.append(log_Loss.detach())
|
569 |
+
log_Loss = torch.cat(ret_log_L, -1).cpu().numpy()
|
570 |
+
|
571 |
+
real_log_Loss = log_Loss.copy()
|
572 |
+
|
573 |
+
log_Loss = log_Loss[:old_L]
|
574 |
+
# sen_list = sen_list[:old_L]
|
575 |
+
|
576 |
+
# mini_span should be preserved
|
577 |
+
# for j in range(len(log_Loss)):
|
578 |
+
# doc = nlp(sen_list[j])
|
579 |
+
# sens = [sen.text for sen in doc.sents]
|
580 |
+
# Len = len(sen_list)
|
581 |
+
# check_text = ' '.join(sens[j : max(0,len(sens) - Len) + j + 1])
|
582 |
+
# if span not in check_text:
|
583 |
+
# log_Loss[j] += 1
|
584 |
+
|
585 |
+
p = np.argmin(log_Loss)
|
586 |
+
if p < old_L // 2:
|
587 |
+
inner_better += 1
|
588 |
+
else:
|
589 |
+
outter_better += 1
|
590 |
+
content = []
|
591 |
+
for i in range(len(real_log_Loss)):
|
592 |
+
content.append([sen_list[i], str(real_log_Loss[i])])
|
593 |
+
scored[k] = {'path':path_text, 'prompt': prompt, 'in':input, 's':text_s, 'o':text_o, 'out': content, 'bound': boundary}
|
594 |
+
p_p = p
|
595 |
+
# print('Old_L:', old_L)
|
596 |
+
|
597 |
+
if real_log_Loss[p] > real_log_Loss[p+1+old_L]:
|
598 |
+
p_p = p+1+old_L
|
599 |
+
if real_log_Loss[p] > real_log_Loss[p+1+old_L]:
|
600 |
+
add += 1
|
601 |
+
|
602 |
+
if real_log_Loss[p] < real_log_Loss[old_L]:
|
603 |
+
better_than_gpt += 1
|
604 |
+
else:
|
605 |
+
if real_log_Loss[p] > real_log_Loss[p+1+old_L]:
|
606 |
+
p = p+1+old_L
|
607 |
+
# case_study[k] = {'path':path_text, 'entity_0': text_s, 'entity_1': text_o, 'GPT_in': input, 'Prompt': prompt, 'GPT_out': {'text': output, 'perplexity': str(np.exp(real_log_Loss[old_L]))}, 'BART_in': BART_in[p], 'BART_out': {'text': sen_list[p], 'perplexity': str(np.exp(real_log_Loss[p]))}, 'Assist': {'text': Assist[p], 'perplexity': str(np.exp(real_log_Loss[p+1+old_L]))}}
|
608 |
+
ret[k] = {'prompt': prompt, 'in':input, 'out': sen_list[p]}
|
609 |
+
p_ret[k] = {'prompt': prompt, 'in':input, 'out': sen_list[p_p]}
|
610 |
+
print(add)
|
611 |
+
print('inner_better:', inner_better)
|
612 |
+
print('outter_better:', outter_better)
|
613 |
+
print('better_than_gpt:', better_than_gpt)
|
614 |
+
print('better_than_replace', add)
|
615 |
+
with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}{args.ratio}_bioBART_finetune.json', 'w') as fl:
|
616 |
+
json.dump(ret, fl, indent=4)
|
617 |
+
# with open(f'generate_abstract/bioBART/case_{args.target_split}_{args.reasonable_rate}_bioBART_finetune.json', 'w') as fl:
|
618 |
+
# json.dump(case_study, fl, indent=4)
|
619 |
+
with open(f'generate_abstract/bioBART/{args.init_mode}{args.reasonable_rate}{args.ratio}_scored.json', 'w') as fl:
|
620 |
+
json.dump(scored, fl, indent=4)
|
621 |
+
with open(f'generate_abstract/bioBART/{args.init_mode}{args.reasonable_rate}{args.ratio}_perplexity.json', 'w') as fl:
|
622 |
+
json.dump(p_ret, fl, indent=4)
|
623 |
+
|
624 |
+
# with open(Parameters.GNBRfile+'original_entity_raw_name', 'rb') as fl:
|
625 |
+
# full_entity_raw_name = pkl.load(fl)
|
626 |
+
# for k, v in entity_raw_name.items():
|
627 |
+
# assert v in full_entity_raw_name[k]
|
628 |
+
|
629 |
+
# nlp = spacy.load("en_core_web_sm")
|
630 |
+
# type_set = set()
|
631 |
+
# for aa in range(36):
|
632 |
+
# dependency_sen_dict = retieve_sentence_through_edgetype[aa]['manual']
|
633 |
+
# tmp_dict = retieve_sentence_through_edgetype[aa]['auto']
|
634 |
+
# dependencys = list(dependency_sen_dict.keys()) + list(tmp_dict.keys())
|
635 |
+
# for dependency in dependencys:
|
636 |
+
# dep_list = dependency.split(' ')
|
637 |
+
# for sub_dep in dep_list:
|
638 |
+
# sub_dep_list = sub_dep.split('|')
|
639 |
+
# assert(len(sub_dep_list) == 3)
|
640 |
+
# type_set.add(sub_dep_list[1])
|
641 |
+
|
642 |
+
# fine_dict = {}
|
643 |
+
# for k, v_dict in draft.items():
|
644 |
+
|
645 |
+
# input = v_dict['in']
|
646 |
+
# output = v_dict['out']
|
647 |
+
# fine_dict[k] = {'in':input, 'out': input + ' ' + output}
|
648 |
+
|
649 |
+
# with open(f'generate_abstract/{args.target_split}_{args.reasonable_rate}_sentence_finetune.json', 'w') as fl:
|
650 |
+
# json.dump(fine_dict, fl, indent=4)
|
651 |
+
else:
|
652 |
+
raise Exception('Wrong mode !!')
|
DiseaseAgnostic/evaluation.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#%%
|
2 |
+
import logging
|
3 |
+
from symbol import parameters
|
4 |
+
from textwrap import indent
|
5 |
+
import os
|
6 |
+
import tempfile
|
7 |
+
import sys
|
8 |
+
from matplotlib import collections
|
9 |
+
import pandas as pd
|
10 |
+
import json
|
11 |
+
from glob import glob
|
12 |
+
from tqdm import tqdm
|
13 |
+
import numpy as np
|
14 |
+
from pprint import pprint
|
15 |
+
import torch
|
16 |
+
import pickle as pkl
|
17 |
+
from collections import Counter
|
18 |
+
# print(dir(collections))
|
19 |
+
import networkx as nx
|
20 |
+
from collections import Counter
|
21 |
+
import utils
|
22 |
+
from torch.nn import functional as F
|
23 |
+
sys.path.append("..")
|
24 |
+
import Parameters
|
25 |
+
from DiseaseSpecific.attack import calculate_edge_bound, get_model_loss_without_softmax
|
26 |
+
|
27 |
+
#%%
|
28 |
+
def load_data(file_name):
|
29 |
+
df = pd.read_csv(file_name, sep='\t', header=None, names=None, dtype=str)
|
30 |
+
df = df.drop_duplicates()
|
31 |
+
return df.values
|
32 |
+
|
33 |
+
parser = utils.get_argument_parser()
|
34 |
+
parser.add_argument('--reasonable-rate', type = float, default=0.7, help = 'The added edge\'s existance rank prob greater than this rate')
|
35 |
+
parser.add_argument('--mode', type = str, default='', help = ' "" or chat or bioBART')
|
36 |
+
parser.add_argument('--init-mode', type = str, default='random', help = 'How to select target nodes') # 'single' for case study
|
37 |
+
parser.add_argument('--added-edge-num', type = str, default = '', help = 'Added edge num')
|
38 |
+
|
39 |
+
args = parser.parse_args()
|
40 |
+
args = utils.set_hyperparams(args)
|
41 |
+
utils.seed_all(args.seed)
|
42 |
+
graph_edge_path = '../DiseaseSpecific/processed_data/GNBR/all.txt'
|
43 |
+
idtomeshid_path = '../DiseaseSpecific/processed_data/GNBR/entities_reverse_dict.json'
|
44 |
+
model_path = f'../DiseaseSpecific/saved_models/GNBR_{args.model}_128_0.2_0.3_0.3.model'
|
45 |
+
data_path = '../DiseaseSpecific/processed_data/GNBR'
|
46 |
+
target_path = f'processed_data/target_{args.reasonable_rate}{args.init_mode}.pkl'
|
47 |
+
attack_path = f'processed_data/attack_edge_{args.model}_{args.reasonable_rate}{args.init_mode}{args.added_edge_num}{args.mode}.pkl'
|
48 |
+
|
49 |
+
with open(Parameters.GNBRfile+'original_entity_raw_name', 'rb') as fl:
|
50 |
+
full_entity_raw_name = pkl.load(fl)
|
51 |
+
|
52 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
53 |
+
|
54 |
+
# device = torch.device("cpu")
|
55 |
+
|
56 |
+
args.device = device
|
57 |
+
n_ent, n_rel, ent_to_id, rel_to_id = utils.generate_dicts(data_path)
|
58 |
+
model = utils.load_model(model_path, args, n_ent, n_rel, args.device)
|
59 |
+
|
60 |
+
graph_edge = utils.load_data(graph_edge_path)
|
61 |
+
with open(idtomeshid_path, 'r') as fl:
|
62 |
+
idtomeshid = json.load(fl)
|
63 |
+
print(graph_edge.shape, len(idtomeshid))
|
64 |
+
|
65 |
+
divide_bound, data_mean, data_std = calculate_edge_bound(graph_edge, model, args.device, n_ent)
|
66 |
+
print('Defender ...')
|
67 |
+
print(divide_bound, data_mean, data_std)
|
68 |
+
|
69 |
+
meshids = list(idtomeshid.values())
|
70 |
+
cal = {
|
71 |
+
'chemical' : 0,
|
72 |
+
'disease' : 0,
|
73 |
+
'gene' : 0
|
74 |
+
}
|
75 |
+
for meshid in meshids:
|
76 |
+
cal[meshid.split('_')[0]] += 1
|
77 |
+
# pprint(cal)
|
78 |
+
|
79 |
+
def check_reasonable(s, r, o):
|
80 |
+
|
81 |
+
train_trip = np.asarray([[s, r, o]])
|
82 |
+
train_trip = torch.from_numpy(train_trip.astype('int64')).to(device)
|
83 |
+
edge_loss = get_model_loss_without_softmax(train_trip, model, device).squeeze()
|
84 |
+
# edge_losse_log_prob = torch.log(F.softmax(-edge_loss, dim = -1))
|
85 |
+
|
86 |
+
edge_loss = edge_loss.item()
|
87 |
+
edge_loss = (edge_loss - data_mean) / data_std
|
88 |
+
edge_losses_prob = 1 / ( 1 + np.exp(edge_loss - divide_bound) )
|
89 |
+
bound = 1 - args.reasonable_rate
|
90 |
+
|
91 |
+
return (edge_losses_prob > bound), edge_losses_prob
|
92 |
+
|
93 |
+
edgeid_to_edgetype = {}
|
94 |
+
edgeid_to_reversemask = {}
|
95 |
+
for k, id_list in Parameters.edge_type_to_id.items():
|
96 |
+
for iid, mask in zip(id_list, Parameters.reverse_mask[k]):
|
97 |
+
edgeid_to_edgetype[str(iid)] = k
|
98 |
+
edgeid_to_reversemask[str(iid)] = mask
|
99 |
+
|
100 |
+
with open(target_path, 'rb') as fl:
|
101 |
+
Target_node_list = pkl.load(fl)
|
102 |
+
with open(attack_path, 'rb') as fl:
|
103 |
+
Attack_edge_list = pkl.load(fl)
|
104 |
+
with open(Parameters.UMLSfile+'drug_term', 'rb') as fl:
|
105 |
+
drug_term = pkl.load(fl)
|
106 |
+
with open(Parameters.GNBRfile+'entity_raw_name', 'rb') as fl:
|
107 |
+
entity_raw_name = pkl.load(fl)
|
108 |
+
drug_meshid = []
|
109 |
+
for meshid, nm in entity_raw_name.items():
|
110 |
+
if nm.lower() in drug_term and meshid.split('_')[0] == 'chemical':
|
111 |
+
drug_meshid.append(meshid)
|
112 |
+
drug_meshid = set(drug_meshid)
|
113 |
+
|
114 |
+
if args.init_mode == 'single':
|
115 |
+
name_list = []
|
116 |
+
for target in Target_node_list:
|
117 |
+
name = entity_raw_name[idtomeshid[str(target)]]
|
118 |
+
name_list.append(name)
|
119 |
+
with open(f'results/name_list_{args.reasonable_rate}{args.init_mode}.txt', 'w') as fl:
|
120 |
+
fl.write('\n'.join(name_list))
|
121 |
+
# print(Target_node_list)
|
122 |
+
# # print(Attack_edge_list)
|
123 |
+
# addset = set()
|
124 |
+
# if args.added_edge_num == 1:
|
125 |
+
# for edge in Attack_edge_list:
|
126 |
+
# addset.add(edge[2])
|
127 |
+
# else:
|
128 |
+
# for edge_list in Attack_edge_list:
|
129 |
+
# for edge in edge_list:
|
130 |
+
# addset.add(edge[2])
|
131 |
+
# print(addset)
|
132 |
+
# print(len(addset))
|
133 |
+
# typeset = set()
|
134 |
+
# for iid in addset:
|
135 |
+
# typeset.add(idtomeshid[str(iid)].split('_')[0])
|
136 |
+
# print(typeset)
|
137 |
+
# raise Exception('done')
|
138 |
+
|
139 |
+
if args.init_mode == 'single':
|
140 |
+
Target_node_list = [[Target_node_list[i]] for i in range(len(Target_node_list))]
|
141 |
+
Attack_edge_list = [[Attack_edge_list[i]] for i in range(len(Attack_edge_list))]
|
142 |
+
else:
|
143 |
+
print(len(Attack_edge_list), len(Target_node_list))
|
144 |
+
tmp_target_node_list = []
|
145 |
+
tmp_attack_edge_list = []
|
146 |
+
for l in range(0,len(Target_node_list), 50):
|
147 |
+
r = min(l+50, len(Target_node_list))
|
148 |
+
tmp_target_node_list.append(Target_node_list[l:r])
|
149 |
+
tmp_attack_edge_list.append(Attack_edge_list[l:r])
|
150 |
+
Target_node_list = tmp_target_node_list
|
151 |
+
Attack_edge_list = tmp_attack_edge_list
|
152 |
+
|
153 |
+
# for i, init_p in enumerate([0.1, 0.3, 0.5, 0.7, 0.9]):
|
154 |
+
|
155 |
+
# target_node_list = Target_node_list[i]
|
156 |
+
# attack_edge_list = Attack_edge_list[i]
|
157 |
+
Init = []
|
158 |
+
After = []
|
159 |
+
# final_init = []
|
160 |
+
# final_after = []
|
161 |
+
for i, (target_node_list, attack_edge_list) in enumerate(zip(Target_node_list, Attack_edge_list)):
|
162 |
+
|
163 |
+
G = nx.DiGraph()
|
164 |
+
for s, r, o in graph_edge:
|
165 |
+
assert idtomeshid[s].split('_')[0] == edgeid_to_edgetype[r].split('-')[0]
|
166 |
+
if edgeid_to_reversemask[r] == 1:
|
167 |
+
G.add_edge(int(o), int(s))
|
168 |
+
else:
|
169 |
+
G.add_edge(int(s), int(o))
|
170 |
+
|
171 |
+
pagerank_value_1 = nx.pagerank(G, max_iter = 200, tol=1.0e-7)
|
172 |
+
|
173 |
+
for target, attack_list in tqdm(list(zip(target_node_list, attack_edge_list))):
|
174 |
+
pr = list(pagerank_value_1.items())
|
175 |
+
pr.sort(key = lambda x: x[1])
|
176 |
+
list_iid = []
|
177 |
+
for iid, score in pr:
|
178 |
+
tp = idtomeshid[str(iid)].split('_')[0]
|
179 |
+
if tp == 'chemical':
|
180 |
+
# if idtomeshid[str(iid)] in drug_meshid:
|
181 |
+
list_iid.append(iid)
|
182 |
+
init_rank = len(list_iid) - list_iid.index(target)
|
183 |
+
# init_rank = 1 - list_iid.index(target) / len(list_iid)
|
184 |
+
Init.append(init_rank)
|
185 |
+
|
186 |
+
for target, attack_list in tqdm(list(zip(target_node_list, attack_edge_list))):
|
187 |
+
|
188 |
+
if args.mode == '' and (args.added_edge_num == '' or int(args.added_edge_num) == 1):
|
189 |
+
if int(attack_list[0]) == -1:
|
190 |
+
attack_list = []
|
191 |
+
else:
|
192 |
+
attack_list = [attack_list]
|
193 |
+
if len(attack_list) > 0:
|
194 |
+
for s, r, o in attack_list:
|
195 |
+
bo, prob = check_reasonable(s, r, o)
|
196 |
+
if bo:
|
197 |
+
if edgeid_to_reversemask[str(r)] == 1:
|
198 |
+
G.add_edge(int(o), int(s))
|
199 |
+
else:
|
200 |
+
G.add_edge(int(s), int(o))
|
201 |
+
pagerank_value_1 = nx.pagerank(G, max_iter = 200, tol=1.0e-7)
|
202 |
+
for target, attack_list in tqdm(list(zip(target_node_list, attack_edge_list))):
|
203 |
+
pr = list(pagerank_value_1.items())
|
204 |
+
pr.sort(key = lambda x: x[1])
|
205 |
+
list_iid = []
|
206 |
+
for iid, score in pr:
|
207 |
+
tp = idtomeshid[str(iid)].split('_')[0]
|
208 |
+
if tp == 'chemical':
|
209 |
+
# if idtomeshid[str(iid)] in drug_meshid:
|
210 |
+
list_iid.append(iid)
|
211 |
+
after_rank = len(list_iid) - list_iid.index(target)
|
212 |
+
# after_rank = 1 - list_iid.index(target) / len(list_iid)
|
213 |
+
After.append(after_rank)
|
214 |
+
with open(f'results/Init_{args.reasonable_rate}{args.init_mode}.pkl', 'wb') as fl:
|
215 |
+
pkl.dump(Init, fl)
|
216 |
+
with open(f'results/After_{args.model}_{args.reasonable_rate}{args.init_mode}{args.added_edge_num}{args.mode}.pkl', 'wb') as fl:
|
217 |
+
pkl.dump(After, fl)
|
218 |
+
print(np.mean(Init), np.std(Init))
|
219 |
+
print(np.mean(After), np.std(After))
|
DiseaseAgnostic/generate_target_and_attack.py
ADDED
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#%%
|
2 |
+
import logging
|
3 |
+
from symbol import parameters
|
4 |
+
from textwrap import indent
|
5 |
+
import os
|
6 |
+
import tempfile
|
7 |
+
import sys
|
8 |
+
from matplotlib import collections
|
9 |
+
import pandas as pd
|
10 |
+
import json
|
11 |
+
from glob import glob
|
12 |
+
from tqdm import tqdm
|
13 |
+
import numpy as np
|
14 |
+
from pprint import pprint
|
15 |
+
import torch
|
16 |
+
import pickle as pkl
|
17 |
+
from collections import Counter
|
18 |
+
# print(dir(collections))
|
19 |
+
import networkx as nx
|
20 |
+
from collections import Counter
|
21 |
+
import utils
|
22 |
+
from torch.nn import functional as F
|
23 |
+
sys.path.append("..")
|
24 |
+
import Parameters
|
25 |
+
from DiseaseSpecific.attack import calculate_edge_bound, get_model_loss_without_softmax
|
26 |
+
|
27 |
+
#%%
|
28 |
+
def load_data(file_name):
|
29 |
+
df = pd.read_csv(file_name, sep='\t', header=None, names=None, dtype=str)
|
30 |
+
df = df.drop_duplicates()
|
31 |
+
return df.values
|
32 |
+
|
33 |
+
parser = utils.get_argument_parser()
|
34 |
+
parser.add_argument('--reasonable-rate', type = float, default=0.7, help = 'The added edge\'s existance rank prob greater than this rate')
|
35 |
+
parser.add_argument('--init-mode', type = str, default='single', help = 'How to select target nodes') # 'single' for case study
|
36 |
+
parser.add_argument('--added-edge-num', type = str, default = '', help = 'Added edge num')
|
37 |
+
|
38 |
+
args = parser.parse_args()
|
39 |
+
args = utils.set_hyperparams(args)
|
40 |
+
utils.seed_all(args.seed)
|
41 |
+
graph_edge_path = '../DiseaseSpecific/processed_data/GNBR/all.txt'
|
42 |
+
idtomeshid_path = '../DiseaseSpecific/processed_data/GNBR/entities_reverse_dict.json'
|
43 |
+
model_path = f'../DiseaseSpecific/saved_models/GNBR_{args.model}_128_0.2_0.3_0.3.model'
|
44 |
+
data_path = '../DiseaseSpecific/processed_data/GNBR'
|
45 |
+
with open(Parameters.GNBRfile+'original_entity_raw_name', 'rb') as fl:
|
46 |
+
full_entity_raw_name = pkl.load(fl)
|
47 |
+
|
48 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
49 |
+
args.device = device
|
50 |
+
n_ent, n_rel, ent_to_id, rel_to_id = utils.generate_dicts(data_path)
|
51 |
+
model = utils.load_model(model_path, args, n_ent, n_rel, args.device)
|
52 |
+
print(device)
|
53 |
+
|
54 |
+
graph_edge = utils.load_data(graph_edge_path)
|
55 |
+
with open(idtomeshid_path, 'r') as fl:
|
56 |
+
idtomeshid = json.load(fl)
|
57 |
+
print(graph_edge.shape, len(idtomeshid))
|
58 |
+
|
59 |
+
divide_bound, data_mean, data_std = calculate_edge_bound(graph_edge, model, args.device, n_ent)
|
60 |
+
print('Defender ...')
|
61 |
+
print(divide_bound, data_mean, data_std)
|
62 |
+
|
63 |
+
meshids = list(idtomeshid.values())
|
64 |
+
cal = {
|
65 |
+
'chemical' : 0,
|
66 |
+
'disease' : 0,
|
67 |
+
'gene' : 0
|
68 |
+
}
|
69 |
+
for meshid in meshids:
|
70 |
+
cal[meshid.split('_')[0]] += 1
|
71 |
+
# pprint(cal)
|
72 |
+
|
73 |
+
def check_reasonable(s, r, o):
|
74 |
+
|
75 |
+
train_trip = np.asarray([[s, r, o]])
|
76 |
+
train_trip = torch.from_numpy(train_trip.astype('int64')).to(device)
|
77 |
+
edge_loss = get_model_loss_without_softmax(train_trip, model, device).squeeze()
|
78 |
+
# edge_losse_log_prob = torch.log(F.softmax(-edge_loss, dim = -1))
|
79 |
+
|
80 |
+
edge_loss = edge_loss.item()
|
81 |
+
edge_loss = (edge_loss - data_mean) / data_std
|
82 |
+
edge_losses_prob = 1 / ( 1 + np.exp(edge_loss - divide_bound) )
|
83 |
+
bound = 1 - args.reasonable_rate
|
84 |
+
|
85 |
+
return (edge_losses_prob > bound), edge_losses_prob
|
86 |
+
|
87 |
+
edgeid_to_edgetype = {}
|
88 |
+
edgeid_to_reversemask = {}
|
89 |
+
for k, id_list in Parameters.edge_type_to_id.items():
|
90 |
+
for iid, mask in zip(id_list, Parameters.reverse_mask[k]):
|
91 |
+
edgeid_to_edgetype[str(iid)] = k
|
92 |
+
edgeid_to_reversemask[str(iid)] = mask
|
93 |
+
reverse_tot = 0
|
94 |
+
G = nx.DiGraph()
|
95 |
+
for s, r, o in graph_edge:
|
96 |
+
assert idtomeshid[s].split('_')[0] == edgeid_to_edgetype[r].split('-')[0]
|
97 |
+
if edgeid_to_reversemask[r] == 1:
|
98 |
+
reverse_tot += 1
|
99 |
+
G.add_edge(int(o), int(s))
|
100 |
+
else:
|
101 |
+
G.add_edge(int(s), int(o))
|
102 |
+
# print(reverse_tot)
|
103 |
+
print('Edge num:', G.number_of_edges(), 'Node num:', G.number_of_nodes())
|
104 |
+
pagerank_value_1 = nx.pagerank(G, max_iter = 200, tol=1.0e-7)
|
105 |
+
|
106 |
+
#%%
|
107 |
+
with open(Parameters.UMLSfile+'drug_term', 'rb') as fl:
|
108 |
+
drug_term = pkl.load(fl)
|
109 |
+
with open(Parameters.GNBRfile+'entity_raw_name', 'rb') as fl:
|
110 |
+
entity_raw_name = pkl.load(fl)
|
111 |
+
drug_meshid = []
|
112 |
+
for meshid, nm in entity_raw_name.items():
|
113 |
+
if nm.lower() in drug_term and meshid.split('_')[0] == 'chemical':
|
114 |
+
drug_meshid.append(meshid)
|
115 |
+
drug_meshid = set(drug_meshid)
|
116 |
+
pr = list(pagerank_value_1.items())
|
117 |
+
pr.sort(key = lambda x: x[1])
|
118 |
+
sorted_rank = { 'chemical' : [],
|
119 |
+
'gene' : [],
|
120 |
+
'disease': [],
|
121 |
+
'merged' : []}
|
122 |
+
for iid, score in pr:
|
123 |
+
tp = idtomeshid[str(iid)].split('_')[0]
|
124 |
+
if tp == 'chemical':
|
125 |
+
if idtomeshid[str(iid)] in drug_meshid:
|
126 |
+
sorted_rank[tp].append((iid, score))
|
127 |
+
else:
|
128 |
+
sorted_rank[tp].append((iid, score))
|
129 |
+
sorted_rank['merged'].append((iid, score))
|
130 |
+
llen = len(sorted_rank['merged'])
|
131 |
+
sorted_rank['merged'] = sorted_rank['merged'][llen * 3 // 4 : ]
|
132 |
+
print(len(sorted_rank['chemical']))
|
133 |
+
print(len(sorted_rank['gene']), len(sorted_rank['disease']), len(sorted_rank['merged']))
|
134 |
+
|
135 |
+
#%%
|
136 |
+
Target_node_list = []
|
137 |
+
Attack_edge_list = []
|
138 |
+
if args.init_mode == '':
|
139 |
+
|
140 |
+
if args.added_edge_num != '' and args.added_edge_num != '1':
|
141 |
+
raise Exception('added_edge_num must be 1 when init_mode=='' ')
|
142 |
+
for init_p in [0.1, 0.3, 0.5, 0.7, 0.9]:
|
143 |
+
|
144 |
+
p = len(sorted_rank['chemical']) * init_p
|
145 |
+
print('Init p:', init_p)
|
146 |
+
target_node_list = []
|
147 |
+
attack_edge_list = []
|
148 |
+
num_max_eq = 0
|
149 |
+
mean_rank_of_total_max = 0
|
150 |
+
for pp in tqdm(range(int(p)-10, int(p)+10)):
|
151 |
+
target = sorted_rank['chemical'][pp][0]
|
152 |
+
target_node_list.append(target)
|
153 |
+
|
154 |
+
candidate_list = []
|
155 |
+
score_list = []
|
156 |
+
loss_list = []
|
157 |
+
for iid, score in sorted_rank['merged']:
|
158 |
+
a = G.number_of_edges(iid, target) + 1
|
159 |
+
if a != 1:
|
160 |
+
continue
|
161 |
+
b = G.out_degree(iid) + 1
|
162 |
+
tp = idtomeshid[str(iid)].split('_')[0]
|
163 |
+
edge_losses = []
|
164 |
+
r_list = []
|
165 |
+
for r in range(len(edgeid_to_edgetype)):
|
166 |
+
r_tp = edgeid_to_edgetype[str(r)]
|
167 |
+
if (edgeid_to_reversemask[str(r)] == 0 and r_tp.split('-')[0] == tp and r_tp.split('-')[1] == 'chemical'):
|
168 |
+
train_trip = np.array([[iid, r, target]])
|
169 |
+
train_trip = torch.from_numpy(train_trip.astype('int64')).to(device)
|
170 |
+
edge_loss = get_model_loss_without_softmax(train_trip, model, device).squeeze()
|
171 |
+
edge_losses.append(edge_loss.unsqueeze(0).detach())
|
172 |
+
r_list.append(r)
|
173 |
+
elif(edgeid_to_reversemask[str(r)] == 1 and r_tp.split('-')[0] == 'chemical' and r_tp.split('-')[1] == tp):
|
174 |
+
train_trip = np.array([[iid, r, target]]) # add batch dim
|
175 |
+
train_trip = torch.from_numpy(train_trip.astype('int64')).to(device)
|
176 |
+
edge_loss = get_model_loss_without_softmax(train_trip, model, device).squeeze()
|
177 |
+
edge_losses.append(edge_loss.unsqueeze(0).detach())
|
178 |
+
r_list.append(r)
|
179 |
+
if len(edge_losses)==0:
|
180 |
+
continue
|
181 |
+
min_index = torch.argmin(torch.cat(edge_losses, dim = 0))
|
182 |
+
r = r_list[min_index]
|
183 |
+
r_tp = edgeid_to_edgetype[str(r)]
|
184 |
+
|
185 |
+
if (edgeid_to_reversemask[str(r)] == 0):
|
186 |
+
bo, prob = check_reasonable(iid, r, target)
|
187 |
+
if bo:
|
188 |
+
candidate_list.append((iid, r, target))
|
189 |
+
score_list.append(score * a / b)
|
190 |
+
loss_list.append(edge_losses[min_index].item())
|
191 |
+
if (edgeid_to_reversemask[str(r)] == 1):
|
192 |
+
bo, prob = check_reasonable(target, r, iid)
|
193 |
+
if bo:
|
194 |
+
candidate_list.append((target, r, iid))
|
195 |
+
score_list.append(score * a / b)
|
196 |
+
loss_list.append(edge_losses[min_index].item())
|
197 |
+
|
198 |
+
if len(candidate_list) == 0:
|
199 |
+
attack_edge_list.append((-1, -1, -1))
|
200 |
+
continue
|
201 |
+
norm_score = np.array(score_list) / np.sum(score_list)
|
202 |
+
norm_loss = np.exp(-np.array(loss_list)) / np.sum(np.exp(-np.array(loss_list)))
|
203 |
+
|
204 |
+
total_score = norm_score * norm_loss
|
205 |
+
max_index = np.argmax(total_score)
|
206 |
+
attack_edge_list.append(candidate_list[max_index])
|
207 |
+
|
208 |
+
score_max_index = np.argmax(norm_score)
|
209 |
+
if score_max_index == max_index:
|
210 |
+
num_max_eq += 1
|
211 |
+
|
212 |
+
score_index_list = list(zip(list(range(len(norm_score))), norm_score))
|
213 |
+
score_index_list.sort(key = lambda x: x[1], reverse = True)
|
214 |
+
max_index_in_score = score_index_list.index((max_index, norm_score[max_index]))
|
215 |
+
mean_rank_of_total_max += max_index_in_score / len(norm_score)
|
216 |
+
print('num_max_eq:', num_max_eq)
|
217 |
+
print('mean_rank_of_total_max:', mean_rank_of_total_max / 20)
|
218 |
+
Target_node_list.append(target_node_list)
|
219 |
+
Attack_edge_list.append(attack_edge_list)
|
220 |
+
else:
|
221 |
+
assert args.init_mode == 'random' or args.init_mode == 'single'
|
222 |
+
print(f'Init mode : {args.init_mode}')
|
223 |
+
utils.seed_all(args.seed)
|
224 |
+
|
225 |
+
if args.init_mode == 'random':
|
226 |
+
index = np.random.choice(len(sorted_rank['chemical']), 400, replace = False)
|
227 |
+
else:
|
228 |
+
# index = [5807, 6314, 5799, 5831, 3954, 5654, 5649, 5624, 2412, 2407]
|
229 |
+
|
230 |
+
index = np.random.choice(len(sorted_rank['chemical']), 400, replace = False)
|
231 |
+
with open(f'../pagerank/results/After_distmult_0.7random10.pkl', 'rb') as fl:
|
232 |
+
edge = pkl.load(fl)
|
233 |
+
with open('../pagerank/results/Init_0.7random.pkl', 'rb') as fl:
|
234 |
+
init = pkl.load(fl)
|
235 |
+
increase = (np.array(init) - np.array(edge)) / np.array(init)
|
236 |
+
increase = increase.reshape(-1)
|
237 |
+
selected_index = np.argsort(increase)[::-1][:10]
|
238 |
+
# print(selected_index)
|
239 |
+
# print(increase[selected_index])
|
240 |
+
# print(np.array(init)[selected_index])
|
241 |
+
# print(np.array(edge)[selected_index])
|
242 |
+
index = [index[i] for i in selected_index]
|
243 |
+
# llen = len(sorted_rank['chemical'])
|
244 |
+
# index = np.random.choice(range(llen//4, llen), 4, replace = False)
|
245 |
+
# index = selected_index + list(index)
|
246 |
+
# for i in index:
|
247 |
+
# ii = str(sorted_rank['chemical'][i][0])
|
248 |
+
# nm = entity_raw_name[idtomeshid[ii]]
|
249 |
+
# nmset = full_entity_raw_name[idtomeshid[ii]]
|
250 |
+
# print('**'*10)
|
251 |
+
# print(i)
|
252 |
+
# print(nm)
|
253 |
+
# print(nmset)
|
254 |
+
# raise Exception('stop')
|
255 |
+
target_node_list = []
|
256 |
+
attack_edge_list = []
|
257 |
+
num_max_eq = 0
|
258 |
+
mean_rank_of_total_max = 0
|
259 |
+
|
260 |
+
for pp in tqdm(index):
|
261 |
+
target = sorted_rank['chemical'][pp][0]
|
262 |
+
target_node_list.append(target)
|
263 |
+
|
264 |
+
print('Target:', entity_raw_name[idtomeshid[str(target)]])
|
265 |
+
|
266 |
+
candidate_list = []
|
267 |
+
score_list = []
|
268 |
+
loss_list = []
|
269 |
+
main_dict = {}
|
270 |
+
for iid, score in sorted_rank['merged']:
|
271 |
+
a = G.number_of_edges(iid, target) + 1
|
272 |
+
if a != 1:
|
273 |
+
continue
|
274 |
+
b = G.out_degree(iid) + 1
|
275 |
+
tp = idtomeshid[str(iid)].split('_')[0]
|
276 |
+
edge_losses = []
|
277 |
+
r_list = []
|
278 |
+
for r in range(len(edgeid_to_edgetype)):
|
279 |
+
r_tp = edgeid_to_edgetype[str(r)]
|
280 |
+
if (edgeid_to_reversemask[str(r)] == 0 and r_tp.split('-')[0] == tp and r_tp.split('-')[1] == 'chemical'):
|
281 |
+
train_trip = np.array([[iid, r, target]])
|
282 |
+
train_trip = torch.from_numpy(train_trip.astype('int64')).to(device)
|
283 |
+
edge_loss = get_model_loss_without_softmax(train_trip, model, device).squeeze()
|
284 |
+
edge_losses.append(edge_loss.unsqueeze(0).detach())
|
285 |
+
r_list.append(r)
|
286 |
+
elif(edgeid_to_reversemask[str(r)] == 1 and r_tp.split('-')[0] == 'chemical' and r_tp.split('-')[1] == tp):
|
287 |
+
train_trip = np.array([[iid, r, target]]) # add batch dim
|
288 |
+
train_trip = torch.from_numpy(train_trip.astype('int64')).to(device)
|
289 |
+
edge_loss = get_model_loss_without_softmax(train_trip, model, device).squeeze()
|
290 |
+
edge_losses.append(edge_loss.unsqueeze(0).detach())
|
291 |
+
r_list.append(r)
|
292 |
+
if len(edge_losses)==0:
|
293 |
+
continue
|
294 |
+
min_index = torch.argmin(torch.cat(edge_losses, dim = 0))
|
295 |
+
r = r_list[min_index]
|
296 |
+
r_tp = edgeid_to_edgetype[str(r)]
|
297 |
+
|
298 |
+
|
299 |
+
old_len = len(candidate_list)
|
300 |
+
if (edgeid_to_reversemask[str(r)] == 0):
|
301 |
+
bo, prob = check_reasonable(iid, r, target)
|
302 |
+
if bo:
|
303 |
+
candidate_list.append((iid, r, target))
|
304 |
+
score_list.append(score * a / b)
|
305 |
+
loss_list.append(edge_losses[min_index].item())
|
306 |
+
if (edgeid_to_reversemask[str(r)] == 1):
|
307 |
+
bo, prob = check_reasonable(target, r, iid)
|
308 |
+
if bo:
|
309 |
+
candidate_list.append((target, r, iid))
|
310 |
+
score_list.append(score * a / b)
|
311 |
+
loss_list.append(edge_losses[min_index].item())
|
312 |
+
|
313 |
+
if len(candidate_list) != old_len:
|
314 |
+
if int(iid) in main_iid:
|
315 |
+
main_dict[iid] = len(candidate_list) - 1
|
316 |
+
|
317 |
+
if len(candidate_list) == 0:
|
318 |
+
if args.added_edge_num == '' or int(args.added_edge_num) == 1:
|
319 |
+
attack_edge_list.append((-1,-1,-1))
|
320 |
+
else:
|
321 |
+
attack_edge_list.append([])
|
322 |
+
continue
|
323 |
+
norm_score = np.array(score_list) / np.sum(score_list)
|
324 |
+
norm_loss = np.exp(-np.array(loss_list)) / np.sum(np.exp(-np.array(loss_list)))
|
325 |
+
|
326 |
+
total_score = norm_score * norm_loss
|
327 |
+
total_score_index = list(zip(range(len(total_score)), total_score))
|
328 |
+
total_score_index.sort(key = lambda x: x[1], reverse = True)
|
329 |
+
|
330 |
+
norm_score_index = np.argsort(norm_score)[::-1]
|
331 |
+
norm_loss_index = np.argsort(norm_loss)[::-1]
|
332 |
+
total_index = np.argsort(total_score)[::-1]
|
333 |
+
assert total_index[0] == total_score_index[0][0]
|
334 |
+
# find rank of main index
|
335 |
+
for k, v in main_dict.items():
|
336 |
+
k = int(k)
|
337 |
+
index = v
|
338 |
+
print(f'score rank of {entity_raw_name[idtomeshid[str(k)]]}: ', norm_score_index.tolist().index(index))
|
339 |
+
print(f'loss rank of {entity_raw_name[idtomeshid[str(k)]]}: ', norm_loss_index.tolist().index(index))
|
340 |
+
print(f'total rank of {entity_raw_name[idtomeshid[str(k)]]}: ', total_index.tolist().index(index))
|
341 |
+
|
342 |
+
max_index = np.argmax(total_score)
|
343 |
+
assert max_index == total_score_index[0][0]
|
344 |
+
|
345 |
+
tmp_add = []
|
346 |
+
add_num = 1
|
347 |
+
if args.added_edge_num == '' or int(args.added_edge_num) == 1:
|
348 |
+
attack_edge_list.append(candidate_list[max_index])
|
349 |
+
else:
|
350 |
+
add_num = int(args.added_edge_num)
|
351 |
+
for i in range(add_num):
|
352 |
+
tmp_add.append(candidate_list[total_score_index[i][0]])
|
353 |
+
attack_edge_list.append(tmp_add)
|
354 |
+
|
355 |
+
score_max_index = np.argmax(norm_score)
|
356 |
+
if score_max_index == max_index:
|
357 |
+
num_max_eq += 1
|
358 |
+
score_index_list = list(zip(list(range(len(norm_score))), norm_score))
|
359 |
+
score_index_list.sort(key = lambda x: x[1], reverse = True)
|
360 |
+
max_index_in_score = score_index_list.index((max_index, norm_score[max_index]))
|
361 |
+
mean_rank_of_total_max += max_index_in_score / len(norm_score)
|
362 |
+
print('num_max_eq:', num_max_eq)
|
363 |
+
print('mean_rank_of_total_max:', mean_rank_of_total_max / 400)
|
364 |
+
Target_node_list = target_node_list
|
365 |
+
Attack_edge_list = attack_edge_list
|
366 |
+
print(np.array(Target_node_list).shape)
|
367 |
+
print(np.array(Attack_edge_list).shape)
|
368 |
+
# with open(f'processed_data/target_{args.reasonable_rate}{args.init_mode}.pkl', 'wb') as fl:
|
369 |
+
# pkl.dump(Target_node_list, fl)
|
370 |
+
# with open(f'processed_data/attack_edge_{args.model}_{args.reasonable_rate}{args.init_mode}{args.added_edge_num}.pkl', 'wb') as fl:
|
371 |
+
# pkl.dump(Attack_edge_list, fl)
|
DiseaseAgnostic/model.py
ADDED
@@ -0,0 +1,520 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.nn import functional as F, Parameter
|
3 |
+
from torch.autograd import Variable
|
4 |
+
from torch.nn.init import xavier_normal_, xavier_uniform_
|
5 |
+
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
6 |
+
|
7 |
+
class Distmult(torch.nn.Module):
|
8 |
+
def __init__(self, args, num_entities, num_relations):
|
9 |
+
super(Distmult, self).__init__()
|
10 |
+
|
11 |
+
if args.max_norm:
|
12 |
+
self.emb_e = torch.nn.Embedding(num_entities, args.embedding_dim, max_norm=1.0)
|
13 |
+
self.emb_rel = torch.nn.Embedding(num_relations, args.embedding_dim)
|
14 |
+
else:
|
15 |
+
self.emb_e = torch.nn.Embedding(num_entities, args.embedding_dim, padding_idx=None)
|
16 |
+
self.emb_rel = torch.nn.Embedding(num_relations, args.embedding_dim, padding_idx=None)
|
17 |
+
|
18 |
+
self.inp_drop = torch.nn.Dropout(args.input_drop)
|
19 |
+
self.loss = torch.nn.CrossEntropyLoss()
|
20 |
+
|
21 |
+
self.init()
|
22 |
+
|
23 |
+
def init(self):
|
24 |
+
xavier_normal_(self.emb_e.weight)
|
25 |
+
xavier_normal_(self.emb_rel.weight)
|
26 |
+
|
27 |
+
def score_sr(self, sub, rel, sigmoid = False):
|
28 |
+
sub_emb = self.emb_e(sub).squeeze(dim=1)
|
29 |
+
rel_emb = self.emb_rel(rel).squeeze(dim=1)
|
30 |
+
|
31 |
+
#sub_emb = self.inp_drop(sub_emb)
|
32 |
+
#rel_emb = self.inp_drop(rel_emb)
|
33 |
+
|
34 |
+
pred = torch.mm(sub_emb*rel_emb, self.emb_e.weight.transpose(1,0))
|
35 |
+
if sigmoid:
|
36 |
+
pred = torch.sigmoid(pred)
|
37 |
+
return pred
|
38 |
+
|
39 |
+
def score_or(self, obj, rel, sigmoid = False):
|
40 |
+
obj_emb = self.emb_e(obj).squeeze(dim=1)
|
41 |
+
rel_emb = self.emb_rel(rel).squeeze(dim=1)
|
42 |
+
|
43 |
+
#obj_emb = self.inp_drop(obj_emb)
|
44 |
+
#rel_emb = self.inp_drop(rel_emb)
|
45 |
+
|
46 |
+
pred = torch.mm(obj_emb*rel_emb, self.emb_e.weight.transpose(1,0))
|
47 |
+
if sigmoid:
|
48 |
+
pred = torch.sigmoid(pred)
|
49 |
+
return pred
|
50 |
+
|
51 |
+
|
52 |
+
def forward(self, sub_emb, rel_emb, mode='rhs', sigmoid=False):
|
53 |
+
'''
|
54 |
+
When mode is 'rhs' we expect (s,r); for 'lhs', we expect (o,r)
|
55 |
+
For distmult, computations for both modes are equivalent, so we do not need if-else block
|
56 |
+
'''
|
57 |
+
sub_emb = self.inp_drop(sub_emb)
|
58 |
+
rel_emb = self.inp_drop(rel_emb)
|
59 |
+
|
60 |
+
pred = torch.mm(sub_emb*rel_emb, self.emb_e.weight.transpose(1,0))
|
61 |
+
|
62 |
+
if sigmoid:
|
63 |
+
pred = torch.sigmoid(pred)
|
64 |
+
|
65 |
+
return pred
|
66 |
+
|
67 |
+
def score_triples(self, sub, rel, obj, sigmoid=False):
|
68 |
+
'''
|
69 |
+
Inputs - subject, relation, object
|
70 |
+
Return - score
|
71 |
+
'''
|
72 |
+
sub_emb = self.emb_e(sub).squeeze(dim=1)
|
73 |
+
rel_emb = self.emb_rel(rel).squeeze(dim=1)
|
74 |
+
obj_emb = self.emb_e(obj).squeeze(dim=1)
|
75 |
+
|
76 |
+
pred = torch.sum(sub_emb*rel_emb*obj_emb, dim=-1)
|
77 |
+
|
78 |
+
if sigmoid:
|
79 |
+
pred = torch.sigmoid(pred)
|
80 |
+
|
81 |
+
return pred
|
82 |
+
|
83 |
+
def score_emb(self, emb_s, emb_r, emb_o, sigmoid=False):
|
84 |
+
'''
|
85 |
+
Inputs - embeddings of subject, relation, object
|
86 |
+
Return - score
|
87 |
+
'''
|
88 |
+
pred = torch.sum(emb_s*emb_r*emb_o, dim=-1)
|
89 |
+
|
90 |
+
if sigmoid:
|
91 |
+
pred = torch.sigmoid(pred)
|
92 |
+
|
93 |
+
return pred
|
94 |
+
|
95 |
+
def score_triples_vec(self, sub, rel, obj, sigmoid=False):
|
96 |
+
'''
|
97 |
+
Inputs - subject, relation, object
|
98 |
+
Return - a vector score for the triple instead of reducing over the embedding dimension
|
99 |
+
'''
|
100 |
+
sub_emb = self.emb_e(sub).squeeze(dim=1)
|
101 |
+
rel_emb = self.emb_rel(rel).squeeze(dim=1)
|
102 |
+
obj_emb = self.emb_e(obj).squeeze(dim=1)
|
103 |
+
|
104 |
+
pred = sub_emb*rel_emb*obj_emb
|
105 |
+
|
106 |
+
if sigmoid:
|
107 |
+
pred = torch.sigmoid(pred)
|
108 |
+
|
109 |
+
return pred
|
110 |
+
|
111 |
+
class Complex(torch.nn.Module):
|
112 |
+
def __init__(self, args, num_entities, num_relations):
|
113 |
+
super(Complex, self).__init__()
|
114 |
+
|
115 |
+
if args.max_norm:
|
116 |
+
self.emb_e = torch.nn.Embedding(num_entities, 2*args.embedding_dim, max_norm=1.0)
|
117 |
+
self.emb_rel = torch.nn.Embedding(num_relations, 2*args.embedding_dim)
|
118 |
+
else:
|
119 |
+
self.emb_e = torch.nn.Embedding(num_entities, 2*args.embedding_dim, padding_idx=None)
|
120 |
+
self.emb_rel = torch.nn.Embedding(num_relations, 2*args.embedding_dim, padding_idx=None)
|
121 |
+
|
122 |
+
self.inp_drop = torch.nn.Dropout(args.input_drop)
|
123 |
+
self.loss = torch.nn.CrossEntropyLoss()
|
124 |
+
|
125 |
+
self.init()
|
126 |
+
|
127 |
+
def init(self):
|
128 |
+
xavier_normal_(self.emb_e.weight)
|
129 |
+
xavier_normal_(self.emb_rel.weight)
|
130 |
+
|
131 |
+
def score_sr(self, sub, rel, sigmoid = False):
|
132 |
+
sub_emb = self.emb_e(sub).squeeze(dim=1)
|
133 |
+
rel_emb = self.emb_rel(rel).squeeze(dim=1)
|
134 |
+
|
135 |
+
s_real, s_img = torch.chunk(rel_emb, 2, dim=-1)
|
136 |
+
rel_real, rel_img = torch.chunk(sub_emb, 2, dim=-1)
|
137 |
+
emb_e_real, emb_e_img = torch.chunk(self.emb_e.weight, 2, dim=-1)
|
138 |
+
|
139 |
+
#s_real = self.inp_drop(s_real)
|
140 |
+
#s_img = self.inp_drop(s_img)
|
141 |
+
#rel_real = self.inp_drop(rel_real)
|
142 |
+
#rel_img = self.inp_drop(rel_img)
|
143 |
+
|
144 |
+
# complex space bilinear product (equivalent to HolE)
|
145 |
+
# realrealreal = torch.mm(s_real*rel_real, emb_e_real.transpose(1,0))
|
146 |
+
# realimgimg = torch.mm(s_real*rel_img, emb_e_img.transpose(1,0))
|
147 |
+
# imgrealimg = torch.mm(s_img*rel_real, emb_e_img.transpose(1,0))
|
148 |
+
# imgimgreal = torch.mm(s_img*rel_img, emb_e_real.transpose(1,0))
|
149 |
+
# pred = realrealreal + realimgimg + imgrealimg - imgimgreal
|
150 |
+
|
151 |
+
realo_realreal = s_real*rel_real
|
152 |
+
realo_imgimg = s_img*rel_img
|
153 |
+
realo = realo_realreal - realo_imgimg
|
154 |
+
real = torch.mm(realo, emb_e_real.transpose(1,0))
|
155 |
+
|
156 |
+
imgo_realimg = s_real*rel_img
|
157 |
+
imgo_imgreal = s_img*rel_real
|
158 |
+
imgo = imgo_realimg + imgo_imgreal
|
159 |
+
img = torch.mm(imgo, emb_e_img.transpose(1,0))
|
160 |
+
|
161 |
+
pred = real + img
|
162 |
+
|
163 |
+
if sigmoid:
|
164 |
+
pred = torch.sigmoid(pred)
|
165 |
+
return pred
|
166 |
+
|
167 |
+
|
168 |
+
def score_or(self, obj, rel, sigmoid = False):
|
169 |
+
obj_emb = self.emb_e(obj).squeeze(dim=1)
|
170 |
+
rel_emb = self.emb_rel(rel).squeeze(dim=1)
|
171 |
+
|
172 |
+
rel_real, rel_img = torch.chunk(rel_emb, 2, dim=-1)
|
173 |
+
o_real, o_img = torch.chunk(obj_emb, 2, dim=-1)
|
174 |
+
emb_e_real, emb_e_img = torch.chunk(self.emb_e.weight, 2, dim=-1)
|
175 |
+
|
176 |
+
#rel_real = self.inp_drop(rel_real)
|
177 |
+
#rel_img = self.inp_drop(rel_img)
|
178 |
+
#o_real = self.inp_drop(o_real)
|
179 |
+
#o_img = self.inp_drop(o_img)
|
180 |
+
|
181 |
+
# complex space bilinear product (equivalent to HolE)
|
182 |
+
# realrealreal = torch.mm(rel_real*o_real, emb_e_real.transpose(1,0))
|
183 |
+
# realimgimg = torch.mm(rel_img*o_img, emb_e_real.transpose(1,0))
|
184 |
+
# imgrealimg = torch.mm(rel_real*o_img, emb_e_img.transpose(1,0))
|
185 |
+
# imgimgreal = torch.mm(rel_img*o_real, emb_e_img.transpose(1,0))
|
186 |
+
# pred = realrealreal + realimgimg + imgrealimg - imgimgreal
|
187 |
+
|
188 |
+
reals_realreal = rel_real*o_real
|
189 |
+
reals_imgimg = rel_img*o_img
|
190 |
+
reals = reals_realreal + reals_imgimg
|
191 |
+
real = torch.mm(reals, emb_e_real.transpose(1,0))
|
192 |
+
|
193 |
+
imgs_realimg = rel_real*o_img
|
194 |
+
imgs_imgreal = rel_img*o_real
|
195 |
+
imgs = imgs_realimg - imgs_imgreal
|
196 |
+
img = torch.mm(imgs, emb_e_img.transpose(1,0))
|
197 |
+
|
198 |
+
pred = real + img
|
199 |
+
|
200 |
+
if sigmoid:
|
201 |
+
pred = torch.sigmoid(pred)
|
202 |
+
return pred
|
203 |
+
|
204 |
+
|
205 |
+
def forward(self, sub_emb, rel_emb, mode='rhs', sigmoid=False):
|
206 |
+
'''
|
207 |
+
When mode is 'rhs' we expect (s,r); for 'lhs', we expect (o,r)
|
208 |
+
|
209 |
+
'''
|
210 |
+
if mode == 'lhs':
|
211 |
+
rel_real, rel_img = torch.chunk(rel_emb, 2, dim=-1)
|
212 |
+
o_real, o_img = torch.chunk(sub_emb, 2, dim=-1)
|
213 |
+
emb_e_real, emb_e_img = torch.chunk(self.emb_e.weight, 2, dim=-1)
|
214 |
+
|
215 |
+
rel_real = self.inp_drop(rel_real)
|
216 |
+
rel_img = self.inp_drop(rel_img)
|
217 |
+
o_real = self.inp_drop(o_real)
|
218 |
+
o_img = self.inp_drop(o_img)
|
219 |
+
|
220 |
+
# complex space bilinear product (equivalent to HolE)
|
221 |
+
# realrealreal = torch.mm(rel_real*o_real, emb_e_real.transpose(1,0))
|
222 |
+
# realimgimg = torch.mm(rel_img*o_img, emb_e_real.transpose(1,0))
|
223 |
+
# imgrealimg = torch.mm(rel_real*o_img, emb_e_img.transpose(1,0))
|
224 |
+
# imgimgreal = torch.mm(rel_img*o_real, emb_e_img.transpose(1,0))
|
225 |
+
# pred = realrealreal + realimgimg + imgrealimg - imgimgreal
|
226 |
+
reals_realreal = rel_real*o_real
|
227 |
+
reals_imgimg = rel_img*o_img
|
228 |
+
reals = reals_realreal + reals_imgimg
|
229 |
+
real = torch.mm(reals, emb_e_real.transpose(1,0))
|
230 |
+
|
231 |
+
imgs_realimg = rel_real*o_img
|
232 |
+
imgs_imgreal = rel_img*o_real
|
233 |
+
imgs = imgs_realimg - imgs_imgreal
|
234 |
+
img = torch.mm(imgs, emb_e_img.transpose(1,0))
|
235 |
+
|
236 |
+
pred = real + img
|
237 |
+
|
238 |
+
else:
|
239 |
+
s_real, s_img = torch.chunk(rel_emb, 2, dim=-1)
|
240 |
+
rel_real, rel_img = torch.chunk(sub_emb, 2, dim=-1)
|
241 |
+
emb_e_real, emb_e_img = torch.chunk(self.emb_e.weight, 2, dim=-1)
|
242 |
+
|
243 |
+
s_real = self.inp_drop(s_real)
|
244 |
+
s_img = self.inp_drop(s_img)
|
245 |
+
rel_real = self.inp_drop(rel_real)
|
246 |
+
rel_img = self.inp_drop(rel_img)
|
247 |
+
|
248 |
+
# complex space bilinear product (equivalent to HolE)
|
249 |
+
# realrealreal = torch.mm(s_real*rel_real, emb_e_real.transpose(1,0))
|
250 |
+
# realimgimg = torch.mm(s_real*rel_img, emb_e_img.transpose(1,0))
|
251 |
+
# imgrealimg = torch.mm(s_img*rel_real, emb_e_img.transpose(1,0))
|
252 |
+
# imgimgreal = torch.mm(s_img*rel_img, emb_e_real.transpose(1,0))
|
253 |
+
# pred = realrealreal + realimgimg + imgrealimg - imgimgreal
|
254 |
+
|
255 |
+
realo_realreal = s_real*rel_real
|
256 |
+
realo_imgimg = s_img*rel_img
|
257 |
+
realo = realo_realreal - realo_imgimg
|
258 |
+
real = torch.mm(realo, emb_e_real.transpose(1,0))
|
259 |
+
|
260 |
+
imgo_realimg = s_real*rel_img
|
261 |
+
imgo_imgreal = s_img*rel_real
|
262 |
+
imgo = imgo_realimg + imgo_imgreal
|
263 |
+
img = torch.mm(imgo, emb_e_img.transpose(1,0))
|
264 |
+
|
265 |
+
pred = real + img
|
266 |
+
|
267 |
+
if sigmoid:
|
268 |
+
pred = torch.sigmoid(pred)
|
269 |
+
|
270 |
+
return pred
|
271 |
+
|
272 |
+
def score_triples(self, sub, rel, obj, sigmoid=False):
|
273 |
+
'''
|
274 |
+
Inputs - subject, relation, object
|
275 |
+
Return - score
|
276 |
+
'''
|
277 |
+
sub_emb = self.emb_e(sub).squeeze(dim=1)
|
278 |
+
rel_emb = self.emb_rel(rel).squeeze(dim=1)
|
279 |
+
obj_emb = self.emb_e(obj).squeeze(dim=1)
|
280 |
+
|
281 |
+
s_real, s_img = torch.chunk(sub_emb, 2, dim=-1)
|
282 |
+
rel_real, rel_img = torch.chunk(rel_emb, 2, dim=-1)
|
283 |
+
o_real, o_img = torch.chunk(obj_emb, 2, dim=-1)
|
284 |
+
|
285 |
+
realrealreal = torch.sum(s_real*rel_real*o_real, dim=-1)
|
286 |
+
realimgimg = torch.sum(s_real*rel_img*o_img, axis=-1)
|
287 |
+
imgrealimg = torch.sum(s_img*rel_real*o_img, axis=-1)
|
288 |
+
imgimgreal = torch.sum(s_img*rel_img*o_real, axis=-1)
|
289 |
+
|
290 |
+
pred = realrealreal + realimgimg + imgrealimg - imgimgreal
|
291 |
+
|
292 |
+
if sigmoid:
|
293 |
+
pred = torch.sigmoid(pred)
|
294 |
+
|
295 |
+
return pred
|
296 |
+
|
297 |
+
def score_emb(self, emb_s, emb_r, emb_o, sigmoid=False):
|
298 |
+
'''
|
299 |
+
Inputs - embeddings of subject, relation, object
|
300 |
+
Return - score
|
301 |
+
'''
|
302 |
+
|
303 |
+
s_real, s_img = torch.chunk(emb_s, 2, dim=-1)
|
304 |
+
rel_real, rel_img = torch.chunk(emb_r, 2, dim=-1)
|
305 |
+
o_real, o_img = torch.chunk(emb_o, 2, dim=-1)
|
306 |
+
|
307 |
+
realrealreal = torch.sum(s_real*rel_real*o_real, dim=-1)
|
308 |
+
realimgimg = torch.sum(s_real*rel_img*o_img, axis=-1)
|
309 |
+
imgrealimg = torch.sum(s_img*rel_real*o_img, axis=-1)
|
310 |
+
imgimgreal = torch.sum(s_img*rel_img*o_real, axis=-1)
|
311 |
+
|
312 |
+
pred = realrealreal + realimgimg + imgrealimg - imgimgreal
|
313 |
+
|
314 |
+
if sigmoid:
|
315 |
+
pred = torch.sigmoid(pred)
|
316 |
+
|
317 |
+
return pred
|
318 |
+
|
319 |
+
def score_triples_vec(self, sub, rel, obj, sigmoid=False):
|
320 |
+
'''
|
321 |
+
Inputs - subject, relation, object
|
322 |
+
Return - a vector score for the triple instead of reducing over the embedding dimension
|
323 |
+
'''
|
324 |
+
sub_emb = self.emb_e(sub).squeeze(dim=1)
|
325 |
+
rel_emb = self.emb_rel(rel).squeeze(dim=1)
|
326 |
+
obj_emb = self.emb_e(obj).squeeze(dim=1)
|
327 |
+
|
328 |
+
s_real, s_img = torch.chunk(sub_emb, 2, dim=-1)
|
329 |
+
rel_real, rel_img = torch.chunk(rel_emb, 2, dim=-1)
|
330 |
+
o_real, o_img = torch.chunk(obj_emb, 2, dim=-1)
|
331 |
+
|
332 |
+
realrealreal = s_real*rel_real*o_real
|
333 |
+
realimgimg = s_real*rel_img*o_img
|
334 |
+
imgrealimg = s_img*rel_real*o_img
|
335 |
+
imgimgreal = s_img*rel_img*o_real
|
336 |
+
|
337 |
+
pred = realrealreal + realimgimg + imgrealimg - imgimgreal
|
338 |
+
|
339 |
+
if sigmoid:
|
340 |
+
pred = torch.sigmoid(pred)
|
341 |
+
|
342 |
+
return pred
|
343 |
+
|
344 |
+
class Conve(torch.nn.Module):
|
345 |
+
|
346 |
+
#Too slow !!!!
|
347 |
+
|
348 |
+
def __init__(self, args, num_entities, num_relations):
|
349 |
+
super(Conve, self).__init__()
|
350 |
+
|
351 |
+
if args.max_norm:
|
352 |
+
self.emb_e = torch.nn.Embedding(num_entities, args.embedding_dim, max_norm=1.0)
|
353 |
+
self.emb_rel = torch.nn.Embedding(num_relations, args.embedding_dim)
|
354 |
+
else:
|
355 |
+
self.emb_e = torch.nn.Embedding(num_entities, args.embedding_dim, padding_idx=None)
|
356 |
+
self.emb_rel = torch.nn.Embedding(num_relations, args.embedding_dim, padding_idx=None)
|
357 |
+
|
358 |
+
self.inp_drop = torch.nn.Dropout(args.input_drop)
|
359 |
+
self.hidden_drop = torch.nn.Dropout(args.hidden_drop)
|
360 |
+
self.feature_drop = torch.nn.Dropout2d(args.feat_drop)
|
361 |
+
|
362 |
+
self.embedding_dim = args.embedding_dim #default is 200
|
363 |
+
self.num_filters = args.num_filters # default is 32
|
364 |
+
self.kernel_size = args.kernel_size # default is 3
|
365 |
+
self.stack_width = args.stack_width # default is 20
|
366 |
+
self.stack_height = args.embedding_dim // self.stack_width
|
367 |
+
|
368 |
+
self.bn0 = torch.nn.BatchNorm2d(1)
|
369 |
+
self.bn1 = torch.nn.BatchNorm2d(self.num_filters)
|
370 |
+
self.bn2 = torch.nn.BatchNorm1d(args.embedding_dim)
|
371 |
+
|
372 |
+
self.conv1 = torch.nn.Conv2d(1, out_channels=self.num_filters,
|
373 |
+
kernel_size=(self.kernel_size, self.kernel_size),
|
374 |
+
stride=1, padding=0, bias=args.use_bias)
|
375 |
+
#self.conv1 = torch.nn.Conv2d(1, 32, (3, 3), 1, 0, bias=args.use_bias) # <-- default
|
376 |
+
|
377 |
+
flat_sz_h = int(2*self.stack_width) - self.kernel_size + 1
|
378 |
+
flat_sz_w = self.stack_height - self.kernel_size + 1
|
379 |
+
self.flat_sz = flat_sz_h*flat_sz_w*self.num_filters
|
380 |
+
self.fc = torch.nn.Linear(self.flat_sz, args.embedding_dim)
|
381 |
+
|
382 |
+
self.register_parameter('b', Parameter(torch.zeros(num_entities)))
|
383 |
+
self.loss = torch.nn.CrossEntropyLoss()
|
384 |
+
|
385 |
+
self.init()
|
386 |
+
|
387 |
+
def init(self):
|
388 |
+
xavier_normal_(self.emb_e.weight)
|
389 |
+
xavier_normal_(self.emb_rel.weight)
|
390 |
+
|
391 |
+
def concat(self, e1_embed, rel_embed, form='plain'):
|
392 |
+
if form == 'plain':
|
393 |
+
e1_embed = e1_embed. view(-1, 1, self.stack_width, self.stack_height)
|
394 |
+
rel_embed = rel_embed.view(-1, 1, self.stack_width, self.stack_height)
|
395 |
+
stack_inp = torch.cat([e1_embed, rel_embed], 2)
|
396 |
+
|
397 |
+
elif form == 'alternate':
|
398 |
+
e1_embed = e1_embed. view(-1, 1, self.embedding_dim)
|
399 |
+
rel_embed = rel_embed.view(-1, 1, self.embedding_dim)
|
400 |
+
stack_inp = torch.cat([e1_embed, rel_embed], 1)
|
401 |
+
stack_inp = torch.transpose(stack_inp, 2, 1).reshape((-1, 1, 2*self.stack_width, self.stack_height))
|
402 |
+
|
403 |
+
else: raise NotImplementedError
|
404 |
+
return stack_inp
|
405 |
+
|
406 |
+
def conve_architecture(self, sub_emb, rel_emb):
|
407 |
+
stacked_inputs = self.concat(sub_emb, rel_emb)
|
408 |
+
stacked_inputs = self.bn0(stacked_inputs)
|
409 |
+
x = self.inp_drop(stacked_inputs)
|
410 |
+
x = self.conv1(x)
|
411 |
+
x = self.bn1(x)
|
412 |
+
x = F.relu(x)
|
413 |
+
x = self.feature_drop(x)
|
414 |
+
#x = x.view(x.shape[0], -1)
|
415 |
+
x = x.view(-1, self.flat_sz)
|
416 |
+
x = self.fc(x)
|
417 |
+
x = self.hidden_drop(x)
|
418 |
+
x = self.bn2(x)
|
419 |
+
x = F.relu(x)
|
420 |
+
|
421 |
+
return x
|
422 |
+
|
423 |
+
def score_sr(self, sub, rel, sigmoid = False):
|
424 |
+
sub_emb = self.emb_e(sub)
|
425 |
+
rel_emb = self.emb_rel(rel)
|
426 |
+
|
427 |
+
x = self.conve_architecture(sub_emb, rel_emb)
|
428 |
+
|
429 |
+
pred = torch.mm(x, self.emb_e.weight.transpose(1,0))
|
430 |
+
pred += self.b.expand_as(pred)
|
431 |
+
|
432 |
+
if sigmoid:
|
433 |
+
pred = torch.sigmoid(pred)
|
434 |
+
return pred
|
435 |
+
|
436 |
+
def score_or(self, obj, rel, sigmoid = False):
|
437 |
+
obj_emb = self.emb_e(obj)
|
438 |
+
rel_emb = self.emb_rel(rel)
|
439 |
+
|
440 |
+
x = self.conve_architecture(obj_emb, rel_emb)
|
441 |
+
pred = torch.mm(x, self.emb_e.weight.transpose(1,0))
|
442 |
+
pred += self.b.expand_as(pred)
|
443 |
+
|
444 |
+
if sigmoid:
|
445 |
+
pred = torch.sigmoid(pred)
|
446 |
+
return pred
|
447 |
+
|
448 |
+
|
449 |
+
def forward(self, sub_emb, rel_emb, mode='rhs', sigmoid=False):
|
450 |
+
'''
|
451 |
+
When mode is 'rhs' we expect (s,r); for 'lhs', we expect (o,r)
|
452 |
+
For conve, computations for both modes are equivalent, so we do not need if-else block
|
453 |
+
'''
|
454 |
+
x = self.conve_architecture(sub_emb, rel_emb)
|
455 |
+
|
456 |
+
pred = torch.mm(x, self.emb_e.weight.transpose(1,0))
|
457 |
+
pred += self.b.expand_as(pred)
|
458 |
+
|
459 |
+
if sigmoid:
|
460 |
+
pred = torch.sigmoid(pred)
|
461 |
+
|
462 |
+
return pred
|
463 |
+
|
464 |
+
def score_triples(self, sub, rel, obj, sigmoid=False):
|
465 |
+
'''
|
466 |
+
Inputs - subject, relation, object
|
467 |
+
Return - score
|
468 |
+
'''
|
469 |
+
sub_emb = self.emb_e(sub)
|
470 |
+
rel_emb = self.emb_rel(rel)
|
471 |
+
obj_emb = self.emb_e(obj)
|
472 |
+
x = self.conve_architecture(sub_emb, rel_emb)
|
473 |
+
|
474 |
+
pred = torch.mm(x, obj_emb.transpose(1,0))
|
475 |
+
#print(pred.shape)
|
476 |
+
pred += self.b[obj].expand_as(pred) #taking the bias value for object embedding
|
477 |
+
# above works fine for single input triples;
|
478 |
+
# but if input is batch of triples, then this is a matrix of (num_trip x num_trip) where diagonal is scores
|
479 |
+
# so use torch.diagonal() after calling this function
|
480 |
+
pred = torch.diagonal(pred)
|
481 |
+
# or could have used : pred= torch.sum(x*obj_emb, dim=-1)
|
482 |
+
|
483 |
+
if sigmoid:
|
484 |
+
pred = torch.sigmoid(pred)
|
485 |
+
|
486 |
+
return pred
|
487 |
+
|
488 |
+
def score_emb(self, emb_s, emb_r, emb_o, sigmoid=False):
|
489 |
+
'''
|
490 |
+
Inputs - embeddings of subject, relation, object
|
491 |
+
Return - score
|
492 |
+
'''
|
493 |
+
x = self.conve_architecture(emb_s, emb_r)
|
494 |
+
|
495 |
+
pred = torch.mm(x, emb_o.transpose(1,0))
|
496 |
+
|
497 |
+
pred = torch.diagonal(pred)
|
498 |
+
|
499 |
+
if sigmoid:
|
500 |
+
pred = torch.sigmoid(pred)
|
501 |
+
|
502 |
+
return pred
|
503 |
+
|
504 |
+
def score_triples_vec(self, sub, rel, obj, sigmoid=False):
|
505 |
+
'''
|
506 |
+
Inputs - subject, relation, object
|
507 |
+
Return - a vector score for the triple instead of reducing over the embedding dimension
|
508 |
+
'''
|
509 |
+
sub_emb = self.emb_e(sub)
|
510 |
+
rel_emb = self.emb_rel(rel)
|
511 |
+
obj_emb = self.emb_e(obj)
|
512 |
+
|
513 |
+
x = self.conve_architecture(sub_emb, rel_emb)
|
514 |
+
|
515 |
+
pred = x*obj_emb
|
516 |
+
|
517 |
+
if sigmoid:
|
518 |
+
pred = torch.sigmoid(pred)
|
519 |
+
|
520 |
+
return pred
|
DiseaseAgnostic/utils.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
A file modified on https://github.com/PeruBhardwaj/AttributionAttack/blob/main/KGEAttack/ConvE/utils.py
|
3 |
+
'''
|
4 |
+
#%%
|
5 |
+
import logging
|
6 |
+
import time
|
7 |
+
from tqdm import tqdm
|
8 |
+
import io
|
9 |
+
import pandas as pd
|
10 |
+
import numpy as np
|
11 |
+
import os
|
12 |
+
import json
|
13 |
+
|
14 |
+
import argparse
|
15 |
+
import torch
|
16 |
+
import random
|
17 |
+
|
18 |
+
from yaml import parse
|
19 |
+
|
20 |
+
from model import Conve, Distmult, Complex
|
21 |
+
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
#%%
|
24 |
+
def generate_dicts(data_path):
|
25 |
+
with open (os.path.join(data_path, 'entities_dict.json'), 'r') as f:
|
26 |
+
ent_to_id = json.load(f)
|
27 |
+
with open (os.path.join(data_path, 'relations_dict.json'), 'r') as f:
|
28 |
+
rel_to_id = json.load(f)
|
29 |
+
n_ent = len(list(ent_to_id.keys()))
|
30 |
+
n_rel = len(list(rel_to_id.keys()))
|
31 |
+
|
32 |
+
return n_ent, n_rel, ent_to_id, rel_to_id
|
33 |
+
|
34 |
+
def save_data(file_name, data):
|
35 |
+
with open(file_name, 'w') as fl:
|
36 |
+
for item in data:
|
37 |
+
fl.write("%s\n" % "\t".join(map(str, item)))
|
38 |
+
|
39 |
+
def load_data(file_name):
|
40 |
+
df = pd.read_csv(file_name, sep='\t', header=None, names=None, dtype=str)
|
41 |
+
df = df.drop_duplicates()
|
42 |
+
return df.values
|
43 |
+
|
44 |
+
def seed_all(seed=1):
|
45 |
+
random.seed(seed)
|
46 |
+
np.random.seed(seed)
|
47 |
+
torch.manual_seed(seed)
|
48 |
+
torch.cuda.manual_seed_all(seed)
|
49 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
50 |
+
torch.backends.cudnn.deterministic = True
|
51 |
+
|
52 |
+
def add_model(args, n_ent, n_rel):
|
53 |
+
if args.model is None:
|
54 |
+
model = Distmult(args, n_ent, n_rel)
|
55 |
+
elif args.model == 'distmult':
|
56 |
+
model = Distmult(args, n_ent, n_rel)
|
57 |
+
elif args.model == 'complex':
|
58 |
+
model = Complex(args, n_ent, n_rel)
|
59 |
+
elif args.model == 'conve':
|
60 |
+
model = Conve(args, n_ent, n_rel)
|
61 |
+
else:
|
62 |
+
raise Exception("Unknown model!")
|
63 |
+
|
64 |
+
return model
|
65 |
+
|
66 |
+
def load_model(model_path, args, n_ent, n_rel, device):
|
67 |
+
# add a model and load the pre-trained params
|
68 |
+
model = add_model(args, n_ent, n_rel)
|
69 |
+
model.to(device)
|
70 |
+
logger.info('Loading saved model from {0}'.format(model_path))
|
71 |
+
state = torch.load(model_path)
|
72 |
+
model_params = state['state_dict']
|
73 |
+
params = [(key, value.size(), value.numel()) for key, value in model_params.items()]
|
74 |
+
for key, size, count in params:
|
75 |
+
logger.info('Key:{0}, Size:{1}, Count:{2}'.format(key, size, count))
|
76 |
+
|
77 |
+
model.load_state_dict(model_params)
|
78 |
+
model.eval()
|
79 |
+
logger.info(model)
|
80 |
+
|
81 |
+
return model
|
82 |
+
|
83 |
+
def add_eval_parameters(parser):
|
84 |
+
|
85 |
+
parser.add_argument('--eval-mode', type = str, default = 'all', help = 'Method to evaluate the attack performance. Default: all. (all or single)')
|
86 |
+
parser.add_argument('--cuda-name', type = str, required = True, help = 'Start a main thread on each cuda.')
|
87 |
+
parser.add_argument('--direct', action='store_true', help = 'Directly add edge or not.')
|
88 |
+
parser.add_argument('--seperate', action='store_true', help = 'Evaluate seperatly or not')
|
89 |
+
return parser
|
90 |
+
|
91 |
+
def add_attack_parameters(parser):
|
92 |
+
|
93 |
+
# parser.add_argument('--target-split', type=str, default='0_100_1', help='Ranks to use for target set. Values are 0 for ranks==1; 1 for ranks <=10; 2 for ranks>10 and ranks<=100. Default: 1')
|
94 |
+
parser.add_argument('--target-split', type=str, default='min', help='Methods for target triple selection. Default: min. (min or top_?, top means top_0.1)')
|
95 |
+
parser.add_argument('--target-size', type=int, default=50, help='Number of target triples. Default: 50')
|
96 |
+
parser.add_argument('--target-existed', action='store_true', help='Whether the targeted s_?_o already exists.')
|
97 |
+
|
98 |
+
# parser.add_argument('--budget', type=int, default=1, help='Budget for each target triple for each corruption side')
|
99 |
+
|
100 |
+
parser.add_argument('--attack-goal', type = str, default='single', help='Attack goal. Default: single. (single or global)')
|
101 |
+
parser.add_argument('--neighbor-num', type = int, default=20, help='Max neighbor num for each side. Default: 20')
|
102 |
+
parser.add_argument('--candidate-mode', type = str, default='quadratic', help = 'The method to generate candidate edge. Default: quadratic. (quadratic or linear)')
|
103 |
+
parser.add_argument('--reasonable-rate', type = float, default=0.7, help = 'The added edge\'s existance rank prob greater than this rate')
|
104 |
+
# parser.add_argument('--neighbor-num', type = int, default=200, help='Max neighbor num for each side. Default: 200')
|
105 |
+
# parser.add_argument('--candidate-mode', type = str, default='linear', help = 'The method to generate candidate edge. Default: quadratic. (quadratic or linear)')
|
106 |
+
parser.add_argument('--attack-batch-size', type=int, default=256, help='Batch size for processing neighbours of target')
|
107 |
+
parser.add_argument('--template-mode', type=str, default = 'manual', help = 'Template mode for transforming edge to single sentense. Default: manual. (manual or auto)')
|
108 |
+
|
109 |
+
parser.add_argument('--update-lissa', action='store_true', help = 'Update lissa cache or not.')
|
110 |
+
|
111 |
+
parser.add_argument('--GPT-batch-size', type=int, default = 64, help = 'Batch size for GPT2 when calculating LM score. Default: 64')
|
112 |
+
parser.add_argument('--LM-softmax', action='store_true', help = 'Use a softmax head on LM prob or not.')
|
113 |
+
parser.add_argument('--LMprob-mode', type=str, default='relative', help = 'Use the absolute LM score or calculate the destruction score when target word is replaced. Default: absolute. (absolute or relative)')
|
114 |
+
|
115 |
+
return parser
|
116 |
+
|
117 |
+
def get_argument_parser():
|
118 |
+
'''Generate an argument parser'''
|
119 |
+
parser = argparse.ArgumentParser(description='Graph embedding')
|
120 |
+
|
121 |
+
parser.add_argument('--seed', type=int, default=1, metavar='S', help='Random seed (default: 1)')
|
122 |
+
|
123 |
+
parser.add_argument('--data', type=str, default='GNBR', help='Dataset to use: { GNBR }')
|
124 |
+
parser.add_argument('--model', type=str, default='distmult', help='Choose from: {distmult, complex, transe, conve}')
|
125 |
+
|
126 |
+
parser.add_argument('--transe-margin', type=float, default=0.0, help='Margin value for TransE scoring function. Default:0.0')
|
127 |
+
parser.add_argument('--transe-norm', type=int, default=2, help='P-norm value for TransE scoring function. Default:2')
|
128 |
+
|
129 |
+
parser.add_argument('--epochs', type=int, default=100, help='Number of epochs to train (default: 100)')
|
130 |
+
parser.add_argument('--lr', type=float, default=0.001, help='Learning rate (default: 0.001)')
|
131 |
+
parser.add_argument('--lr-decay', type=float, default=0.0, help='Weight decay value to use in the optimizer. Default: 0.0')
|
132 |
+
parser.add_argument('--max-norm', action='store_true', help='Option to add unit max norm constraint to entity embeddings')
|
133 |
+
|
134 |
+
parser.add_argument('--train-batch-size', type=int, default=64, help='Batch size for train split (default: 128)')
|
135 |
+
parser.add_argument('--test-batch-size', type=int, default=128, help='Batch size for test split (default: 128)')
|
136 |
+
parser.add_argument('--valid-batch-size', type=int, default=128, help='Batch size for valid split (default: 128)')
|
137 |
+
parser.add_argument('--KG-valid-rate', type = float, default=0.1, help='Validation rate during KG embedding training. (default: 0.1)')
|
138 |
+
|
139 |
+
parser.add_argument('--save-influence-map', action='store_true', help='Save the influence map during training for gradient rollback.')
|
140 |
+
parser.add_argument('--add-reciprocals', action='store_true')
|
141 |
+
|
142 |
+
parser.add_argument('--embedding-dim', type=int, default=128, help='The embedding dimension (1D). Default: 128')
|
143 |
+
parser.add_argument('--stack-width', type=int, default=16, help='The first dimension of the reshaped/stacked 2D embedding. Second dimension is inferred. Default: 20')
|
144 |
+
#parser.add_argument('--stack_height', type=int, default=10, help='The second dimension of the reshaped/stacked 2D embedding. Default: 10')
|
145 |
+
parser.add_argument('--hidden-drop', type=float, default=0.3, help='Dropout for the hidden layer. Default: 0.3.')
|
146 |
+
parser.add_argument('--input-drop', type=float, default=0.2, help='Dropout for the input embeddings. Default: 0.2.')
|
147 |
+
parser.add_argument('--feat-drop', type=float, default=0.3, help='Dropout for the convolutional features. Default: 0.2.')
|
148 |
+
parser.add_argument('-num-filters', default=32, type=int, help='Number of filters for convolution')
|
149 |
+
parser.add_argument('-kernel-size', default=3, type=int, help='Kernel Size for convolution')
|
150 |
+
|
151 |
+
parser.add_argument('--use-bias', action='store_true', help='Use a bias in the convolutional layer. Default: True')
|
152 |
+
|
153 |
+
parser.add_argument('--reg-weight', type=float, default=5e-2, help='Weight for regularization. Default: 5e-2')
|
154 |
+
parser.add_argument('--reg-norm', type=int, default=3, help='Norm for regularization. Default: 2')
|
155 |
+
# parser.add_argument('--resume', action='store_true', help='Restore a saved model.')
|
156 |
+
# parser.add_argument('--resume-split', type=str, default='test', help='Split to evaluate a restored model')
|
157 |
+
# parser.add_argument('--reproduce-results', action='store_true', help='Use the hyperparameters to reproduce the results.')
|
158 |
+
# parser.add_argument('--original-data', type=str, default='FB15k-237', help='Dataset to use; this option is needed to set the hyperparams to reproduce the results for training after attack, default: FB15k-237')
|
159 |
+
return parser
|
160 |
+
|
161 |
+
def set_hyperparams(args):
|
162 |
+
if args.model == 'distmult':
|
163 |
+
args.lr = 0.005
|
164 |
+
args.train_batch_size = 1024
|
165 |
+
args.reg_norm = 3
|
166 |
+
elif args.model == 'complex':
|
167 |
+
args.lr = 0.005
|
168 |
+
args.reg_norm = 3
|
169 |
+
args.input_drop = 0.4
|
170 |
+
args.train_batch_size = 1024
|
171 |
+
elif args.model == 'conve':
|
172 |
+
args.lr = 0.005
|
173 |
+
args.train_batch_size = 1024
|
174 |
+
args.reg_weight = 0.0
|
175 |
+
|
176 |
+
# args.damping = 0.01
|
177 |
+
# args.lissa_repeat = 1
|
178 |
+
# args.lissa_depth = 1
|
179 |
+
# args.scale = 500
|
180 |
+
# args.lissa_batch_size = 100
|
181 |
+
|
182 |
+
args.damping = 0.01
|
183 |
+
args.lissa_repeat = 1
|
184 |
+
args.lissa_depth = 1
|
185 |
+
args.scale = 400
|
186 |
+
args.lissa_batch_size = 300
|
187 |
+
return args
|