yjwtheonly commited on
Commit
fce1f4b
1 Parent(s): 8ae6390
DiseaseAgnostic/KG_extractor.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%%
2
+ import torch
3
+ import numpy as np
4
+ from torch.autograd import Variable
5
+ from sklearn import metrics
6
+
7
+ import datetime
8
+ from typing import Dict, Tuple, List
9
+ import logging
10
+ import os
11
+ import utils
12
+ import pickle as pkl
13
+ import json
14
+ import torch.backends.cudnn as cudnn
15
+
16
+ from tqdm import tqdm
17
+
18
+ import sys
19
+ sys.path.append("..")
20
+ import Parameters
21
+
22
+ parser = utils.get_argument_parser()
23
+ parser.add_argument('--reasonable-rate', type = float, default=0.7, help = 'The added edge\'s existance rank prob greater than this rate')
24
+ parser.add_argument('--mode', type=str, default='sentence', help='sentence, finetune, biogpt, bioBART')
25
+ parser.add_argument('--action', type=str, default='parse', help='parse or extract')
26
+ parser.add_argument('--init-mode', type = str, default='random', help = 'How to select target nodes')
27
+ parser.add_argument('--ratio', type = str, default='', help='ratio of the number of changed words')
28
+ args = parser.parse_args()
29
+ args = utils.set_hyperparams(args)
30
+
31
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+
33
+ utils.seed_all(args.seed)
34
+ np.set_printoptions(precision=5)
35
+ cudnn.benchmark = False
36
+
37
+ data_path = '../DiseaseSpecific/processed_data/GNBR'
38
+ target_path = f'processed_data/target_{args.reasonable_rate}{args.init_mode}.pkl'
39
+ attack_path = f'processed_data/attack_edge_{args.model}_{args.reasonable_rate}{args.init_mode}.pkl'
40
+ modified_attack_path = f'processed_data/attack_edge_{args.model}_{args.reasonable_rate}{args.init_mode}{args.mode}.pkl'
41
+
42
+ with open(attack_path, 'rb') as fl:
43
+ Attack_edge_list = pkl.load(fl)
44
+ attack_data = np.array(Attack_edge_list).reshape(-1, 3)
45
+ #%%
46
+ with open(os.path.join(data_path, 'entities_reverse_dict.json')) as fl:
47
+ id_to_meshid = json.load(fl)
48
+ with open(os.path.join(data_path, 'entities_dict.json'), 'r') as fl:
49
+ meshid_to_id = json.load(fl)
50
+ with open(Parameters.GNBRfile+'entity_raw_name', 'rb') as fl:
51
+ entity_raw_name = pkl.load(fl)
52
+ with open(Parameters.GNBRfile+'retieve_sentence_through_edgetype', 'rb') as fl:
53
+ retieve_sentence_through_edgetype = pkl.load(fl)
54
+ with open(Parameters.GNBRfile+'raw_text_of_each_sentence', 'rb') as fl:
55
+ raw_text_sen = pkl.load(fl)
56
+ with open(Parameters.GNBRfile+'original_entity_raw_name', 'rb') as fl:
57
+ full_entity_raw_name = pkl.load(fl)
58
+ for k, v in entity_raw_name.items():
59
+ assert v in full_entity_raw_name[k]
60
+
61
+ #find unique
62
+ once_set = set()
63
+ twice_set = set()
64
+
65
+ with open('../DiseaseSpecific/generate_abstract/valid_entity.json', 'r') as fl:
66
+ valid_entity = json.load(fl)
67
+ valid_entity = set(valid_entity)
68
+
69
+ good_name = set()
70
+ for k, v, in full_entity_raw_name.items():
71
+ names = list(v)
72
+ for name in names:
73
+ # if name == 'in a':
74
+ # print(names)
75
+ good_name.add(name)
76
+ # if name not in once_set:
77
+ # once_set.add(name)
78
+ # else:
79
+ # twice_set.add(name)
80
+ # assert 'WNK4' in once_set
81
+ # good_name = set.difference(once_set, twice_set)
82
+ # assert 'in a' not in good_name
83
+ # assert 'STE20' not in good_name
84
+ # assert 'STE20' not in valid_entity
85
+ # assert 'STE20-related proline-alanine-rich kinase' not in good_name
86
+ # assert 'STE20-related proline-alanine-rich kinase' not in valid_entity
87
+ # raise Exception
88
+
89
+ name_to_type = {}
90
+ name_to_meshid = {}
91
+
92
+ for k, v, in full_entity_raw_name.items():
93
+ names = list(v)
94
+ for name in names:
95
+ if name in good_name:
96
+ name_to_type[name] = k.split('_')[0]
97
+ name_to_meshid[name] = k
98
+
99
+ import spacy
100
+ import networkx as nx
101
+ import pprint
102
+
103
+ def check(p, s):
104
+
105
+ if p < 1 or p >= len(s):
106
+ return True
107
+ return not((s[p]>='a' and s[p]<='z') or (s[p]>='A' and s[p]<='Z') or (s[p]>='0' and s[p]<='9'))
108
+
109
+ def raw_to_format(sen):
110
+
111
+ text = sen
112
+ l = 0
113
+ ret = []
114
+ while(l < len(text)):
115
+ bo =False
116
+ if text[l] != ' ':
117
+ for i in range(len(text), l, -1): # reversing is important !!!
118
+ cc = text[l:i]
119
+ if (cc in good_name or cc in valid_entity) and check(l-1, text) and check(i, text):
120
+ ret.append(cc.replace(' ', '_'))
121
+ l = i
122
+ bo = True
123
+ break
124
+ if not bo:
125
+ ret.append(text[l])
126
+ l += 1
127
+ return ''.join(ret)
128
+
129
+ if args.mode == 'sentence':
130
+ with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}_chat.json', 'r') as fl:
131
+ draft = json.load(fl)
132
+ elif args.mode == 'finetune':
133
+ with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}_sentence_finetune.json', 'r') as fl:
134
+ draft = json.load(fl)
135
+ elif args.mode == 'bioBART':
136
+ with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}{args.ratio}_bioBART_finetune.json', 'r') as fl:
137
+ draft = json.load(fl)
138
+ elif args.mode == 'biogpt':
139
+ with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}_biogpt.json', 'r') as fl:
140
+ draft = json.load(fl)
141
+ else:
142
+ raise Exception('No!!!')
143
+
144
+ nlp = spacy.load("en_core_web_sm")
145
+
146
+ type_set = set()
147
+ for aa in range(36):
148
+ dependency_sen_dict = retieve_sentence_through_edgetype[aa]['manual']
149
+ tmp_dict = retieve_sentence_through_edgetype[aa]['auto']
150
+ dependencys = list(dependency_sen_dict.keys()) + list(tmp_dict.keys())
151
+ for dependency in dependencys:
152
+ dep_list = dependency.split(' ')
153
+ for sub_dep in dep_list:
154
+ sub_dep_list = sub_dep.split('|')
155
+ assert(len(sub_dep_list) == 3)
156
+ type_set.add(sub_dep_list[1])
157
+ # print('Type:', type_set)
158
+
159
+ if args.action == 'parse':
160
+ # dp_path, sen_list = list(dependency_sen_dict.items())[0]
161
+ # check
162
+ # paper_id, sen_id = sen_list[0]
163
+ # sen = raw_text_sen[paper_id][sen_id]
164
+ # doc = nlp(sen['text'])
165
+ # print(dp_path, '\n')
166
+ # pprint.pprint(sen)
167
+ # print()
168
+ # for token in doc:
169
+ # print((token.head.text, token.text, token.dep_))
170
+
171
+ out = ''
172
+ for k, v_dict in draft.items():
173
+ input = v_dict['in']
174
+ output = v_dict['out']
175
+ if input == '':
176
+ continue
177
+ output = output.replace('\n', ' ')
178
+ doc = nlp(output)
179
+ for sen in doc.sents:
180
+ out += raw_to_format(sen.text) + '\n'
181
+ with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}_{args.mode}_parsein.txt', 'w') as fl:
182
+ fl.write(out)
183
+ elif args.action == 'extract':
184
+
185
+ # dependency_to_type_id = {}
186
+ # for k, v in Parameters.edge_type_to_id.items():
187
+ # dependency_to_type_id[k] = {}
188
+ # for type in v:
189
+ # LL = list(retieve_sentence_through_edgetype[type]['manual'].keys()) + list(retieve_sentence_through_edgetype[type]['auto'].keys())
190
+ # for dp in LL:
191
+ # dependency_to_type_id[k][dp] = type
192
+ if os.path.exists('generate_abstract/dependency_to_type_id.pickle'):
193
+ with open('generate_abstract/dependency_to_type_id.pickle', 'rb') as fl:
194
+ dependency_to_type_id = pkl.load(fl)
195
+ else:
196
+ dependency_to_type_id = {}
197
+ print('Loading path data ...')
198
+ for k in Parameters.edge_type_to_id.keys():
199
+ start, end = k.split('-')
200
+ dependency_to_type_id[k] = {}
201
+ inner_edge_type_to_id = Parameters.edge_type_to_id[k]
202
+ inner_edge_type_dict = Parameters.edge_type_dict[k]
203
+ cal_manual_num = [0] * len(inner_edge_type_to_id)
204
+ with open('../GNBRdata/part-i-'+start+'-'+end+'-path-theme-distributions.txt', 'r') as fl:
205
+ for i, line in tqdm(list(enumerate(fl.readlines()))):
206
+ tmp = line.split('\t')
207
+ if i == 0:
208
+ head = [tmp[i] for i in range(1, len(tmp), 2)]
209
+ assert ' '.join(head) == ' '.join(inner_edge_type_dict[0])
210
+ continue
211
+ probability = [float(tmp[i]) for i in range(1, len(tmp), 2)]
212
+ flag_list = [int(tmp[i]) for i in range(2, len(tmp), 2)]
213
+ indices = np.where(np.asarray(flag_list) == 1)[0]
214
+ if len(indices) >= 1:
215
+ tmp_p = [cal_manual_num[i] for i in indices]
216
+ p = indices[np.argmin(tmp_p)]
217
+ cal_manual_num[p] += 1
218
+ else:
219
+ p = np.argmax(probability)
220
+ assert tmp[0].lower() not in dependency_to_type_id.keys()
221
+ dependency_to_type_id[k][tmp[0].lower()] = inner_edge_type_to_id[p]
222
+ with open('generate_abstract/dependency_to_type_id.pickle', 'wb') as fl:
223
+ pkl.dump(dependency_to_type_id, fl)
224
+
225
+ # record = []
226
+ # with open(f'generate_abstract/par_parseout.txt', 'r') as fl:
227
+ # Tmp = []
228
+ # tmp = []
229
+ # for i,line in enumerate(fl.readlines()):
230
+ # # print(len(line), line)
231
+ # line = line.replace('\n', '')
232
+ # if len(line) > 1:
233
+ # tmp.append(line)
234
+ # else:
235
+ # Tmp.append(tmp)
236
+ # tmp = []
237
+ # if len(Tmp) == 3:
238
+ # record.append(Tmp)
239
+ # Tmp = []
240
+
241
+ # print(len(record))
242
+ # record_index = 0
243
+ # add = 0
244
+ # Attack = []
245
+ # for ii in range(100):
246
+
247
+ # # input = v_dict['in']
248
+ # # output = v_dict['out']
249
+ # # output = output.replace('\n', ' ')
250
+ # s, r, o = attack_data[ii]
251
+ # dependency_sen_dict = retieve_sentence_through_edgetype[int(r)]['manual']
252
+
253
+ # target_dp = set()
254
+ # for dp_path, sen_list in dependency_sen_dict.items():
255
+ # target_dp.add(dp_path)
256
+ # DP_list = []
257
+ # for _ in range(1):
258
+ # dp_dict = {}
259
+ # data = record[record_index]
260
+ # record_index += 1
261
+ # dp_paths = data[2]
262
+ # nodes_list = []
263
+ # edges_list = []
264
+ # for line in dp_paths:
265
+ # ttp, tmp = line.split('(')
266
+ # assert tmp[-1] == ')'
267
+ # tmp = tmp[:-1]
268
+ # e1, e2 = tmp.split(', ')
269
+ # if not ttp in type_set and ':' in ttp:
270
+ # ttp = ttp.split(':')[0]
271
+ # dp_dict[f'{e1}_x_{e2}'] = [e1, ttp, e2]
272
+ # dp_dict[f'{e2}_x_{e1}'] = [e1, ttp, e2]
273
+ # nodes_list.append(e1)
274
+ # nodes_list.append(e2)
275
+ # edges_list.append((e1, e2))
276
+ # nodes_list = list(set(nodes_list))
277
+ # pure_name = [('-'.join(name.split('-')[:-1])).replace('_', ' ') for name in nodes_list]
278
+ # graph = nx.Graph(edges_list)
279
+
280
+ # type_list = [name_to_type[name] if name in good_name else '' for name in pure_name]
281
+ # # print(type_list)
282
+ # # for i in range(len(type_list)):
283
+ # # print(pure_name[i], type_list[i])
284
+ # for i in range(len(nodes_list)):
285
+ # if type_list[i] != '':
286
+ # for j in range(len(nodes_list)):
287
+ # if i != j and type_list[j] != '':
288
+ # if f'{type_list[i]}-{type_list[j]}' in Parameters.edge_type_to_id.keys():
289
+ # # print(f'{type_list[i]}_{type_list[j]}')
290
+ # ret_path = []
291
+ # sp = nx.shortest_path(graph, source=nodes_list[i], target=nodes_list[j])
292
+ # start = sp[0]
293
+ # end = sp[-1]
294
+ # for k in range(len(sp)-1):
295
+ # e1, ttp, e2 = dp_dict[f'{sp[k]}_x_{sp[k+1]}']
296
+ # if e1 == start:
297
+ # e1 = 'start_entity-x'
298
+ # if e2 == start:
299
+ # e2 = 'start_entity-x'
300
+ # if e1 == end:
301
+ # e1 = 'end_entity-x'
302
+ # if e2 == end:
303
+ # e2 = 'end_entity-x'
304
+ # ret_path.append(f'{"-".join(e1.split("-")[:-1])}|{ttp}|{"-".join(e2.split("-")[:-1])}'.lower())
305
+ # dependency_P = ' '.join(ret_path)
306
+ # DP_list.append((f'{type_list[i]}-{type_list[j]}',
307
+ # name_to_meshid[pure_name[i]],
308
+ # name_to_meshid[pure_name[j]],
309
+ # dependency_P))
310
+
311
+ # boo = False
312
+ # modified_attack = []
313
+ # for k, ss, tt, dp in DP_list:
314
+ # if dp in dependency_to_type_id[k].keys():
315
+ # tp = str(dependency_to_type_id[k][dp])
316
+ # id_ss = str(meshid_to_id[ss])
317
+ # id_tt = str(meshid_to_id[tt])
318
+ # modified_attack.append(f'{id_ss}*{tp}*{id_tt}')
319
+ # if int(dependency_to_type_id[k][dp]) == int(r):
320
+ # # if id_to_meshid[s] == ss and id_to_meshid[o] == tt:
321
+ # boo = True
322
+ # modified_attack = list(set(modified_attack))
323
+ # modified_attack = [k.split('*') for k in modified_attack]
324
+ # if boo:
325
+ # add += 1
326
+ # # else:
327
+ # # print(ii)
328
+
329
+ # # for i in range(len(type_list)):
330
+ # # if type_list[i]:
331
+ # # print(pure_name[i], type_list[i])
332
+ # # for k, ss, tt, dp in DP_list:
333
+ # # print(k, dp)
334
+ # # print(record[record_index - 1])
335
+ # # raise Exception('No!!')
336
+ # Attack.append(modified_attack)
337
+
338
+ record = []
339
+ with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}_{args.mode}_parseout.txt', 'r') as fl:
340
+ Tmp = []
341
+ tmp = []
342
+ for i,line in enumerate(fl.readlines()):
343
+ # print(len(line), line)
344
+ line = line.replace('\n', '')
345
+ if len(line) > 1:
346
+ tmp.append(line)
347
+ else:
348
+ if len(Tmp) == 2:
349
+ if len(tmp) == 1 and '/' in tmp[0].split(' ')[0]:
350
+ Tmp.append([])
351
+ record.append(Tmp)
352
+ Tmp = []
353
+ Tmp.append(tmp)
354
+ if len(Tmp) == 2 and tmp[0][:5] != '(ROOT':
355
+ print(record[-1][2])
356
+ raise Exception('??')
357
+ tmp = []
358
+ if len(Tmp) == 3:
359
+ record.append(Tmp)
360
+ Tmp = []
361
+ with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}_{args.mode}_parsein.txt', 'r') as fl:
362
+ parsin = fl.readlines()
363
+
364
+ print('Record len', len(record), 'Parsin len:', len(parsin))
365
+ record_index = 0
366
+ add = 0
367
+
368
+ Attack = []
369
+ for ii, (k, v_dict) in enumerate(tqdm(draft.items())):
370
+
371
+ input = v_dict['in']
372
+ output = v_dict['out']
373
+ output = output.replace('\n', ' ')
374
+ s, r, o = attack_data[ii]
375
+ s = str(s)
376
+ r = str(r)
377
+ o = str(o)
378
+ assert ii == int(k.split('_')[-1])
379
+
380
+ DP_list = []
381
+ if input != '':
382
+
383
+ dependency_sen_dict = retieve_sentence_through_edgetype[int(r)]['manual']
384
+ target_dp = set()
385
+ for dp_path, sen_list in dependency_sen_dict.items():
386
+ target_dp.add(dp_path)
387
+ doc = nlp(output)
388
+
389
+ for sen in doc.sents:
390
+ dp_dict = {}
391
+ if record_index >= len(record):
392
+ break
393
+ data = record[record_index]
394
+ record_index += 1
395
+ dp_paths = data[2]
396
+ nodes_list = []
397
+ edges_list = []
398
+ for line in dp_paths:
399
+ aa = line.split('(')
400
+ if len(aa) == 1:
401
+ print(ii)
402
+ print(sen)
403
+ print(data)
404
+ raise Exception
405
+ ttp, tmp = aa[0], aa[1]
406
+ assert tmp[-1] == ')'
407
+ tmp = tmp[:-1]
408
+ e1, e2 = tmp.split(', ')
409
+ if not ttp in type_set and ':' in ttp:
410
+ ttp = ttp.split(':')[0]
411
+ dp_dict[f'{e1}_x_{e2}'] = [e1, ttp, e2]
412
+ dp_dict[f'{e2}_x_{e1}'] = [e1, ttp, e2]
413
+ nodes_list.append(e1)
414
+ nodes_list.append(e2)
415
+ edges_list.append((e1, e2))
416
+ nodes_list = list(set(nodes_list))
417
+ pure_name = [('-'.join(name.split('-')[:-1])).replace('_', ' ') for name in nodes_list]
418
+ graph = nx.Graph(edges_list)
419
+
420
+ type_list = [name_to_type[name] if name in good_name else '' for name in pure_name]
421
+ # print(type_list)
422
+ for i in range(len(nodes_list)):
423
+ if type_list[i] != '':
424
+ for j in range(len(nodes_list)):
425
+ if i != j and type_list[j] != '':
426
+ if f'{type_list[i]}-{type_list[j]}' in Parameters.edge_type_to_id.keys():
427
+ # print(f'{type_list[i]}_{type_list[j]}')
428
+ ret_path = []
429
+ sp = nx.shortest_path(graph, source=nodes_list[i], target=nodes_list[j])
430
+ start = sp[0]
431
+ end = sp[-1]
432
+ for k in range(len(sp)-1):
433
+ e1, ttp, e2 = dp_dict[f'{sp[k]}_x_{sp[k+1]}']
434
+ if e1 == start:
435
+ e1 = 'start_entity-x'
436
+ if e2 == start:
437
+ e2 = 'start_entity-x'
438
+ if e1 == end:
439
+ e1 = 'end_entity-x'
440
+ if e2 == end:
441
+ e2 = 'end_entity-x'
442
+ ret_path.append(f'{"-".join(e1.split("-")[:-1])}|{ttp}|{"-".join(e2.split("-")[:-1])}'.lower())
443
+ dependency_P = ' '.join(ret_path)
444
+ DP_list.append((f'{type_list[i]}-{type_list[j]}',
445
+ name_to_meshid[pure_name[i]],
446
+ name_to_meshid[pure_name[j]],
447
+ dependency_P))
448
+
449
+ boo = False
450
+ modified_attack = []
451
+ for k, ss, tt, dp in DP_list:
452
+ if dp in dependency_to_type_id[k].keys():
453
+ tp = str(dependency_to_type_id[k][dp])
454
+ id_ss = str(meshid_to_id[ss])
455
+ id_tt = str(meshid_to_id[tt])
456
+ modified_attack.append(f'{id_ss}*{tp}*{id_tt}')
457
+ if int(dependency_to_type_id[k][dp]) == int(r):
458
+ if id_to_meshid[s] == ss and id_to_meshid[o] == tt:
459
+ boo = True
460
+ modified_attack = list(set(modified_attack))
461
+ modified_attack = [k.split('*') for k in modified_attack]
462
+ if boo:
463
+ # print(DP_list)
464
+ add += 1
465
+ Attack.append(modified_attack)
466
+ print(add)
467
+ print('End record_index:', record_index)
468
+ final_Attack = Attack
469
+ print('Len of Attack:', len(Attack))
470
+ with open(modified_attack_path, 'wb') as fl:
471
+ pkl.dump(final_Attack, fl)
472
+ else:
473
+ raise Exception('Wrong action !!')
DiseaseAgnostic/edge_to_abstract.py ADDED
@@ -0,0 +1,652 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%%
2
+ import torch
3
+ import numpy as np
4
+ from torch.autograd import Variable
5
+ from sklearn import metrics
6
+
7
+ import datetime
8
+ from typing import Dict, Tuple, List
9
+ import logging
10
+ import os
11
+ import utils
12
+ import pickle as pkl
13
+ import json
14
+ import torch.backends.cudnn as cudnn
15
+
16
+ from tqdm import tqdm
17
+
18
+ import sys
19
+ sys.path.append("..")
20
+ import Parameters
21
+
22
+ parser = utils.get_argument_parser()
23
+ parser.add_argument('--reasonable-rate', type = float, default=0.7, help = 'The added edge\'s existance rank prob greater than this rate')
24
+ parser.add_argument('--mode', type=str, default='sentence', help='sentence, biogpt or finetune')
25
+ parser.add_argument('--init-mode', type = str, default='random', help = 'How to select target nodes')
26
+ parser.add_argument('--ratio', type = str, default='', help='ratio of the number of changed words')
27
+ args = parser.parse_args()
28
+ args = utils.set_hyperparams(args)
29
+
30
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+
32
+ utils.seed_all(args.seed)
33
+ np.set_printoptions(precision=5)
34
+ cudnn.benchmark = False
35
+
36
+ data_path = '../DiseaseSpecific/processed_data/GNBR'
37
+ target_path = f'processed_data/target_{args.reasonable_rate}{args.init_mode}.pkl'
38
+ attack_path = f'processed_data/attack_edge_{args.model}_{args.reasonable_rate}{args.init_mode}.pkl'
39
+
40
+ # target_data = utils.load_data(target_path)
41
+ with open(target_path, 'rb') as fl:
42
+ Target_node_list = pkl.load(fl)
43
+ with open(attack_path, 'rb') as fl:
44
+ Attack_edge_list = pkl.load(fl)
45
+ attack_data = np.array(Attack_edge_list).reshape(-1, 3)
46
+ # assert target_data.shape == attack_data.shape
47
+ #%%
48
+
49
+ with open('../DiseaseSpecific/processed_data/GNBR/entities_reverse_dict.json') as fl:
50
+ id_to_meshid = json.load(fl)
51
+ with open(Parameters.GNBRfile+'entity_raw_name', 'rb') as fl:
52
+ entity_raw_name = pkl.load(fl)
53
+ with open(Parameters.GNBRfile+'retieve_sentence_through_edgetype', 'rb') as fl:
54
+ retieve_sentence_through_edgetype = pkl.load(fl)
55
+ with open(Parameters.GNBRfile+'raw_text_of_each_sentence', 'rb') as fl:
56
+ raw_text_sen = pkl.load(fl)
57
+
58
+ if args.mode == 'sentence':
59
+ import torch
60
+ from torch.nn.modules.loss import CrossEntropyLoss
61
+ from transformers import AutoTokenizer
62
+ from transformers import BioGptForCausalLM
63
+ criterion = CrossEntropyLoss(reduction="none")
64
+
65
+ print('Generating GPT input ...')
66
+
67
+ tokenizer = AutoTokenizer.from_pretrained('microsoft/biogpt')
68
+ tokenizer.pad_token = tokenizer.eos_token
69
+ model = BioGptForCausalLM.from_pretrained('microsoft/biogpt', pad_token_id=tokenizer.eos_token_id)
70
+ model.to(device)
71
+ model.eval()
72
+ GPT_batch_size = 24
73
+ single_sentence = {}
74
+ test_text = []
75
+ test_dp = []
76
+ test_parse = []
77
+ for i, (s, r, o) in enumerate(tqdm(attack_data)):
78
+
79
+ s = str(s)
80
+ r = str(r)
81
+ o = str(o)
82
+ if int(s) != -1:
83
+
84
+ dependency_sen_dict = retieve_sentence_through_edgetype[int(r)]['manual']
85
+ candidate_sen = []
86
+ Dp_path = []
87
+ L = len(dependency_sen_dict.keys())
88
+ bound = 500 // L
89
+ if bound == 0:
90
+ bound = 1
91
+ for dp_path, sen_list in dependency_sen_dict.items():
92
+ if len(sen_list) > bound:
93
+ index = np.random.choice(np.array(range(len(sen_list))), bound, replace=False)
94
+ sen_list = [sen_list[aa] for aa in index]
95
+ candidate_sen += sen_list
96
+ Dp_path += [dp_path] * len(sen_list)
97
+
98
+ text_s = entity_raw_name[id_to_meshid[s]]
99
+ text_o = entity_raw_name[id_to_meshid[o]]
100
+ candidate_text_sen = []
101
+ candidate_ori_sen = []
102
+ candidate_parse_sen = []
103
+
104
+ for paper_id, sen_id in candidate_sen:
105
+ sen = raw_text_sen[paper_id][sen_id]
106
+ text = sen['text']
107
+ candidate_ori_sen.append(text)
108
+ ss = sen['start_formatted']
109
+ oo = sen['end_formatted']
110
+ text = text.replace('-LRB-', '(')
111
+ text = text.replace('-RRB-', ')')
112
+ text = text.replace('-LSB-', '[')
113
+ text = text.replace('-RSB-', ']')
114
+ text = text.replace('-LCB-', '{')
115
+ text = text.replace('-RCB-', '}')
116
+ parse_text = text
117
+ parse_text = parse_text.replace(ss, text_s.replace(' ', '_'))
118
+ parse_text = parse_text.replace(oo, text_o.replace(' ', '_'))
119
+ text = text.replace(ss, text_s)
120
+ text = text.replace(oo, text_o)
121
+ text = text.replace('_', ' ')
122
+ candidate_text_sen.append(text)
123
+ candidate_parse_sen.append(parse_text)
124
+ tokens = tokenizer( candidate_text_sen,
125
+ truncation = True,
126
+ padding = True,
127
+ max_length = 300,
128
+ return_tensors="pt")
129
+ target_ids = tokens['input_ids'].to(device)
130
+ attention_mask = tokens['attention_mask'].to(device)
131
+
132
+ L = len(candidate_text_sen)
133
+ assert L > 0
134
+ ret_log_L = []
135
+ for l in range(0, L, GPT_batch_size):
136
+ R = min(L, l + GPT_batch_size)
137
+ target = target_ids[l:R, :]
138
+ attention = attention_mask[l:R, :]
139
+ outputs = model(input_ids = target,
140
+ attention_mask = attention,
141
+ labels = target)
142
+ logits = outputs.logits
143
+ shift_logits = logits[..., :-1, :].contiguous()
144
+ shift_labels = target[..., 1:].contiguous()
145
+ Loss = criterion(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1))
146
+ Loss = Loss.view(-1, shift_logits.shape[1])
147
+ attention = attention[..., 1:].contiguous()
148
+ log_Loss = (torch.mean(Loss * attention.float(), dim = 1) / torch.mean(attention.float(), dim = 1))
149
+ ret_log_L.append(log_Loss.detach())
150
+
151
+
152
+ ret_log_L = list(torch.cat(ret_log_L, -1).cpu().numpy())
153
+ sen_score = list(zip(candidate_text_sen, ret_log_L, candidate_ori_sen, Dp_path, candidate_parse_sen))
154
+ sen_score.sort(key = lambda x: x[1])
155
+ test_text.append(sen_score[0][2])
156
+ test_dp.append(sen_score[0][3])
157
+ test_parse.append(sen_score[0][4])
158
+ single_sentence.update({f'{s}_{r}_{o}_{i}': sen_score[0][0]})
159
+
160
+ else:
161
+ single_sentence.update({f'{s}_{r}_{o}_{i}': ''})
162
+
163
+ with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}_sentence.json', 'w') as fl:
164
+ json.dump(single_sentence, fl, indent=4)
165
+ # with open('generate_abstract/test.txt', 'w') as fl:
166
+ # fl.write('\n'.join(test_text))
167
+ # with open('generate_abstract/dp.txt', 'w') as fl:
168
+ # fl.write('\n'.join(test_dp))
169
+ with open (f'generate_abstract/path/{args.init_mode}{args.reasonable_rate}_path.json', 'w') as fl:
170
+ fl.write('\n'.join(test_dp))
171
+ with open (f'generate_abstract/path/{args.init_mode}{args.reasonable_rate}_temp.json', 'w') as fl:
172
+ fl.write('\n'.join(test_text))
173
+
174
+ elif args.mode == 'biogpt':
175
+ pass
176
+ # from biogpt_generate import GPT_eval
177
+ # import spacy
178
+
179
+ # model = GPT_eval(args.seed)
180
+
181
+ # nlp = spacy.load("en_core_web_sm")
182
+ # with open(f'generate_abstract/{args.target_split}_{args.reasonable_rate}_sentence.json', 'r') as fl:
183
+ # data = json.load(fl)
184
+
185
+ # KK = []
186
+ # input = []
187
+ # for i,(k, v) in enumerate(data.items()):
188
+ # KK.append(k)
189
+ # input.append(v)
190
+ # output = model.eval(input)
191
+
192
+ # ret = {}
193
+ # for i, o in enumerate(output):
194
+
195
+ # o = o.replace('<|abstract|>', '')
196
+ # doc = nlp(o)
197
+ # sen_list = []
198
+ # sen_set = set()
199
+ # for sen in doc.sents:
200
+ # txt = sen.text
201
+ # if not (txt.lower() in sen_set):
202
+ # sen_set.add(txt.lower())
203
+ # sen_list.append(txt)
204
+ # O = ' '.join(sen_list)
205
+ # ret[KK[i]] = {'in' : input[i], 'out' : O}
206
+
207
+ # with open(f'generate_abstract/{args.target_split}_{args.reasonable_rate}_biogpt.json', 'w') as fl:
208
+ # json.dump(ret, fl, indent=4)
209
+
210
+ elif args.mode == 'finetune':
211
+
212
+ import spacy
213
+ import pprint
214
+ from transformers import AutoModel, AutoTokenizer,BartForConditionalGeneration
215
+
216
+ print('Finetuning ...')
217
+
218
+ with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}_chat.json', 'r') as fl:
219
+ draft = json.load(fl)
220
+ with open (f'generate_abstract/path/{args.init_mode}{args.reasonable_rate}_path.json', 'r') as fl:
221
+ dpath = fl.readlines()
222
+
223
+ nlp = spacy.load("en_core_web_sm")
224
+ if os.path.exists(f'generate_abstract/bioBART/{args.init_mode}{args.reasonable_rate}{args.ratio}_candidates.json'):
225
+ with open(f'generate_abstract/bioBART/{args.init_mode}{args.reasonable_rate}{args.ratio}_candidates.json', 'r') as fl:
226
+ ret_candidates = json.load(fl)
227
+ else:
228
+
229
+ def find_mini_span(vec, words, check_set):
230
+
231
+
232
+ def cal(text, sset):
233
+ add = 0
234
+ for tt in sset:
235
+ if tt in text:
236
+ add += 1
237
+ return add
238
+ text = ' '.join(words)
239
+ max_add = cal(text, check_set)
240
+
241
+ minn = 10000000
242
+ span = ''
243
+ rc = None
244
+ for i in range(len(vec)):
245
+ if vec[i] == True:
246
+ p = -1
247
+ for j in range(i+1, len(vec)+1):
248
+ if vec[j-1] == True:
249
+ text = ' '.join(words[i:j])
250
+ if cal(text, check_set) == max_add:
251
+ p = j
252
+ break
253
+ if p > 0:
254
+ if (p-i) < minn:
255
+ minn = p-i
256
+ span = ' '.join(words[i:p])
257
+ rc = (i, p)
258
+ if rc:
259
+ for i in range(rc[0], rc[1]):
260
+ vec[i] = True
261
+ return vec, span
262
+
263
+ # def mask_func(tokenized_sen, position):
264
+
265
+ # if len(tokenized_sen) == 0:
266
+ # return []
267
+ # token_list = []
268
+ # # for sen in tokenized_sen:
269
+ # # for token in sen:
270
+ # # token_list.append(token)
271
+ # for sen in tokenized_sen:
272
+ # token_list += sen.text.split(' ')
273
+ # l_p = 0
274
+ # r_p = 1
275
+ # assert position == 'front' or position == 'back'
276
+ # if position == 'back':
277
+ # l_p, r_p = r_p, l_p
278
+ # P = np.linspace(start = l_p, stop = r_p, num = len(token_list))
279
+ # P = (P ** 3) * 0.4
280
+
281
+ # ret_list = []
282
+ # for t, p in zip(token_list, list(P)):
283
+ # if '.' in t or '(' in t or ')' in t or '[' in t or ']' in t:
284
+ # ret_list.append(t)
285
+ # else:
286
+ # if np.random.rand() < p:
287
+ # ret_list.append('<mask>')
288
+ # else:
289
+ # ret_list.append(t)
290
+ # return [' '.join(ret_list)]
291
+ def mask_func(tokenized_sen):
292
+
293
+ if len(tokenized_sen) == 0:
294
+ return []
295
+ token_list = []
296
+ # for sen in tokenized_sen:
297
+ # for token in sen:
298
+ # token_list.append(token)
299
+ for sen in tokenized_sen:
300
+ token_list += sen.text.split(' ')
301
+ if args.ratio == '':
302
+ P = 0.3
303
+ else:
304
+ P = float(args.ratio)
305
+
306
+ ret_list = []
307
+ i = 0
308
+ mask_num = 0
309
+ while i < len(token_list):
310
+ t = token_list[i]
311
+ if '.' in t or '(' in t or ')' in t or '[' in t or ']' in t:
312
+ ret_list.append(t)
313
+ i += 1
314
+ mask_num = 0
315
+ else:
316
+ length = np.random.poisson(3)
317
+ if np.random.rand() < P and length > 0:
318
+ if mask_num < 8:
319
+ ret_list.append('<mask>')
320
+ mask_num += 1
321
+ i += length
322
+ else:
323
+ ret_list.append(t)
324
+ i += 1
325
+ mask_num = 0
326
+ return [' '.join(ret_list)]
327
+
328
+ model = BartForConditionalGeneration.from_pretrained('GanjinZero/biobart-large')
329
+ model.eval()
330
+ model.to(device)
331
+ tokenizer = AutoTokenizer.from_pretrained('GanjinZero/biobart-large')
332
+
333
+ ret_candidates = {}
334
+ dpath_i = 0
335
+
336
+ for i,(k, v) in enumerate(tqdm(draft.items())):
337
+
338
+ input = v['in'].replace('\n', '')
339
+ output = v['out'].replace('\n', '')
340
+ s, r, o = attack_data[i]
341
+ s = str(s)
342
+ o = str(o)
343
+ r = str(r)
344
+
345
+ if int(s) == -1:
346
+ ret_candidates[str(i)] = {'span': '', 'prompt' : '', 'out' : [], 'in': [], 'assist': []}
347
+ continue
348
+
349
+ path_text = dpath[dpath_i].replace('\n', '')
350
+ dpath_i += 1
351
+ text_s = entity_raw_name[id_to_meshid[s]]
352
+ text_o = entity_raw_name[id_to_meshid[o]]
353
+
354
+ doc = nlp(output)
355
+ words= input.split(' ')
356
+ tokenized_sens = [sen for sen in doc.sents]
357
+ sens = np.array([sen.text for sen in doc.sents])
358
+
359
+ checkset = set([text_s, text_o])
360
+ e_entity = set(['start_entity', 'end_entity'])
361
+ for path in path_text.split(' '):
362
+ a, b, c = path.split('|')
363
+ if a not in e_entity:
364
+ checkset.add(a)
365
+ if c not in e_entity:
366
+ checkset.add(c)
367
+ vec = []
368
+ l = 0
369
+ while(l < len(words)):
370
+ bo =False
371
+ for j in range(len(words), l, -1): # reversing is important !!!
372
+ cc = ' '.join(words[l:j])
373
+ if (cc in checkset):
374
+ vec += [True] * (j-l)
375
+ l = j
376
+ bo = True
377
+ break
378
+ if not bo:
379
+ vec.append(False)
380
+ l += 1
381
+ vec, span = find_mini_span(vec, words, checkset)
382
+ # vec = np.vectorize(lambda x: x in checkset)(words)
383
+ vec[-1] = True
384
+ prompt = []
385
+ mask_num = 0
386
+ for j, bo in enumerate(vec):
387
+ if not bo:
388
+ mask_num += 1
389
+ else:
390
+ if mask_num > 0:
391
+ # mask_num = mask_num // 3 # span length ~ poisson distribution (lambda = 3)
392
+ mask_num = max(mask_num, 1)
393
+ mask_num= min(8, mask_num)
394
+ prompt += ['<mask>'] * mask_num
395
+ prompt.append(words[j])
396
+ mask_num = 0
397
+ prompt = ' '.join(prompt)
398
+ Text = []
399
+ Assist = []
400
+
401
+ for j in range(len(sens)):
402
+ Bart_input = list(sens[:j]) + [prompt] +list(sens[j+1:])
403
+ assist = list(sens[:j]) + [input] +list(sens[j+1:])
404
+ Text.append(' '.join(Bart_input))
405
+ Assist.append(' '.join(assist))
406
+
407
+ for j in range(len(sens)):
408
+ Bart_input = mask_func(tokenized_sens[:j]) + [input] + mask_func(tokenized_sens[j+1:])
409
+ assist = list(sens[:j]) + [input] +list(sens[j+1:])
410
+ Text.append(' '.join(Bart_input))
411
+ Assist.append(' '.join(assist))
412
+
413
+ batch_size = len(Text) // 2
414
+ Outs = []
415
+ for l in range(2):
416
+ A = tokenizer(Text[batch_size * l:batch_size * (l+1)],
417
+ truncation = True,
418
+ padding = True,
419
+ max_length = 1024,
420
+ return_tensors="pt")
421
+ input_ids = A['input_ids'].to(device)
422
+ attention_mask = A['attention_mask'].to(device)
423
+ aaid = model.generate(input_ids, num_beams = 5, max_length = 1024)
424
+ outs = tokenizer.batch_decode(aaid, skip_special_tokens=True, clean_up_tokenization_spaces=False)
425
+ Outs += outs
426
+ ret_candidates[str(i)] = {'span': span, 'prompt' : prompt, 'out' : Outs, 'in': Text, 'assist': Assist}
427
+ with open(f'generate_abstract/bioBART/{args.init_mode}{args.reasonable_rate}{args.ratio}_candidates.json', 'w') as fl:
428
+ json.dump(ret_candidates, fl, indent = 4)
429
+
430
+ from torch.nn.modules.loss import CrossEntropyLoss
431
+ from transformers import BioGptForCausalLM
432
+ criterion = CrossEntropyLoss(reduction="none")
433
+
434
+ tokenizer = AutoTokenizer.from_pretrained('microsoft/biogpt')
435
+ tokenizer.pad_token = tokenizer.eos_token
436
+ model = BioGptForCausalLM.from_pretrained('microsoft/biogpt', pad_token_id=tokenizer.eos_token_id)
437
+ model.to(device)
438
+ model.eval()
439
+
440
+ scored = {}
441
+ ret = {}
442
+ case_study = {}
443
+ p_ret = {}
444
+ add = 0
445
+ dpath_i = 0
446
+ inner_better = 0
447
+ outter_better = 0
448
+ better_than_gpt = 0
449
+ for i,(k, v) in enumerate(tqdm(draft.items())):
450
+
451
+ span = ret_candidates[str(i)]['span']
452
+ prompt = ret_candidates[str(i)]['prompt']
453
+ sen_list = ret_candidates[str(i)]['out']
454
+ BART_in = ret_candidates[str(i)]['in']
455
+ Assist = ret_candidates[str(i)]['assist']
456
+
457
+ s, r, o = attack_data[i]
458
+ s = str(s)
459
+ r = str(r)
460
+ o = str(o)
461
+
462
+ if int(s) == -1:
463
+ ret[k] = {'prompt': '', 'in':'', 'out': ''}
464
+ p_ret[k] = {'prompt': '', 'in':'', 'out': ''}
465
+ continue
466
+
467
+ text_s = entity_raw_name[id_to_meshid[s]]
468
+ text_o = entity_raw_name[id_to_meshid[o]]
469
+
470
+ def process(text):
471
+
472
+ for i in range(ord('A'), ord('Z')+1):
473
+ text = text.replace(f'.{chr(i)}', f'. {chr(i)}')
474
+ return text
475
+
476
+ sen_list = [process(text) for text in sen_list]
477
+ path_text = dpath[dpath_i].replace('\n', '')
478
+ dpath_i += 1
479
+
480
+ checkset = set([text_s, text_o])
481
+ e_entity = set(['start_entity', 'end_entity'])
482
+ for path in path_text.split(' '):
483
+ a, b, c = path.split('|')
484
+ if a not in e_entity:
485
+ checkset.add(a)
486
+ if c not in e_entity:
487
+ checkset.add(c)
488
+
489
+ input = v['in'].replace('\n', '')
490
+ output = v['out'].replace('\n', '')
491
+
492
+ doc = nlp(output)
493
+ gpt_sens = [sen.text for sen in doc.sents]
494
+ assert len(gpt_sens) == len(sen_list) // 2
495
+
496
+ word_sets = []
497
+ for sen in gpt_sens:
498
+ word_sets.append(set(sen.split(' ')))
499
+
500
+ def sen_align(word_sets, modified_word_sets):
501
+
502
+ l = 0
503
+ while(l < len(modified_word_sets)):
504
+ if len(word_sets[l].intersection(modified_word_sets[l])) > len(word_sets[l]) * 0.8:
505
+ l += 1
506
+ else:
507
+ break
508
+ if l == len(modified_word_sets):
509
+ return -1, -1, -1, -1
510
+ r = l + 1
511
+ r1 = None
512
+ r2 = None
513
+ for pos1 in range(r, len(word_sets)):
514
+ for pos2 in range(r, len(modified_word_sets)):
515
+ if len(word_sets[pos1].intersection(modified_word_sets[pos2])) > len(word_sets[pos1]) * 0.8:
516
+ r1 = pos1
517
+ r2 = pos2
518
+ break
519
+ if r1 is not None:
520
+ break
521
+ if r1 is None:
522
+ r1 = len(word_sets)
523
+ r2 = len(modified_word_sets)
524
+ return l, r1, l, r2
525
+
526
+ replace_sen_list = []
527
+ boundary = []
528
+ assert len(sen_list) % 2 == 0
529
+ for j in range(len(sen_list) // 2):
530
+ doc = nlp(sen_list[j])
531
+ sens = [sen.text for sen in doc.sents]
532
+ modified_word_sets = [set(sen.split(' ')) for sen in sens]
533
+ l1, r1, l2, r2 = sen_align(word_sets, modified_word_sets)
534
+ boundary.append((l1, r1, l2, r2))
535
+ if l1 == -1:
536
+ replace_sen_list.append(sen_list[j])
537
+ continue
538
+ check_text = ' '.join(sens[l2: r2])
539
+ replace_sen_list.append(' '.join(gpt_sens[:l1] + [check_text] + gpt_sens[r1:]))
540
+ sen_list = replace_sen_list + sen_list[len(sen_list) // 2:]
541
+
542
+ old_L = len(sen_list)
543
+ sen_list.append(output)
544
+ sen_list += Assist
545
+ tokens = tokenizer( sen_list,
546
+ truncation = True,
547
+ padding = True,
548
+ max_length = 1024,
549
+ return_tensors="pt")
550
+ target_ids = tokens['input_ids'].to(device)
551
+ attention_mask = tokens['attention_mask'].to(device)
552
+ L = len(sen_list)
553
+ ret_log_L = []
554
+ for l in range(0, L, 5):
555
+ R = min(L, l + 5)
556
+ target = target_ids[l:R, :]
557
+ attention = attention_mask[l:R, :]
558
+ outputs = model(input_ids = target,
559
+ attention_mask = attention,
560
+ labels = target)
561
+ logits = outputs.logits
562
+ shift_logits = logits[..., :-1, :].contiguous()
563
+ shift_labels = target[..., 1:].contiguous()
564
+ Loss = criterion(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1))
565
+ Loss = Loss.view(-1, shift_logits.shape[1])
566
+ attention = attention[..., 1:].contiguous()
567
+ log_Loss = (torch.mean(Loss * attention.float(), dim = 1) / torch.mean(attention.float(), dim = 1))
568
+ ret_log_L.append(log_Loss.detach())
569
+ log_Loss = torch.cat(ret_log_L, -1).cpu().numpy()
570
+
571
+ real_log_Loss = log_Loss.copy()
572
+
573
+ log_Loss = log_Loss[:old_L]
574
+ # sen_list = sen_list[:old_L]
575
+
576
+ # mini_span should be preserved
577
+ # for j in range(len(log_Loss)):
578
+ # doc = nlp(sen_list[j])
579
+ # sens = [sen.text for sen in doc.sents]
580
+ # Len = len(sen_list)
581
+ # check_text = ' '.join(sens[j : max(0,len(sens) - Len) + j + 1])
582
+ # if span not in check_text:
583
+ # log_Loss[j] += 1
584
+
585
+ p = np.argmin(log_Loss)
586
+ if p < old_L // 2:
587
+ inner_better += 1
588
+ else:
589
+ outter_better += 1
590
+ content = []
591
+ for i in range(len(real_log_Loss)):
592
+ content.append([sen_list[i], str(real_log_Loss[i])])
593
+ scored[k] = {'path':path_text, 'prompt': prompt, 'in':input, 's':text_s, 'o':text_o, 'out': content, 'bound': boundary}
594
+ p_p = p
595
+ # print('Old_L:', old_L)
596
+
597
+ if real_log_Loss[p] > real_log_Loss[p+1+old_L]:
598
+ p_p = p+1+old_L
599
+ if real_log_Loss[p] > real_log_Loss[p+1+old_L]:
600
+ add += 1
601
+
602
+ if real_log_Loss[p] < real_log_Loss[old_L]:
603
+ better_than_gpt += 1
604
+ else:
605
+ if real_log_Loss[p] > real_log_Loss[p+1+old_L]:
606
+ p = p+1+old_L
607
+ # case_study[k] = {'path':path_text, 'entity_0': text_s, 'entity_1': text_o, 'GPT_in': input, 'Prompt': prompt, 'GPT_out': {'text': output, 'perplexity': str(np.exp(real_log_Loss[old_L]))}, 'BART_in': BART_in[p], 'BART_out': {'text': sen_list[p], 'perplexity': str(np.exp(real_log_Loss[p]))}, 'Assist': {'text': Assist[p], 'perplexity': str(np.exp(real_log_Loss[p+1+old_L]))}}
608
+ ret[k] = {'prompt': prompt, 'in':input, 'out': sen_list[p]}
609
+ p_ret[k] = {'prompt': prompt, 'in':input, 'out': sen_list[p_p]}
610
+ print(add)
611
+ print('inner_better:', inner_better)
612
+ print('outter_better:', outter_better)
613
+ print('better_than_gpt:', better_than_gpt)
614
+ print('better_than_replace', add)
615
+ with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}{args.ratio}_bioBART_finetune.json', 'w') as fl:
616
+ json.dump(ret, fl, indent=4)
617
+ # with open(f'generate_abstract/bioBART/case_{args.target_split}_{args.reasonable_rate}_bioBART_finetune.json', 'w') as fl:
618
+ # json.dump(case_study, fl, indent=4)
619
+ with open(f'generate_abstract/bioBART/{args.init_mode}{args.reasonable_rate}{args.ratio}_scored.json', 'w') as fl:
620
+ json.dump(scored, fl, indent=4)
621
+ with open(f'generate_abstract/bioBART/{args.init_mode}{args.reasonable_rate}{args.ratio}_perplexity.json', 'w') as fl:
622
+ json.dump(p_ret, fl, indent=4)
623
+
624
+ # with open(Parameters.GNBRfile+'original_entity_raw_name', 'rb') as fl:
625
+ # full_entity_raw_name = pkl.load(fl)
626
+ # for k, v in entity_raw_name.items():
627
+ # assert v in full_entity_raw_name[k]
628
+
629
+ # nlp = spacy.load("en_core_web_sm")
630
+ # type_set = set()
631
+ # for aa in range(36):
632
+ # dependency_sen_dict = retieve_sentence_through_edgetype[aa]['manual']
633
+ # tmp_dict = retieve_sentence_through_edgetype[aa]['auto']
634
+ # dependencys = list(dependency_sen_dict.keys()) + list(tmp_dict.keys())
635
+ # for dependency in dependencys:
636
+ # dep_list = dependency.split(' ')
637
+ # for sub_dep in dep_list:
638
+ # sub_dep_list = sub_dep.split('|')
639
+ # assert(len(sub_dep_list) == 3)
640
+ # type_set.add(sub_dep_list[1])
641
+
642
+ # fine_dict = {}
643
+ # for k, v_dict in draft.items():
644
+
645
+ # input = v_dict['in']
646
+ # output = v_dict['out']
647
+ # fine_dict[k] = {'in':input, 'out': input + ' ' + output}
648
+
649
+ # with open(f'generate_abstract/{args.target_split}_{args.reasonable_rate}_sentence_finetune.json', 'w') as fl:
650
+ # json.dump(fine_dict, fl, indent=4)
651
+ else:
652
+ raise Exception('Wrong mode !!')
DiseaseAgnostic/evaluation.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%%
2
+ import logging
3
+ from symbol import parameters
4
+ from textwrap import indent
5
+ import os
6
+ import tempfile
7
+ import sys
8
+ from matplotlib import collections
9
+ import pandas as pd
10
+ import json
11
+ from glob import glob
12
+ from tqdm import tqdm
13
+ import numpy as np
14
+ from pprint import pprint
15
+ import torch
16
+ import pickle as pkl
17
+ from collections import Counter
18
+ # print(dir(collections))
19
+ import networkx as nx
20
+ from collections import Counter
21
+ import utils
22
+ from torch.nn import functional as F
23
+ sys.path.append("..")
24
+ import Parameters
25
+ from DiseaseSpecific.attack import calculate_edge_bound, get_model_loss_without_softmax
26
+
27
+ #%%
28
+ def load_data(file_name):
29
+ df = pd.read_csv(file_name, sep='\t', header=None, names=None, dtype=str)
30
+ df = df.drop_duplicates()
31
+ return df.values
32
+
33
+ parser = utils.get_argument_parser()
34
+ parser.add_argument('--reasonable-rate', type = float, default=0.7, help = 'The added edge\'s existance rank prob greater than this rate')
35
+ parser.add_argument('--mode', type = str, default='', help = ' "" or chat or bioBART')
36
+ parser.add_argument('--init-mode', type = str, default='random', help = 'How to select target nodes') # 'single' for case study
37
+ parser.add_argument('--added-edge-num', type = str, default = '', help = 'Added edge num')
38
+
39
+ args = parser.parse_args()
40
+ args = utils.set_hyperparams(args)
41
+ utils.seed_all(args.seed)
42
+ graph_edge_path = '../DiseaseSpecific/processed_data/GNBR/all.txt'
43
+ idtomeshid_path = '../DiseaseSpecific/processed_data/GNBR/entities_reverse_dict.json'
44
+ model_path = f'../DiseaseSpecific/saved_models/GNBR_{args.model}_128_0.2_0.3_0.3.model'
45
+ data_path = '../DiseaseSpecific/processed_data/GNBR'
46
+ target_path = f'processed_data/target_{args.reasonable_rate}{args.init_mode}.pkl'
47
+ attack_path = f'processed_data/attack_edge_{args.model}_{args.reasonable_rate}{args.init_mode}{args.added_edge_num}{args.mode}.pkl'
48
+
49
+ with open(Parameters.GNBRfile+'original_entity_raw_name', 'rb') as fl:
50
+ full_entity_raw_name = pkl.load(fl)
51
+
52
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
53
+
54
+ # device = torch.device("cpu")
55
+
56
+ args.device = device
57
+ n_ent, n_rel, ent_to_id, rel_to_id = utils.generate_dicts(data_path)
58
+ model = utils.load_model(model_path, args, n_ent, n_rel, args.device)
59
+
60
+ graph_edge = utils.load_data(graph_edge_path)
61
+ with open(idtomeshid_path, 'r') as fl:
62
+ idtomeshid = json.load(fl)
63
+ print(graph_edge.shape, len(idtomeshid))
64
+
65
+ divide_bound, data_mean, data_std = calculate_edge_bound(graph_edge, model, args.device, n_ent)
66
+ print('Defender ...')
67
+ print(divide_bound, data_mean, data_std)
68
+
69
+ meshids = list(idtomeshid.values())
70
+ cal = {
71
+ 'chemical' : 0,
72
+ 'disease' : 0,
73
+ 'gene' : 0
74
+ }
75
+ for meshid in meshids:
76
+ cal[meshid.split('_')[0]] += 1
77
+ # pprint(cal)
78
+
79
+ def check_reasonable(s, r, o):
80
+
81
+ train_trip = np.asarray([[s, r, o]])
82
+ train_trip = torch.from_numpy(train_trip.astype('int64')).to(device)
83
+ edge_loss = get_model_loss_without_softmax(train_trip, model, device).squeeze()
84
+ # edge_losse_log_prob = torch.log(F.softmax(-edge_loss, dim = -1))
85
+
86
+ edge_loss = edge_loss.item()
87
+ edge_loss = (edge_loss - data_mean) / data_std
88
+ edge_losses_prob = 1 / ( 1 + np.exp(edge_loss - divide_bound) )
89
+ bound = 1 - args.reasonable_rate
90
+
91
+ return (edge_losses_prob > bound), edge_losses_prob
92
+
93
+ edgeid_to_edgetype = {}
94
+ edgeid_to_reversemask = {}
95
+ for k, id_list in Parameters.edge_type_to_id.items():
96
+ for iid, mask in zip(id_list, Parameters.reverse_mask[k]):
97
+ edgeid_to_edgetype[str(iid)] = k
98
+ edgeid_to_reversemask[str(iid)] = mask
99
+
100
+ with open(target_path, 'rb') as fl:
101
+ Target_node_list = pkl.load(fl)
102
+ with open(attack_path, 'rb') as fl:
103
+ Attack_edge_list = pkl.load(fl)
104
+ with open(Parameters.UMLSfile+'drug_term', 'rb') as fl:
105
+ drug_term = pkl.load(fl)
106
+ with open(Parameters.GNBRfile+'entity_raw_name', 'rb') as fl:
107
+ entity_raw_name = pkl.load(fl)
108
+ drug_meshid = []
109
+ for meshid, nm in entity_raw_name.items():
110
+ if nm.lower() in drug_term and meshid.split('_')[0] == 'chemical':
111
+ drug_meshid.append(meshid)
112
+ drug_meshid = set(drug_meshid)
113
+
114
+ if args.init_mode == 'single':
115
+ name_list = []
116
+ for target in Target_node_list:
117
+ name = entity_raw_name[idtomeshid[str(target)]]
118
+ name_list.append(name)
119
+ with open(f'results/name_list_{args.reasonable_rate}{args.init_mode}.txt', 'w') as fl:
120
+ fl.write('\n'.join(name_list))
121
+ # print(Target_node_list)
122
+ # # print(Attack_edge_list)
123
+ # addset = set()
124
+ # if args.added_edge_num == 1:
125
+ # for edge in Attack_edge_list:
126
+ # addset.add(edge[2])
127
+ # else:
128
+ # for edge_list in Attack_edge_list:
129
+ # for edge in edge_list:
130
+ # addset.add(edge[2])
131
+ # print(addset)
132
+ # print(len(addset))
133
+ # typeset = set()
134
+ # for iid in addset:
135
+ # typeset.add(idtomeshid[str(iid)].split('_')[0])
136
+ # print(typeset)
137
+ # raise Exception('done')
138
+
139
+ if args.init_mode == 'single':
140
+ Target_node_list = [[Target_node_list[i]] for i in range(len(Target_node_list))]
141
+ Attack_edge_list = [[Attack_edge_list[i]] for i in range(len(Attack_edge_list))]
142
+ else:
143
+ print(len(Attack_edge_list), len(Target_node_list))
144
+ tmp_target_node_list = []
145
+ tmp_attack_edge_list = []
146
+ for l in range(0,len(Target_node_list), 50):
147
+ r = min(l+50, len(Target_node_list))
148
+ tmp_target_node_list.append(Target_node_list[l:r])
149
+ tmp_attack_edge_list.append(Attack_edge_list[l:r])
150
+ Target_node_list = tmp_target_node_list
151
+ Attack_edge_list = tmp_attack_edge_list
152
+
153
+ # for i, init_p in enumerate([0.1, 0.3, 0.5, 0.7, 0.9]):
154
+
155
+ # target_node_list = Target_node_list[i]
156
+ # attack_edge_list = Attack_edge_list[i]
157
+ Init = []
158
+ After = []
159
+ # final_init = []
160
+ # final_after = []
161
+ for i, (target_node_list, attack_edge_list) in enumerate(zip(Target_node_list, Attack_edge_list)):
162
+
163
+ G = nx.DiGraph()
164
+ for s, r, o in graph_edge:
165
+ assert idtomeshid[s].split('_')[0] == edgeid_to_edgetype[r].split('-')[0]
166
+ if edgeid_to_reversemask[r] == 1:
167
+ G.add_edge(int(o), int(s))
168
+ else:
169
+ G.add_edge(int(s), int(o))
170
+
171
+ pagerank_value_1 = nx.pagerank(G, max_iter = 200, tol=1.0e-7)
172
+
173
+ for target, attack_list in tqdm(list(zip(target_node_list, attack_edge_list))):
174
+ pr = list(pagerank_value_1.items())
175
+ pr.sort(key = lambda x: x[1])
176
+ list_iid = []
177
+ for iid, score in pr:
178
+ tp = idtomeshid[str(iid)].split('_')[0]
179
+ if tp == 'chemical':
180
+ # if idtomeshid[str(iid)] in drug_meshid:
181
+ list_iid.append(iid)
182
+ init_rank = len(list_iid) - list_iid.index(target)
183
+ # init_rank = 1 - list_iid.index(target) / len(list_iid)
184
+ Init.append(init_rank)
185
+
186
+ for target, attack_list in tqdm(list(zip(target_node_list, attack_edge_list))):
187
+
188
+ if args.mode == '' and (args.added_edge_num == '' or int(args.added_edge_num) == 1):
189
+ if int(attack_list[0]) == -1:
190
+ attack_list = []
191
+ else:
192
+ attack_list = [attack_list]
193
+ if len(attack_list) > 0:
194
+ for s, r, o in attack_list:
195
+ bo, prob = check_reasonable(s, r, o)
196
+ if bo:
197
+ if edgeid_to_reversemask[str(r)] == 1:
198
+ G.add_edge(int(o), int(s))
199
+ else:
200
+ G.add_edge(int(s), int(o))
201
+ pagerank_value_1 = nx.pagerank(G, max_iter = 200, tol=1.0e-7)
202
+ for target, attack_list in tqdm(list(zip(target_node_list, attack_edge_list))):
203
+ pr = list(pagerank_value_1.items())
204
+ pr.sort(key = lambda x: x[1])
205
+ list_iid = []
206
+ for iid, score in pr:
207
+ tp = idtomeshid[str(iid)].split('_')[0]
208
+ if tp == 'chemical':
209
+ # if idtomeshid[str(iid)] in drug_meshid:
210
+ list_iid.append(iid)
211
+ after_rank = len(list_iid) - list_iid.index(target)
212
+ # after_rank = 1 - list_iid.index(target) / len(list_iid)
213
+ After.append(after_rank)
214
+ with open(f'results/Init_{args.reasonable_rate}{args.init_mode}.pkl', 'wb') as fl:
215
+ pkl.dump(Init, fl)
216
+ with open(f'results/After_{args.model}_{args.reasonable_rate}{args.init_mode}{args.added_edge_num}{args.mode}.pkl', 'wb') as fl:
217
+ pkl.dump(After, fl)
218
+ print(np.mean(Init), np.std(Init))
219
+ print(np.mean(After), np.std(After))
DiseaseAgnostic/generate_target_and_attack.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%%
2
+ import logging
3
+ from symbol import parameters
4
+ from textwrap import indent
5
+ import os
6
+ import tempfile
7
+ import sys
8
+ from matplotlib import collections
9
+ import pandas as pd
10
+ import json
11
+ from glob import glob
12
+ from tqdm import tqdm
13
+ import numpy as np
14
+ from pprint import pprint
15
+ import torch
16
+ import pickle as pkl
17
+ from collections import Counter
18
+ # print(dir(collections))
19
+ import networkx as nx
20
+ from collections import Counter
21
+ import utils
22
+ from torch.nn import functional as F
23
+ sys.path.append("..")
24
+ import Parameters
25
+ from DiseaseSpecific.attack import calculate_edge_bound, get_model_loss_without_softmax
26
+
27
+ #%%
28
+ def load_data(file_name):
29
+ df = pd.read_csv(file_name, sep='\t', header=None, names=None, dtype=str)
30
+ df = df.drop_duplicates()
31
+ return df.values
32
+
33
+ parser = utils.get_argument_parser()
34
+ parser.add_argument('--reasonable-rate', type = float, default=0.7, help = 'The added edge\'s existance rank prob greater than this rate')
35
+ parser.add_argument('--init-mode', type = str, default='single', help = 'How to select target nodes') # 'single' for case study
36
+ parser.add_argument('--added-edge-num', type = str, default = '', help = 'Added edge num')
37
+
38
+ args = parser.parse_args()
39
+ args = utils.set_hyperparams(args)
40
+ utils.seed_all(args.seed)
41
+ graph_edge_path = '../DiseaseSpecific/processed_data/GNBR/all.txt'
42
+ idtomeshid_path = '../DiseaseSpecific/processed_data/GNBR/entities_reverse_dict.json'
43
+ model_path = f'../DiseaseSpecific/saved_models/GNBR_{args.model}_128_0.2_0.3_0.3.model'
44
+ data_path = '../DiseaseSpecific/processed_data/GNBR'
45
+ with open(Parameters.GNBRfile+'original_entity_raw_name', 'rb') as fl:
46
+ full_entity_raw_name = pkl.load(fl)
47
+
48
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
+ args.device = device
50
+ n_ent, n_rel, ent_to_id, rel_to_id = utils.generate_dicts(data_path)
51
+ model = utils.load_model(model_path, args, n_ent, n_rel, args.device)
52
+ print(device)
53
+
54
+ graph_edge = utils.load_data(graph_edge_path)
55
+ with open(idtomeshid_path, 'r') as fl:
56
+ idtomeshid = json.load(fl)
57
+ print(graph_edge.shape, len(idtomeshid))
58
+
59
+ divide_bound, data_mean, data_std = calculate_edge_bound(graph_edge, model, args.device, n_ent)
60
+ print('Defender ...')
61
+ print(divide_bound, data_mean, data_std)
62
+
63
+ meshids = list(idtomeshid.values())
64
+ cal = {
65
+ 'chemical' : 0,
66
+ 'disease' : 0,
67
+ 'gene' : 0
68
+ }
69
+ for meshid in meshids:
70
+ cal[meshid.split('_')[0]] += 1
71
+ # pprint(cal)
72
+
73
+ def check_reasonable(s, r, o):
74
+
75
+ train_trip = np.asarray([[s, r, o]])
76
+ train_trip = torch.from_numpy(train_trip.astype('int64')).to(device)
77
+ edge_loss = get_model_loss_without_softmax(train_trip, model, device).squeeze()
78
+ # edge_losse_log_prob = torch.log(F.softmax(-edge_loss, dim = -1))
79
+
80
+ edge_loss = edge_loss.item()
81
+ edge_loss = (edge_loss - data_mean) / data_std
82
+ edge_losses_prob = 1 / ( 1 + np.exp(edge_loss - divide_bound) )
83
+ bound = 1 - args.reasonable_rate
84
+
85
+ return (edge_losses_prob > bound), edge_losses_prob
86
+
87
+ edgeid_to_edgetype = {}
88
+ edgeid_to_reversemask = {}
89
+ for k, id_list in Parameters.edge_type_to_id.items():
90
+ for iid, mask in zip(id_list, Parameters.reverse_mask[k]):
91
+ edgeid_to_edgetype[str(iid)] = k
92
+ edgeid_to_reversemask[str(iid)] = mask
93
+ reverse_tot = 0
94
+ G = nx.DiGraph()
95
+ for s, r, o in graph_edge:
96
+ assert idtomeshid[s].split('_')[0] == edgeid_to_edgetype[r].split('-')[0]
97
+ if edgeid_to_reversemask[r] == 1:
98
+ reverse_tot += 1
99
+ G.add_edge(int(o), int(s))
100
+ else:
101
+ G.add_edge(int(s), int(o))
102
+ # print(reverse_tot)
103
+ print('Edge num:', G.number_of_edges(), 'Node num:', G.number_of_nodes())
104
+ pagerank_value_1 = nx.pagerank(G, max_iter = 200, tol=1.0e-7)
105
+
106
+ #%%
107
+ with open(Parameters.UMLSfile+'drug_term', 'rb') as fl:
108
+ drug_term = pkl.load(fl)
109
+ with open(Parameters.GNBRfile+'entity_raw_name', 'rb') as fl:
110
+ entity_raw_name = pkl.load(fl)
111
+ drug_meshid = []
112
+ for meshid, nm in entity_raw_name.items():
113
+ if nm.lower() in drug_term and meshid.split('_')[0] == 'chemical':
114
+ drug_meshid.append(meshid)
115
+ drug_meshid = set(drug_meshid)
116
+ pr = list(pagerank_value_1.items())
117
+ pr.sort(key = lambda x: x[1])
118
+ sorted_rank = { 'chemical' : [],
119
+ 'gene' : [],
120
+ 'disease': [],
121
+ 'merged' : []}
122
+ for iid, score in pr:
123
+ tp = idtomeshid[str(iid)].split('_')[0]
124
+ if tp == 'chemical':
125
+ if idtomeshid[str(iid)] in drug_meshid:
126
+ sorted_rank[tp].append((iid, score))
127
+ else:
128
+ sorted_rank[tp].append((iid, score))
129
+ sorted_rank['merged'].append((iid, score))
130
+ llen = len(sorted_rank['merged'])
131
+ sorted_rank['merged'] = sorted_rank['merged'][llen * 3 // 4 : ]
132
+ print(len(sorted_rank['chemical']))
133
+ print(len(sorted_rank['gene']), len(sorted_rank['disease']), len(sorted_rank['merged']))
134
+
135
+ #%%
136
+ Target_node_list = []
137
+ Attack_edge_list = []
138
+ if args.init_mode == '':
139
+
140
+ if args.added_edge_num != '' and args.added_edge_num != '1':
141
+ raise Exception('added_edge_num must be 1 when init_mode=='' ')
142
+ for init_p in [0.1, 0.3, 0.5, 0.7, 0.9]:
143
+
144
+ p = len(sorted_rank['chemical']) * init_p
145
+ print('Init p:', init_p)
146
+ target_node_list = []
147
+ attack_edge_list = []
148
+ num_max_eq = 0
149
+ mean_rank_of_total_max = 0
150
+ for pp in tqdm(range(int(p)-10, int(p)+10)):
151
+ target = sorted_rank['chemical'][pp][0]
152
+ target_node_list.append(target)
153
+
154
+ candidate_list = []
155
+ score_list = []
156
+ loss_list = []
157
+ for iid, score in sorted_rank['merged']:
158
+ a = G.number_of_edges(iid, target) + 1
159
+ if a != 1:
160
+ continue
161
+ b = G.out_degree(iid) + 1
162
+ tp = idtomeshid[str(iid)].split('_')[0]
163
+ edge_losses = []
164
+ r_list = []
165
+ for r in range(len(edgeid_to_edgetype)):
166
+ r_tp = edgeid_to_edgetype[str(r)]
167
+ if (edgeid_to_reversemask[str(r)] == 0 and r_tp.split('-')[0] == tp and r_tp.split('-')[1] == 'chemical'):
168
+ train_trip = np.array([[iid, r, target]])
169
+ train_trip = torch.from_numpy(train_trip.astype('int64')).to(device)
170
+ edge_loss = get_model_loss_without_softmax(train_trip, model, device).squeeze()
171
+ edge_losses.append(edge_loss.unsqueeze(0).detach())
172
+ r_list.append(r)
173
+ elif(edgeid_to_reversemask[str(r)] == 1 and r_tp.split('-')[0] == 'chemical' and r_tp.split('-')[1] == tp):
174
+ train_trip = np.array([[iid, r, target]]) # add batch dim
175
+ train_trip = torch.from_numpy(train_trip.astype('int64')).to(device)
176
+ edge_loss = get_model_loss_without_softmax(train_trip, model, device).squeeze()
177
+ edge_losses.append(edge_loss.unsqueeze(0).detach())
178
+ r_list.append(r)
179
+ if len(edge_losses)==0:
180
+ continue
181
+ min_index = torch.argmin(torch.cat(edge_losses, dim = 0))
182
+ r = r_list[min_index]
183
+ r_tp = edgeid_to_edgetype[str(r)]
184
+
185
+ if (edgeid_to_reversemask[str(r)] == 0):
186
+ bo, prob = check_reasonable(iid, r, target)
187
+ if bo:
188
+ candidate_list.append((iid, r, target))
189
+ score_list.append(score * a / b)
190
+ loss_list.append(edge_losses[min_index].item())
191
+ if (edgeid_to_reversemask[str(r)] == 1):
192
+ bo, prob = check_reasonable(target, r, iid)
193
+ if bo:
194
+ candidate_list.append((target, r, iid))
195
+ score_list.append(score * a / b)
196
+ loss_list.append(edge_losses[min_index].item())
197
+
198
+ if len(candidate_list) == 0:
199
+ attack_edge_list.append((-1, -1, -1))
200
+ continue
201
+ norm_score = np.array(score_list) / np.sum(score_list)
202
+ norm_loss = np.exp(-np.array(loss_list)) / np.sum(np.exp(-np.array(loss_list)))
203
+
204
+ total_score = norm_score * norm_loss
205
+ max_index = np.argmax(total_score)
206
+ attack_edge_list.append(candidate_list[max_index])
207
+
208
+ score_max_index = np.argmax(norm_score)
209
+ if score_max_index == max_index:
210
+ num_max_eq += 1
211
+
212
+ score_index_list = list(zip(list(range(len(norm_score))), norm_score))
213
+ score_index_list.sort(key = lambda x: x[1], reverse = True)
214
+ max_index_in_score = score_index_list.index((max_index, norm_score[max_index]))
215
+ mean_rank_of_total_max += max_index_in_score / len(norm_score)
216
+ print('num_max_eq:', num_max_eq)
217
+ print('mean_rank_of_total_max:', mean_rank_of_total_max / 20)
218
+ Target_node_list.append(target_node_list)
219
+ Attack_edge_list.append(attack_edge_list)
220
+ else:
221
+ assert args.init_mode == 'random' or args.init_mode == 'single'
222
+ print(f'Init mode : {args.init_mode}')
223
+ utils.seed_all(args.seed)
224
+
225
+ if args.init_mode == 'random':
226
+ index = np.random.choice(len(sorted_rank['chemical']), 400, replace = False)
227
+ else:
228
+ # index = [5807, 6314, 5799, 5831, 3954, 5654, 5649, 5624, 2412, 2407]
229
+
230
+ index = np.random.choice(len(sorted_rank['chemical']), 400, replace = False)
231
+ with open(f'../pagerank/results/After_distmult_0.7random10.pkl', 'rb') as fl:
232
+ edge = pkl.load(fl)
233
+ with open('../pagerank/results/Init_0.7random.pkl', 'rb') as fl:
234
+ init = pkl.load(fl)
235
+ increase = (np.array(init) - np.array(edge)) / np.array(init)
236
+ increase = increase.reshape(-1)
237
+ selected_index = np.argsort(increase)[::-1][:10]
238
+ # print(selected_index)
239
+ # print(increase[selected_index])
240
+ # print(np.array(init)[selected_index])
241
+ # print(np.array(edge)[selected_index])
242
+ index = [index[i] for i in selected_index]
243
+ # llen = len(sorted_rank['chemical'])
244
+ # index = np.random.choice(range(llen//4, llen), 4, replace = False)
245
+ # index = selected_index + list(index)
246
+ # for i in index:
247
+ # ii = str(sorted_rank['chemical'][i][0])
248
+ # nm = entity_raw_name[idtomeshid[ii]]
249
+ # nmset = full_entity_raw_name[idtomeshid[ii]]
250
+ # print('**'*10)
251
+ # print(i)
252
+ # print(nm)
253
+ # print(nmset)
254
+ # raise Exception('stop')
255
+ target_node_list = []
256
+ attack_edge_list = []
257
+ num_max_eq = 0
258
+ mean_rank_of_total_max = 0
259
+
260
+ for pp in tqdm(index):
261
+ target = sorted_rank['chemical'][pp][0]
262
+ target_node_list.append(target)
263
+
264
+ print('Target:', entity_raw_name[idtomeshid[str(target)]])
265
+
266
+ candidate_list = []
267
+ score_list = []
268
+ loss_list = []
269
+ main_dict = {}
270
+ for iid, score in sorted_rank['merged']:
271
+ a = G.number_of_edges(iid, target) + 1
272
+ if a != 1:
273
+ continue
274
+ b = G.out_degree(iid) + 1
275
+ tp = idtomeshid[str(iid)].split('_')[0]
276
+ edge_losses = []
277
+ r_list = []
278
+ for r in range(len(edgeid_to_edgetype)):
279
+ r_tp = edgeid_to_edgetype[str(r)]
280
+ if (edgeid_to_reversemask[str(r)] == 0 and r_tp.split('-')[0] == tp and r_tp.split('-')[1] == 'chemical'):
281
+ train_trip = np.array([[iid, r, target]])
282
+ train_trip = torch.from_numpy(train_trip.astype('int64')).to(device)
283
+ edge_loss = get_model_loss_without_softmax(train_trip, model, device).squeeze()
284
+ edge_losses.append(edge_loss.unsqueeze(0).detach())
285
+ r_list.append(r)
286
+ elif(edgeid_to_reversemask[str(r)] == 1 and r_tp.split('-')[0] == 'chemical' and r_tp.split('-')[1] == tp):
287
+ train_trip = np.array([[iid, r, target]]) # add batch dim
288
+ train_trip = torch.from_numpy(train_trip.astype('int64')).to(device)
289
+ edge_loss = get_model_loss_without_softmax(train_trip, model, device).squeeze()
290
+ edge_losses.append(edge_loss.unsqueeze(0).detach())
291
+ r_list.append(r)
292
+ if len(edge_losses)==0:
293
+ continue
294
+ min_index = torch.argmin(torch.cat(edge_losses, dim = 0))
295
+ r = r_list[min_index]
296
+ r_tp = edgeid_to_edgetype[str(r)]
297
+
298
+
299
+ old_len = len(candidate_list)
300
+ if (edgeid_to_reversemask[str(r)] == 0):
301
+ bo, prob = check_reasonable(iid, r, target)
302
+ if bo:
303
+ candidate_list.append((iid, r, target))
304
+ score_list.append(score * a / b)
305
+ loss_list.append(edge_losses[min_index].item())
306
+ if (edgeid_to_reversemask[str(r)] == 1):
307
+ bo, prob = check_reasonable(target, r, iid)
308
+ if bo:
309
+ candidate_list.append((target, r, iid))
310
+ score_list.append(score * a / b)
311
+ loss_list.append(edge_losses[min_index].item())
312
+
313
+ if len(candidate_list) != old_len:
314
+ if int(iid) in main_iid:
315
+ main_dict[iid] = len(candidate_list) - 1
316
+
317
+ if len(candidate_list) == 0:
318
+ if args.added_edge_num == '' or int(args.added_edge_num) == 1:
319
+ attack_edge_list.append((-1,-1,-1))
320
+ else:
321
+ attack_edge_list.append([])
322
+ continue
323
+ norm_score = np.array(score_list) / np.sum(score_list)
324
+ norm_loss = np.exp(-np.array(loss_list)) / np.sum(np.exp(-np.array(loss_list)))
325
+
326
+ total_score = norm_score * norm_loss
327
+ total_score_index = list(zip(range(len(total_score)), total_score))
328
+ total_score_index.sort(key = lambda x: x[1], reverse = True)
329
+
330
+ norm_score_index = np.argsort(norm_score)[::-1]
331
+ norm_loss_index = np.argsort(norm_loss)[::-1]
332
+ total_index = np.argsort(total_score)[::-1]
333
+ assert total_index[0] == total_score_index[0][0]
334
+ # find rank of main index
335
+ for k, v in main_dict.items():
336
+ k = int(k)
337
+ index = v
338
+ print(f'score rank of {entity_raw_name[idtomeshid[str(k)]]}: ', norm_score_index.tolist().index(index))
339
+ print(f'loss rank of {entity_raw_name[idtomeshid[str(k)]]}: ', norm_loss_index.tolist().index(index))
340
+ print(f'total rank of {entity_raw_name[idtomeshid[str(k)]]}: ', total_index.tolist().index(index))
341
+
342
+ max_index = np.argmax(total_score)
343
+ assert max_index == total_score_index[0][0]
344
+
345
+ tmp_add = []
346
+ add_num = 1
347
+ if args.added_edge_num == '' or int(args.added_edge_num) == 1:
348
+ attack_edge_list.append(candidate_list[max_index])
349
+ else:
350
+ add_num = int(args.added_edge_num)
351
+ for i in range(add_num):
352
+ tmp_add.append(candidate_list[total_score_index[i][0]])
353
+ attack_edge_list.append(tmp_add)
354
+
355
+ score_max_index = np.argmax(norm_score)
356
+ if score_max_index == max_index:
357
+ num_max_eq += 1
358
+ score_index_list = list(zip(list(range(len(norm_score))), norm_score))
359
+ score_index_list.sort(key = lambda x: x[1], reverse = True)
360
+ max_index_in_score = score_index_list.index((max_index, norm_score[max_index]))
361
+ mean_rank_of_total_max += max_index_in_score / len(norm_score)
362
+ print('num_max_eq:', num_max_eq)
363
+ print('mean_rank_of_total_max:', mean_rank_of_total_max / 400)
364
+ Target_node_list = target_node_list
365
+ Attack_edge_list = attack_edge_list
366
+ print(np.array(Target_node_list).shape)
367
+ print(np.array(Attack_edge_list).shape)
368
+ # with open(f'processed_data/target_{args.reasonable_rate}{args.init_mode}.pkl', 'wb') as fl:
369
+ # pkl.dump(Target_node_list, fl)
370
+ # with open(f'processed_data/attack_edge_{args.model}_{args.reasonable_rate}{args.init_mode}{args.added_edge_num}.pkl', 'wb') as fl:
371
+ # pkl.dump(Attack_edge_list, fl)
DiseaseAgnostic/model.py ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import functional as F, Parameter
3
+ from torch.autograd import Variable
4
+ from torch.nn.init import xavier_normal_, xavier_uniform_
5
+ from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
6
+
7
+ class Distmult(torch.nn.Module):
8
+ def __init__(self, args, num_entities, num_relations):
9
+ super(Distmult, self).__init__()
10
+
11
+ if args.max_norm:
12
+ self.emb_e = torch.nn.Embedding(num_entities, args.embedding_dim, max_norm=1.0)
13
+ self.emb_rel = torch.nn.Embedding(num_relations, args.embedding_dim)
14
+ else:
15
+ self.emb_e = torch.nn.Embedding(num_entities, args.embedding_dim, padding_idx=None)
16
+ self.emb_rel = torch.nn.Embedding(num_relations, args.embedding_dim, padding_idx=None)
17
+
18
+ self.inp_drop = torch.nn.Dropout(args.input_drop)
19
+ self.loss = torch.nn.CrossEntropyLoss()
20
+
21
+ self.init()
22
+
23
+ def init(self):
24
+ xavier_normal_(self.emb_e.weight)
25
+ xavier_normal_(self.emb_rel.weight)
26
+
27
+ def score_sr(self, sub, rel, sigmoid = False):
28
+ sub_emb = self.emb_e(sub).squeeze(dim=1)
29
+ rel_emb = self.emb_rel(rel).squeeze(dim=1)
30
+
31
+ #sub_emb = self.inp_drop(sub_emb)
32
+ #rel_emb = self.inp_drop(rel_emb)
33
+
34
+ pred = torch.mm(sub_emb*rel_emb, self.emb_e.weight.transpose(1,0))
35
+ if sigmoid:
36
+ pred = torch.sigmoid(pred)
37
+ return pred
38
+
39
+ def score_or(self, obj, rel, sigmoid = False):
40
+ obj_emb = self.emb_e(obj).squeeze(dim=1)
41
+ rel_emb = self.emb_rel(rel).squeeze(dim=1)
42
+
43
+ #obj_emb = self.inp_drop(obj_emb)
44
+ #rel_emb = self.inp_drop(rel_emb)
45
+
46
+ pred = torch.mm(obj_emb*rel_emb, self.emb_e.weight.transpose(1,0))
47
+ if sigmoid:
48
+ pred = torch.sigmoid(pred)
49
+ return pred
50
+
51
+
52
+ def forward(self, sub_emb, rel_emb, mode='rhs', sigmoid=False):
53
+ '''
54
+ When mode is 'rhs' we expect (s,r); for 'lhs', we expect (o,r)
55
+ For distmult, computations for both modes are equivalent, so we do not need if-else block
56
+ '''
57
+ sub_emb = self.inp_drop(sub_emb)
58
+ rel_emb = self.inp_drop(rel_emb)
59
+
60
+ pred = torch.mm(sub_emb*rel_emb, self.emb_e.weight.transpose(1,0))
61
+
62
+ if sigmoid:
63
+ pred = torch.sigmoid(pred)
64
+
65
+ return pred
66
+
67
+ def score_triples(self, sub, rel, obj, sigmoid=False):
68
+ '''
69
+ Inputs - subject, relation, object
70
+ Return - score
71
+ '''
72
+ sub_emb = self.emb_e(sub).squeeze(dim=1)
73
+ rel_emb = self.emb_rel(rel).squeeze(dim=1)
74
+ obj_emb = self.emb_e(obj).squeeze(dim=1)
75
+
76
+ pred = torch.sum(sub_emb*rel_emb*obj_emb, dim=-1)
77
+
78
+ if sigmoid:
79
+ pred = torch.sigmoid(pred)
80
+
81
+ return pred
82
+
83
+ def score_emb(self, emb_s, emb_r, emb_o, sigmoid=False):
84
+ '''
85
+ Inputs - embeddings of subject, relation, object
86
+ Return - score
87
+ '''
88
+ pred = torch.sum(emb_s*emb_r*emb_o, dim=-1)
89
+
90
+ if sigmoid:
91
+ pred = torch.sigmoid(pred)
92
+
93
+ return pred
94
+
95
+ def score_triples_vec(self, sub, rel, obj, sigmoid=False):
96
+ '''
97
+ Inputs - subject, relation, object
98
+ Return - a vector score for the triple instead of reducing over the embedding dimension
99
+ '''
100
+ sub_emb = self.emb_e(sub).squeeze(dim=1)
101
+ rel_emb = self.emb_rel(rel).squeeze(dim=1)
102
+ obj_emb = self.emb_e(obj).squeeze(dim=1)
103
+
104
+ pred = sub_emb*rel_emb*obj_emb
105
+
106
+ if sigmoid:
107
+ pred = torch.sigmoid(pred)
108
+
109
+ return pred
110
+
111
+ class Complex(torch.nn.Module):
112
+ def __init__(self, args, num_entities, num_relations):
113
+ super(Complex, self).__init__()
114
+
115
+ if args.max_norm:
116
+ self.emb_e = torch.nn.Embedding(num_entities, 2*args.embedding_dim, max_norm=1.0)
117
+ self.emb_rel = torch.nn.Embedding(num_relations, 2*args.embedding_dim)
118
+ else:
119
+ self.emb_e = torch.nn.Embedding(num_entities, 2*args.embedding_dim, padding_idx=None)
120
+ self.emb_rel = torch.nn.Embedding(num_relations, 2*args.embedding_dim, padding_idx=None)
121
+
122
+ self.inp_drop = torch.nn.Dropout(args.input_drop)
123
+ self.loss = torch.nn.CrossEntropyLoss()
124
+
125
+ self.init()
126
+
127
+ def init(self):
128
+ xavier_normal_(self.emb_e.weight)
129
+ xavier_normal_(self.emb_rel.weight)
130
+
131
+ def score_sr(self, sub, rel, sigmoid = False):
132
+ sub_emb = self.emb_e(sub).squeeze(dim=1)
133
+ rel_emb = self.emb_rel(rel).squeeze(dim=1)
134
+
135
+ s_real, s_img = torch.chunk(rel_emb, 2, dim=-1)
136
+ rel_real, rel_img = torch.chunk(sub_emb, 2, dim=-1)
137
+ emb_e_real, emb_e_img = torch.chunk(self.emb_e.weight, 2, dim=-1)
138
+
139
+ #s_real = self.inp_drop(s_real)
140
+ #s_img = self.inp_drop(s_img)
141
+ #rel_real = self.inp_drop(rel_real)
142
+ #rel_img = self.inp_drop(rel_img)
143
+
144
+ # complex space bilinear product (equivalent to HolE)
145
+ # realrealreal = torch.mm(s_real*rel_real, emb_e_real.transpose(1,0))
146
+ # realimgimg = torch.mm(s_real*rel_img, emb_e_img.transpose(1,0))
147
+ # imgrealimg = torch.mm(s_img*rel_real, emb_e_img.transpose(1,0))
148
+ # imgimgreal = torch.mm(s_img*rel_img, emb_e_real.transpose(1,0))
149
+ # pred = realrealreal + realimgimg + imgrealimg - imgimgreal
150
+
151
+ realo_realreal = s_real*rel_real
152
+ realo_imgimg = s_img*rel_img
153
+ realo = realo_realreal - realo_imgimg
154
+ real = torch.mm(realo, emb_e_real.transpose(1,0))
155
+
156
+ imgo_realimg = s_real*rel_img
157
+ imgo_imgreal = s_img*rel_real
158
+ imgo = imgo_realimg + imgo_imgreal
159
+ img = torch.mm(imgo, emb_e_img.transpose(1,0))
160
+
161
+ pred = real + img
162
+
163
+ if sigmoid:
164
+ pred = torch.sigmoid(pred)
165
+ return pred
166
+
167
+
168
+ def score_or(self, obj, rel, sigmoid = False):
169
+ obj_emb = self.emb_e(obj).squeeze(dim=1)
170
+ rel_emb = self.emb_rel(rel).squeeze(dim=1)
171
+
172
+ rel_real, rel_img = torch.chunk(rel_emb, 2, dim=-1)
173
+ o_real, o_img = torch.chunk(obj_emb, 2, dim=-1)
174
+ emb_e_real, emb_e_img = torch.chunk(self.emb_e.weight, 2, dim=-1)
175
+
176
+ #rel_real = self.inp_drop(rel_real)
177
+ #rel_img = self.inp_drop(rel_img)
178
+ #o_real = self.inp_drop(o_real)
179
+ #o_img = self.inp_drop(o_img)
180
+
181
+ # complex space bilinear product (equivalent to HolE)
182
+ # realrealreal = torch.mm(rel_real*o_real, emb_e_real.transpose(1,0))
183
+ # realimgimg = torch.mm(rel_img*o_img, emb_e_real.transpose(1,0))
184
+ # imgrealimg = torch.mm(rel_real*o_img, emb_e_img.transpose(1,0))
185
+ # imgimgreal = torch.mm(rel_img*o_real, emb_e_img.transpose(1,0))
186
+ # pred = realrealreal + realimgimg + imgrealimg - imgimgreal
187
+
188
+ reals_realreal = rel_real*o_real
189
+ reals_imgimg = rel_img*o_img
190
+ reals = reals_realreal + reals_imgimg
191
+ real = torch.mm(reals, emb_e_real.transpose(1,0))
192
+
193
+ imgs_realimg = rel_real*o_img
194
+ imgs_imgreal = rel_img*o_real
195
+ imgs = imgs_realimg - imgs_imgreal
196
+ img = torch.mm(imgs, emb_e_img.transpose(1,0))
197
+
198
+ pred = real + img
199
+
200
+ if sigmoid:
201
+ pred = torch.sigmoid(pred)
202
+ return pred
203
+
204
+
205
+ def forward(self, sub_emb, rel_emb, mode='rhs', sigmoid=False):
206
+ '''
207
+ When mode is 'rhs' we expect (s,r); for 'lhs', we expect (o,r)
208
+
209
+ '''
210
+ if mode == 'lhs':
211
+ rel_real, rel_img = torch.chunk(rel_emb, 2, dim=-1)
212
+ o_real, o_img = torch.chunk(sub_emb, 2, dim=-1)
213
+ emb_e_real, emb_e_img = torch.chunk(self.emb_e.weight, 2, dim=-1)
214
+
215
+ rel_real = self.inp_drop(rel_real)
216
+ rel_img = self.inp_drop(rel_img)
217
+ o_real = self.inp_drop(o_real)
218
+ o_img = self.inp_drop(o_img)
219
+
220
+ # complex space bilinear product (equivalent to HolE)
221
+ # realrealreal = torch.mm(rel_real*o_real, emb_e_real.transpose(1,0))
222
+ # realimgimg = torch.mm(rel_img*o_img, emb_e_real.transpose(1,0))
223
+ # imgrealimg = torch.mm(rel_real*o_img, emb_e_img.transpose(1,0))
224
+ # imgimgreal = torch.mm(rel_img*o_real, emb_e_img.transpose(1,0))
225
+ # pred = realrealreal + realimgimg + imgrealimg - imgimgreal
226
+ reals_realreal = rel_real*o_real
227
+ reals_imgimg = rel_img*o_img
228
+ reals = reals_realreal + reals_imgimg
229
+ real = torch.mm(reals, emb_e_real.transpose(1,0))
230
+
231
+ imgs_realimg = rel_real*o_img
232
+ imgs_imgreal = rel_img*o_real
233
+ imgs = imgs_realimg - imgs_imgreal
234
+ img = torch.mm(imgs, emb_e_img.transpose(1,0))
235
+
236
+ pred = real + img
237
+
238
+ else:
239
+ s_real, s_img = torch.chunk(rel_emb, 2, dim=-1)
240
+ rel_real, rel_img = torch.chunk(sub_emb, 2, dim=-1)
241
+ emb_e_real, emb_e_img = torch.chunk(self.emb_e.weight, 2, dim=-1)
242
+
243
+ s_real = self.inp_drop(s_real)
244
+ s_img = self.inp_drop(s_img)
245
+ rel_real = self.inp_drop(rel_real)
246
+ rel_img = self.inp_drop(rel_img)
247
+
248
+ # complex space bilinear product (equivalent to HolE)
249
+ # realrealreal = torch.mm(s_real*rel_real, emb_e_real.transpose(1,0))
250
+ # realimgimg = torch.mm(s_real*rel_img, emb_e_img.transpose(1,0))
251
+ # imgrealimg = torch.mm(s_img*rel_real, emb_e_img.transpose(1,0))
252
+ # imgimgreal = torch.mm(s_img*rel_img, emb_e_real.transpose(1,0))
253
+ # pred = realrealreal + realimgimg + imgrealimg - imgimgreal
254
+
255
+ realo_realreal = s_real*rel_real
256
+ realo_imgimg = s_img*rel_img
257
+ realo = realo_realreal - realo_imgimg
258
+ real = torch.mm(realo, emb_e_real.transpose(1,0))
259
+
260
+ imgo_realimg = s_real*rel_img
261
+ imgo_imgreal = s_img*rel_real
262
+ imgo = imgo_realimg + imgo_imgreal
263
+ img = torch.mm(imgo, emb_e_img.transpose(1,0))
264
+
265
+ pred = real + img
266
+
267
+ if sigmoid:
268
+ pred = torch.sigmoid(pred)
269
+
270
+ return pred
271
+
272
+ def score_triples(self, sub, rel, obj, sigmoid=False):
273
+ '''
274
+ Inputs - subject, relation, object
275
+ Return - score
276
+ '''
277
+ sub_emb = self.emb_e(sub).squeeze(dim=1)
278
+ rel_emb = self.emb_rel(rel).squeeze(dim=1)
279
+ obj_emb = self.emb_e(obj).squeeze(dim=1)
280
+
281
+ s_real, s_img = torch.chunk(sub_emb, 2, dim=-1)
282
+ rel_real, rel_img = torch.chunk(rel_emb, 2, dim=-1)
283
+ o_real, o_img = torch.chunk(obj_emb, 2, dim=-1)
284
+
285
+ realrealreal = torch.sum(s_real*rel_real*o_real, dim=-1)
286
+ realimgimg = torch.sum(s_real*rel_img*o_img, axis=-1)
287
+ imgrealimg = torch.sum(s_img*rel_real*o_img, axis=-1)
288
+ imgimgreal = torch.sum(s_img*rel_img*o_real, axis=-1)
289
+
290
+ pred = realrealreal + realimgimg + imgrealimg - imgimgreal
291
+
292
+ if sigmoid:
293
+ pred = torch.sigmoid(pred)
294
+
295
+ return pred
296
+
297
+ def score_emb(self, emb_s, emb_r, emb_o, sigmoid=False):
298
+ '''
299
+ Inputs - embeddings of subject, relation, object
300
+ Return - score
301
+ '''
302
+
303
+ s_real, s_img = torch.chunk(emb_s, 2, dim=-1)
304
+ rel_real, rel_img = torch.chunk(emb_r, 2, dim=-1)
305
+ o_real, o_img = torch.chunk(emb_o, 2, dim=-1)
306
+
307
+ realrealreal = torch.sum(s_real*rel_real*o_real, dim=-1)
308
+ realimgimg = torch.sum(s_real*rel_img*o_img, axis=-1)
309
+ imgrealimg = torch.sum(s_img*rel_real*o_img, axis=-1)
310
+ imgimgreal = torch.sum(s_img*rel_img*o_real, axis=-1)
311
+
312
+ pred = realrealreal + realimgimg + imgrealimg - imgimgreal
313
+
314
+ if sigmoid:
315
+ pred = torch.sigmoid(pred)
316
+
317
+ return pred
318
+
319
+ def score_triples_vec(self, sub, rel, obj, sigmoid=False):
320
+ '''
321
+ Inputs - subject, relation, object
322
+ Return - a vector score for the triple instead of reducing over the embedding dimension
323
+ '''
324
+ sub_emb = self.emb_e(sub).squeeze(dim=1)
325
+ rel_emb = self.emb_rel(rel).squeeze(dim=1)
326
+ obj_emb = self.emb_e(obj).squeeze(dim=1)
327
+
328
+ s_real, s_img = torch.chunk(sub_emb, 2, dim=-1)
329
+ rel_real, rel_img = torch.chunk(rel_emb, 2, dim=-1)
330
+ o_real, o_img = torch.chunk(obj_emb, 2, dim=-1)
331
+
332
+ realrealreal = s_real*rel_real*o_real
333
+ realimgimg = s_real*rel_img*o_img
334
+ imgrealimg = s_img*rel_real*o_img
335
+ imgimgreal = s_img*rel_img*o_real
336
+
337
+ pred = realrealreal + realimgimg + imgrealimg - imgimgreal
338
+
339
+ if sigmoid:
340
+ pred = torch.sigmoid(pred)
341
+
342
+ return pred
343
+
344
+ class Conve(torch.nn.Module):
345
+
346
+ #Too slow !!!!
347
+
348
+ def __init__(self, args, num_entities, num_relations):
349
+ super(Conve, self).__init__()
350
+
351
+ if args.max_norm:
352
+ self.emb_e = torch.nn.Embedding(num_entities, args.embedding_dim, max_norm=1.0)
353
+ self.emb_rel = torch.nn.Embedding(num_relations, args.embedding_dim)
354
+ else:
355
+ self.emb_e = torch.nn.Embedding(num_entities, args.embedding_dim, padding_idx=None)
356
+ self.emb_rel = torch.nn.Embedding(num_relations, args.embedding_dim, padding_idx=None)
357
+
358
+ self.inp_drop = torch.nn.Dropout(args.input_drop)
359
+ self.hidden_drop = torch.nn.Dropout(args.hidden_drop)
360
+ self.feature_drop = torch.nn.Dropout2d(args.feat_drop)
361
+
362
+ self.embedding_dim = args.embedding_dim #default is 200
363
+ self.num_filters = args.num_filters # default is 32
364
+ self.kernel_size = args.kernel_size # default is 3
365
+ self.stack_width = args.stack_width # default is 20
366
+ self.stack_height = args.embedding_dim // self.stack_width
367
+
368
+ self.bn0 = torch.nn.BatchNorm2d(1)
369
+ self.bn1 = torch.nn.BatchNorm2d(self.num_filters)
370
+ self.bn2 = torch.nn.BatchNorm1d(args.embedding_dim)
371
+
372
+ self.conv1 = torch.nn.Conv2d(1, out_channels=self.num_filters,
373
+ kernel_size=(self.kernel_size, self.kernel_size),
374
+ stride=1, padding=0, bias=args.use_bias)
375
+ #self.conv1 = torch.nn.Conv2d(1, 32, (3, 3), 1, 0, bias=args.use_bias) # <-- default
376
+
377
+ flat_sz_h = int(2*self.stack_width) - self.kernel_size + 1
378
+ flat_sz_w = self.stack_height - self.kernel_size + 1
379
+ self.flat_sz = flat_sz_h*flat_sz_w*self.num_filters
380
+ self.fc = torch.nn.Linear(self.flat_sz, args.embedding_dim)
381
+
382
+ self.register_parameter('b', Parameter(torch.zeros(num_entities)))
383
+ self.loss = torch.nn.CrossEntropyLoss()
384
+
385
+ self.init()
386
+
387
+ def init(self):
388
+ xavier_normal_(self.emb_e.weight)
389
+ xavier_normal_(self.emb_rel.weight)
390
+
391
+ def concat(self, e1_embed, rel_embed, form='plain'):
392
+ if form == 'plain':
393
+ e1_embed = e1_embed. view(-1, 1, self.stack_width, self.stack_height)
394
+ rel_embed = rel_embed.view(-1, 1, self.stack_width, self.stack_height)
395
+ stack_inp = torch.cat([e1_embed, rel_embed], 2)
396
+
397
+ elif form == 'alternate':
398
+ e1_embed = e1_embed. view(-1, 1, self.embedding_dim)
399
+ rel_embed = rel_embed.view(-1, 1, self.embedding_dim)
400
+ stack_inp = torch.cat([e1_embed, rel_embed], 1)
401
+ stack_inp = torch.transpose(stack_inp, 2, 1).reshape((-1, 1, 2*self.stack_width, self.stack_height))
402
+
403
+ else: raise NotImplementedError
404
+ return stack_inp
405
+
406
+ def conve_architecture(self, sub_emb, rel_emb):
407
+ stacked_inputs = self.concat(sub_emb, rel_emb)
408
+ stacked_inputs = self.bn0(stacked_inputs)
409
+ x = self.inp_drop(stacked_inputs)
410
+ x = self.conv1(x)
411
+ x = self.bn1(x)
412
+ x = F.relu(x)
413
+ x = self.feature_drop(x)
414
+ #x = x.view(x.shape[0], -1)
415
+ x = x.view(-1, self.flat_sz)
416
+ x = self.fc(x)
417
+ x = self.hidden_drop(x)
418
+ x = self.bn2(x)
419
+ x = F.relu(x)
420
+
421
+ return x
422
+
423
+ def score_sr(self, sub, rel, sigmoid = False):
424
+ sub_emb = self.emb_e(sub)
425
+ rel_emb = self.emb_rel(rel)
426
+
427
+ x = self.conve_architecture(sub_emb, rel_emb)
428
+
429
+ pred = torch.mm(x, self.emb_e.weight.transpose(1,0))
430
+ pred += self.b.expand_as(pred)
431
+
432
+ if sigmoid:
433
+ pred = torch.sigmoid(pred)
434
+ return pred
435
+
436
+ def score_or(self, obj, rel, sigmoid = False):
437
+ obj_emb = self.emb_e(obj)
438
+ rel_emb = self.emb_rel(rel)
439
+
440
+ x = self.conve_architecture(obj_emb, rel_emb)
441
+ pred = torch.mm(x, self.emb_e.weight.transpose(1,0))
442
+ pred += self.b.expand_as(pred)
443
+
444
+ if sigmoid:
445
+ pred = torch.sigmoid(pred)
446
+ return pred
447
+
448
+
449
+ def forward(self, sub_emb, rel_emb, mode='rhs', sigmoid=False):
450
+ '''
451
+ When mode is 'rhs' we expect (s,r); for 'lhs', we expect (o,r)
452
+ For conve, computations for both modes are equivalent, so we do not need if-else block
453
+ '''
454
+ x = self.conve_architecture(sub_emb, rel_emb)
455
+
456
+ pred = torch.mm(x, self.emb_e.weight.transpose(1,0))
457
+ pred += self.b.expand_as(pred)
458
+
459
+ if sigmoid:
460
+ pred = torch.sigmoid(pred)
461
+
462
+ return pred
463
+
464
+ def score_triples(self, sub, rel, obj, sigmoid=False):
465
+ '''
466
+ Inputs - subject, relation, object
467
+ Return - score
468
+ '''
469
+ sub_emb = self.emb_e(sub)
470
+ rel_emb = self.emb_rel(rel)
471
+ obj_emb = self.emb_e(obj)
472
+ x = self.conve_architecture(sub_emb, rel_emb)
473
+
474
+ pred = torch.mm(x, obj_emb.transpose(1,0))
475
+ #print(pred.shape)
476
+ pred += self.b[obj].expand_as(pred) #taking the bias value for object embedding
477
+ # above works fine for single input triples;
478
+ # but if input is batch of triples, then this is a matrix of (num_trip x num_trip) where diagonal is scores
479
+ # so use torch.diagonal() after calling this function
480
+ pred = torch.diagonal(pred)
481
+ # or could have used : pred= torch.sum(x*obj_emb, dim=-1)
482
+
483
+ if sigmoid:
484
+ pred = torch.sigmoid(pred)
485
+
486
+ return pred
487
+
488
+ def score_emb(self, emb_s, emb_r, emb_o, sigmoid=False):
489
+ '''
490
+ Inputs - embeddings of subject, relation, object
491
+ Return - score
492
+ '''
493
+ x = self.conve_architecture(emb_s, emb_r)
494
+
495
+ pred = torch.mm(x, emb_o.transpose(1,0))
496
+
497
+ pred = torch.diagonal(pred)
498
+
499
+ if sigmoid:
500
+ pred = torch.sigmoid(pred)
501
+
502
+ return pred
503
+
504
+ def score_triples_vec(self, sub, rel, obj, sigmoid=False):
505
+ '''
506
+ Inputs - subject, relation, object
507
+ Return - a vector score for the triple instead of reducing over the embedding dimension
508
+ '''
509
+ sub_emb = self.emb_e(sub)
510
+ rel_emb = self.emb_rel(rel)
511
+ obj_emb = self.emb_e(obj)
512
+
513
+ x = self.conve_architecture(sub_emb, rel_emb)
514
+
515
+ pred = x*obj_emb
516
+
517
+ if sigmoid:
518
+ pred = torch.sigmoid(pred)
519
+
520
+ return pred
DiseaseAgnostic/utils.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ A file modified on https://github.com/PeruBhardwaj/AttributionAttack/blob/main/KGEAttack/ConvE/utils.py
3
+ '''
4
+ #%%
5
+ import logging
6
+ import time
7
+ from tqdm import tqdm
8
+ import io
9
+ import pandas as pd
10
+ import numpy as np
11
+ import os
12
+ import json
13
+
14
+ import argparse
15
+ import torch
16
+ import random
17
+
18
+ from yaml import parse
19
+
20
+ from model import Conve, Distmult, Complex
21
+
22
+ logger = logging.getLogger(__name__)
23
+ #%%
24
+ def generate_dicts(data_path):
25
+ with open (os.path.join(data_path, 'entities_dict.json'), 'r') as f:
26
+ ent_to_id = json.load(f)
27
+ with open (os.path.join(data_path, 'relations_dict.json'), 'r') as f:
28
+ rel_to_id = json.load(f)
29
+ n_ent = len(list(ent_to_id.keys()))
30
+ n_rel = len(list(rel_to_id.keys()))
31
+
32
+ return n_ent, n_rel, ent_to_id, rel_to_id
33
+
34
+ def save_data(file_name, data):
35
+ with open(file_name, 'w') as fl:
36
+ for item in data:
37
+ fl.write("%s\n" % "\t".join(map(str, item)))
38
+
39
+ def load_data(file_name):
40
+ df = pd.read_csv(file_name, sep='\t', header=None, names=None, dtype=str)
41
+ df = df.drop_duplicates()
42
+ return df.values
43
+
44
+ def seed_all(seed=1):
45
+ random.seed(seed)
46
+ np.random.seed(seed)
47
+ torch.manual_seed(seed)
48
+ torch.cuda.manual_seed_all(seed)
49
+ os.environ['PYTHONHASHSEED'] = str(seed)
50
+ torch.backends.cudnn.deterministic = True
51
+
52
+ def add_model(args, n_ent, n_rel):
53
+ if args.model is None:
54
+ model = Distmult(args, n_ent, n_rel)
55
+ elif args.model == 'distmult':
56
+ model = Distmult(args, n_ent, n_rel)
57
+ elif args.model == 'complex':
58
+ model = Complex(args, n_ent, n_rel)
59
+ elif args.model == 'conve':
60
+ model = Conve(args, n_ent, n_rel)
61
+ else:
62
+ raise Exception("Unknown model!")
63
+
64
+ return model
65
+
66
+ def load_model(model_path, args, n_ent, n_rel, device):
67
+ # add a model and load the pre-trained params
68
+ model = add_model(args, n_ent, n_rel)
69
+ model.to(device)
70
+ logger.info('Loading saved model from {0}'.format(model_path))
71
+ state = torch.load(model_path)
72
+ model_params = state['state_dict']
73
+ params = [(key, value.size(), value.numel()) for key, value in model_params.items()]
74
+ for key, size, count in params:
75
+ logger.info('Key:{0}, Size:{1}, Count:{2}'.format(key, size, count))
76
+
77
+ model.load_state_dict(model_params)
78
+ model.eval()
79
+ logger.info(model)
80
+
81
+ return model
82
+
83
+ def add_eval_parameters(parser):
84
+
85
+ parser.add_argument('--eval-mode', type = str, default = 'all', help = 'Method to evaluate the attack performance. Default: all. (all or single)')
86
+ parser.add_argument('--cuda-name', type = str, required = True, help = 'Start a main thread on each cuda.')
87
+ parser.add_argument('--direct', action='store_true', help = 'Directly add edge or not.')
88
+ parser.add_argument('--seperate', action='store_true', help = 'Evaluate seperatly or not')
89
+ return parser
90
+
91
+ def add_attack_parameters(parser):
92
+
93
+ # parser.add_argument('--target-split', type=str, default='0_100_1', help='Ranks to use for target set. Values are 0 for ranks==1; 1 for ranks <=10; 2 for ranks>10 and ranks<=100. Default: 1')
94
+ parser.add_argument('--target-split', type=str, default='min', help='Methods for target triple selection. Default: min. (min or top_?, top means top_0.1)')
95
+ parser.add_argument('--target-size', type=int, default=50, help='Number of target triples. Default: 50')
96
+ parser.add_argument('--target-existed', action='store_true', help='Whether the targeted s_?_o already exists.')
97
+
98
+ # parser.add_argument('--budget', type=int, default=1, help='Budget for each target triple for each corruption side')
99
+
100
+ parser.add_argument('--attack-goal', type = str, default='single', help='Attack goal. Default: single. (single or global)')
101
+ parser.add_argument('--neighbor-num', type = int, default=20, help='Max neighbor num for each side. Default: 20')
102
+ parser.add_argument('--candidate-mode', type = str, default='quadratic', help = 'The method to generate candidate edge. Default: quadratic. (quadratic or linear)')
103
+ parser.add_argument('--reasonable-rate', type = float, default=0.7, help = 'The added edge\'s existance rank prob greater than this rate')
104
+ # parser.add_argument('--neighbor-num', type = int, default=200, help='Max neighbor num for each side. Default: 200')
105
+ # parser.add_argument('--candidate-mode', type = str, default='linear', help = 'The method to generate candidate edge. Default: quadratic. (quadratic or linear)')
106
+ parser.add_argument('--attack-batch-size', type=int, default=256, help='Batch size for processing neighbours of target')
107
+ parser.add_argument('--template-mode', type=str, default = 'manual', help = 'Template mode for transforming edge to single sentense. Default: manual. (manual or auto)')
108
+
109
+ parser.add_argument('--update-lissa', action='store_true', help = 'Update lissa cache or not.')
110
+
111
+ parser.add_argument('--GPT-batch-size', type=int, default = 64, help = 'Batch size for GPT2 when calculating LM score. Default: 64')
112
+ parser.add_argument('--LM-softmax', action='store_true', help = 'Use a softmax head on LM prob or not.')
113
+ parser.add_argument('--LMprob-mode', type=str, default='relative', help = 'Use the absolute LM score or calculate the destruction score when target word is replaced. Default: absolute. (absolute or relative)')
114
+
115
+ return parser
116
+
117
+ def get_argument_parser():
118
+ '''Generate an argument parser'''
119
+ parser = argparse.ArgumentParser(description='Graph embedding')
120
+
121
+ parser.add_argument('--seed', type=int, default=1, metavar='S', help='Random seed (default: 1)')
122
+
123
+ parser.add_argument('--data', type=str, default='GNBR', help='Dataset to use: { GNBR }')
124
+ parser.add_argument('--model', type=str, default='distmult', help='Choose from: {distmult, complex, transe, conve}')
125
+
126
+ parser.add_argument('--transe-margin', type=float, default=0.0, help='Margin value for TransE scoring function. Default:0.0')
127
+ parser.add_argument('--transe-norm', type=int, default=2, help='P-norm value for TransE scoring function. Default:2')
128
+
129
+ parser.add_argument('--epochs', type=int, default=100, help='Number of epochs to train (default: 100)')
130
+ parser.add_argument('--lr', type=float, default=0.001, help='Learning rate (default: 0.001)')
131
+ parser.add_argument('--lr-decay', type=float, default=0.0, help='Weight decay value to use in the optimizer. Default: 0.0')
132
+ parser.add_argument('--max-norm', action='store_true', help='Option to add unit max norm constraint to entity embeddings')
133
+
134
+ parser.add_argument('--train-batch-size', type=int, default=64, help='Batch size for train split (default: 128)')
135
+ parser.add_argument('--test-batch-size', type=int, default=128, help='Batch size for test split (default: 128)')
136
+ parser.add_argument('--valid-batch-size', type=int, default=128, help='Batch size for valid split (default: 128)')
137
+ parser.add_argument('--KG-valid-rate', type = float, default=0.1, help='Validation rate during KG embedding training. (default: 0.1)')
138
+
139
+ parser.add_argument('--save-influence-map', action='store_true', help='Save the influence map during training for gradient rollback.')
140
+ parser.add_argument('--add-reciprocals', action='store_true')
141
+
142
+ parser.add_argument('--embedding-dim', type=int, default=128, help='The embedding dimension (1D). Default: 128')
143
+ parser.add_argument('--stack-width', type=int, default=16, help='The first dimension of the reshaped/stacked 2D embedding. Second dimension is inferred. Default: 20')
144
+ #parser.add_argument('--stack_height', type=int, default=10, help='The second dimension of the reshaped/stacked 2D embedding. Default: 10')
145
+ parser.add_argument('--hidden-drop', type=float, default=0.3, help='Dropout for the hidden layer. Default: 0.3.')
146
+ parser.add_argument('--input-drop', type=float, default=0.2, help='Dropout for the input embeddings. Default: 0.2.')
147
+ parser.add_argument('--feat-drop', type=float, default=0.3, help='Dropout for the convolutional features. Default: 0.2.')
148
+ parser.add_argument('-num-filters', default=32, type=int, help='Number of filters for convolution')
149
+ parser.add_argument('-kernel-size', default=3, type=int, help='Kernel Size for convolution')
150
+
151
+ parser.add_argument('--use-bias', action='store_true', help='Use a bias in the convolutional layer. Default: True')
152
+
153
+ parser.add_argument('--reg-weight', type=float, default=5e-2, help='Weight for regularization. Default: 5e-2')
154
+ parser.add_argument('--reg-norm', type=int, default=3, help='Norm for regularization. Default: 2')
155
+ # parser.add_argument('--resume', action='store_true', help='Restore a saved model.')
156
+ # parser.add_argument('--resume-split', type=str, default='test', help='Split to evaluate a restored model')
157
+ # parser.add_argument('--reproduce-results', action='store_true', help='Use the hyperparameters to reproduce the results.')
158
+ # parser.add_argument('--original-data', type=str, default='FB15k-237', help='Dataset to use; this option is needed to set the hyperparams to reproduce the results for training after attack, default: FB15k-237')
159
+ return parser
160
+
161
+ def set_hyperparams(args):
162
+ if args.model == 'distmult':
163
+ args.lr = 0.005
164
+ args.train_batch_size = 1024
165
+ args.reg_norm = 3
166
+ elif args.model == 'complex':
167
+ args.lr = 0.005
168
+ args.reg_norm = 3
169
+ args.input_drop = 0.4
170
+ args.train_batch_size = 1024
171
+ elif args.model == 'conve':
172
+ args.lr = 0.005
173
+ args.train_batch_size = 1024
174
+ args.reg_weight = 0.0
175
+
176
+ # args.damping = 0.01
177
+ # args.lissa_repeat = 1
178
+ # args.lissa_depth = 1
179
+ # args.scale = 500
180
+ # args.lissa_batch_size = 100
181
+
182
+ args.damping = 0.01
183
+ args.lissa_repeat = 1
184
+ args.lissa_depth = 1
185
+ args.scale = 400
186
+ args.lissa_batch_size = 300
187
+ return args