yjwtheonly commited on
Commit
2d06d0e
1 Parent(s): 6ecb301

midification

Browse files
DiseaseSpecific/__pycache__/attack.cpython-38.pyc CHANGED
Binary files a/DiseaseSpecific/__pycache__/attack.cpython-38.pyc and b/DiseaseSpecific/__pycache__/attack.cpython-38.pyc differ
 
DiseaseSpecific/__pycache__/model.cpython-38.pyc CHANGED
Binary files a/DiseaseSpecific/__pycache__/model.cpython-38.pyc and b/DiseaseSpecific/__pycache__/model.cpython-38.pyc differ
 
DiseaseSpecific/__pycache__/utils.cpython-38.pyc CHANGED
Binary files a/DiseaseSpecific/__pycache__/utils.cpython-38.pyc and b/DiseaseSpecific/__pycache__/utils.cpython-38.pyc differ
 
Openai/__pycache__/chat.cpython-38.pyc CHANGED
Binary files a/Openai/__pycache__/chat.cpython-38.pyc and b/Openai/__pycache__/chat.cpython-38.pyc differ
 
Parameters.py CHANGED
@@ -1,9 +1,9 @@
1
  from audioop import reverse
2
 
3
- GNBRfile = '../GNBRdata/'
4
- PubTatorfile = '../pubtator/'
5
- UMLSfile = '../umls/META/'
6
- Pubmedfile = '../pubmed/'
7
 
8
  edge_type_dict = {
9
  'chemical-gene':(['A+', 'A-', 'B', 'E+', 'E-', 'E', 'N', 'O', 'K', 'Z'],
 
1
  from audioop import reverse
2
 
3
+ GNBRfile = 'GNBRdata/'
4
+ PubTatorfile = 'pubtator/'
5
+ UMLSfile = 'umls/META/'
6
+ Pubmedfile = 'pubmed/'
7
 
8
  edge_type_dict = {
9
  'chemical-gene':(['A+', 'A-', 'B', 'E+', 'E-', 'E', 'N', 'O', 'K', 'Z'],
__pycache__/Parameters.cpython-38.pyc ADDED
Binary file (3.23 kB). View file
 
__pycache__/model.cpython-38.pyc ADDED
Binary file (11.4 kB). View file
 
__pycache__/server.cpython-38.pyc ADDED
Binary file (18.8 kB). View file
 
__pycache__/utils.cpython-38.pyc ADDED
Binary file (7.81 kB). View file
 
model.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import functional as F, Parameter
3
+ from torch.autograd import Variable
4
+ from torch.nn.init import xavier_normal_, xavier_uniform_
5
+ from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
6
+
7
+ class Distmult(torch.nn.Module):
8
+ def __init__(self, args, num_entities, num_relations):
9
+ super(Distmult, self).__init__()
10
+
11
+ if args.max_norm:
12
+ self.emb_e = torch.nn.Embedding(num_entities, args.embedding_dim, max_norm=1.0)
13
+ self.emb_rel = torch.nn.Embedding(num_relations, args.embedding_dim)
14
+ else:
15
+ self.emb_e = torch.nn.Embedding(num_entities, args.embedding_dim, padding_idx=None)
16
+ self.emb_rel = torch.nn.Embedding(num_relations, args.embedding_dim, padding_idx=None)
17
+
18
+ self.inp_drop = torch.nn.Dropout(args.input_drop)
19
+ self.loss = torch.nn.CrossEntropyLoss()
20
+
21
+ self.init()
22
+
23
+ def init(self):
24
+ xavier_normal_(self.emb_e.weight)
25
+ xavier_normal_(self.emb_rel.weight)
26
+
27
+ def score_sr(self, sub, rel, sigmoid = False):
28
+ sub_emb = self.emb_e(sub).squeeze(dim=1)
29
+ rel_emb = self.emb_rel(rel).squeeze(dim=1)
30
+
31
+ #sub_emb = self.inp_drop(sub_emb)
32
+ #rel_emb = self.inp_drop(rel_emb)
33
+
34
+ pred = torch.mm(sub_emb*rel_emb, self.emb_e.weight.transpose(1,0))
35
+ if sigmoid:
36
+ pred = torch.sigmoid(pred)
37
+ return pred
38
+
39
+ def score_or(self, obj, rel, sigmoid = False):
40
+ obj_emb = self.emb_e(obj).squeeze(dim=1)
41
+ rel_emb = self.emb_rel(rel).squeeze(dim=1)
42
+
43
+ #obj_emb = self.inp_drop(obj_emb)
44
+ #rel_emb = self.inp_drop(rel_emb)
45
+
46
+ pred = torch.mm(obj_emb*rel_emb, self.emb_e.weight.transpose(1,0))
47
+ if sigmoid:
48
+ pred = torch.sigmoid(pred)
49
+ return pred
50
+
51
+
52
+ def forward(self, sub_emb, rel_emb, mode='rhs', sigmoid=False):
53
+ '''
54
+ When mode is 'rhs' we expect (s,r); for 'lhs', we expect (o,r)
55
+ For distmult, computations for both modes are equivalent, so we do not need if-else block
56
+ '''
57
+ sub_emb = self.inp_drop(sub_emb)
58
+ rel_emb = self.inp_drop(rel_emb)
59
+
60
+ pred = torch.mm(sub_emb*rel_emb, self.emb_e.weight.transpose(1,0))
61
+
62
+ if sigmoid:
63
+ pred = torch.sigmoid(pred)
64
+
65
+ return pred
66
+
67
+ def score_triples(self, sub, rel, obj, sigmoid=False):
68
+ '''
69
+ Inputs - subject, relation, object
70
+ Return - score
71
+ '''
72
+ sub_emb = self.emb_e(sub).squeeze(dim=1)
73
+ rel_emb = self.emb_rel(rel).squeeze(dim=1)
74
+ obj_emb = self.emb_e(obj).squeeze(dim=1)
75
+
76
+ pred = torch.sum(sub_emb*rel_emb*obj_emb, dim=-1)
77
+
78
+ if sigmoid:
79
+ pred = torch.sigmoid(pred)
80
+
81
+ return pred
82
+
83
+ def score_emb(self, emb_s, emb_r, emb_o, sigmoid=False):
84
+ '''
85
+ Inputs - embeddings of subject, relation, object
86
+ Return - score
87
+ '''
88
+ pred = torch.sum(emb_s*emb_r*emb_o, dim=-1)
89
+
90
+ if sigmoid:
91
+ pred = torch.sigmoid(pred)
92
+
93
+ return pred
94
+
95
+ def score_triples_vec(self, sub, rel, obj, sigmoid=False):
96
+ '''
97
+ Inputs - subject, relation, object
98
+ Return - a vector score for the triple instead of reducing over the embedding dimension
99
+ '''
100
+ sub_emb = self.emb_e(sub).squeeze(dim=1)
101
+ rel_emb = self.emb_rel(rel).squeeze(dim=1)
102
+ obj_emb = self.emb_e(obj).squeeze(dim=1)
103
+
104
+ pred = sub_emb*rel_emb*obj_emb
105
+
106
+ if sigmoid:
107
+ pred = torch.sigmoid(pred)
108
+
109
+ return pred
110
+
111
+ class Complex(torch.nn.Module):
112
+ def __init__(self, args, num_entities, num_relations):
113
+ super(Complex, self).__init__()
114
+
115
+ if args.max_norm:
116
+ self.emb_e = torch.nn.Embedding(num_entities, 2*args.embedding_dim, max_norm=1.0)
117
+ self.emb_rel = torch.nn.Embedding(num_relations, 2*args.embedding_dim)
118
+ else:
119
+ self.emb_e = torch.nn.Embedding(num_entities, 2*args.embedding_dim, padding_idx=None)
120
+ self.emb_rel = torch.nn.Embedding(num_relations, 2*args.embedding_dim, padding_idx=None)
121
+
122
+ self.inp_drop = torch.nn.Dropout(args.input_drop)
123
+ self.loss = torch.nn.CrossEntropyLoss()
124
+
125
+ self.init()
126
+
127
+ def init(self):
128
+ xavier_normal_(self.emb_e.weight)
129
+ xavier_normal_(self.emb_rel.weight)
130
+
131
+ def score_sr(self, sub, rel, sigmoid = False):
132
+ sub_emb = self.emb_e(sub).squeeze(dim=1)
133
+ rel_emb = self.emb_rel(rel).squeeze(dim=1)
134
+
135
+ s_real, s_img = torch.chunk(rel_emb, 2, dim=-1)
136
+ rel_real, rel_img = torch.chunk(sub_emb, 2, dim=-1)
137
+ emb_e_real, emb_e_img = torch.chunk(self.emb_e.weight, 2, dim=-1)
138
+
139
+ realo_realreal = s_real*rel_real
140
+ realo_imgimg = s_img*rel_img
141
+ realo = realo_realreal - realo_imgimg
142
+ real = torch.mm(realo, emb_e_real.transpose(1,0))
143
+
144
+ imgo_realimg = s_real*rel_img
145
+ imgo_imgreal = s_img*rel_real
146
+ imgo = imgo_realimg + imgo_imgreal
147
+ img = torch.mm(imgo, emb_e_img.transpose(1,0))
148
+
149
+ pred = real + img
150
+
151
+ if sigmoid:
152
+ pred = torch.sigmoid(pred)
153
+ return pred
154
+
155
+
156
+ def score_or(self, obj, rel, sigmoid = False):
157
+ obj_emb = self.emb_e(obj).squeeze(dim=1)
158
+ rel_emb = self.emb_rel(rel).squeeze(dim=1)
159
+
160
+ rel_real, rel_img = torch.chunk(rel_emb, 2, dim=-1)
161
+ o_real, o_img = torch.chunk(obj_emb, 2, dim=-1)
162
+ emb_e_real, emb_e_img = torch.chunk(self.emb_e.weight, 2, dim=-1)
163
+
164
+ #rel_real = self.inp_drop(rel_real)
165
+ #rel_img = self.inp_drop(rel_img)
166
+ #o_real = self.inp_drop(o_real)
167
+ #o_img = self.inp_drop(o_img)
168
+
169
+ # complex space bilinear product (equivalent to HolE)
170
+ # realrealreal = torch.mm(rel_real*o_real, emb_e_real.transpose(1,0))
171
+ # realimgimg = torch.mm(rel_img*o_img, emb_e_real.transpose(1,0))
172
+ # imgrealimg = torch.mm(rel_real*o_img, emb_e_img.transpose(1,0))
173
+ # imgimgreal = torch.mm(rel_img*o_real, emb_e_img.transpose(1,0))
174
+ # pred = realrealreal + realimgimg + imgrealimg - imgimgreal
175
+
176
+ reals_realreal = rel_real*o_real
177
+ reals_imgimg = rel_img*o_img
178
+ reals = reals_realreal + reals_imgimg
179
+ real = torch.mm(reals, emb_e_real.transpose(1,0))
180
+
181
+ imgs_realimg = rel_real*o_img
182
+ imgs_imgreal = rel_img*o_real
183
+ imgs = imgs_realimg - imgs_imgreal
184
+ img = torch.mm(imgs, emb_e_img.transpose(1,0))
185
+
186
+ pred = real + img
187
+
188
+ if sigmoid:
189
+ pred = torch.sigmoid(pred)
190
+ return pred
191
+
192
+
193
+ def forward(self, sub_emb, rel_emb, mode='rhs', sigmoid=False):
194
+ '''
195
+ When mode is 'rhs' we expect (s,r); for 'lhs', we expect (o,r)
196
+
197
+ '''
198
+ if mode == 'lhs':
199
+ rel_real, rel_img = torch.chunk(rel_emb, 2, dim=-1)
200
+ o_real, o_img = torch.chunk(sub_emb, 2, dim=-1)
201
+ emb_e_real, emb_e_img = torch.chunk(self.emb_e.weight, 2, dim=-1)
202
+
203
+ rel_real = self.inp_drop(rel_real)
204
+ rel_img = self.inp_drop(rel_img)
205
+ o_real = self.inp_drop(o_real)
206
+ o_img = self.inp_drop(o_img)
207
+
208
+ reals_realreal = rel_real*o_real
209
+ reals_imgimg = rel_img*o_img
210
+ reals = reals_realreal + reals_imgimg
211
+ real = torch.mm(reals, emb_e_real.transpose(1,0))
212
+
213
+ imgs_realimg = rel_real*o_img
214
+ imgs_imgreal = rel_img*o_real
215
+ imgs = imgs_realimg - imgs_imgreal
216
+ img = torch.mm(imgs, emb_e_img.transpose(1,0))
217
+
218
+ pred = real + img
219
+
220
+ else:
221
+ s_real, s_img = torch.chunk(rel_emb, 2, dim=-1)
222
+ rel_real, rel_img = torch.chunk(sub_emb, 2, dim=-1)
223
+ emb_e_real, emb_e_img = torch.chunk(self.emb_e.weight, 2, dim=-1)
224
+
225
+ s_real = self.inp_drop(s_real)
226
+ s_img = self.inp_drop(s_img)
227
+ rel_real = self.inp_drop(rel_real)
228
+ rel_img = self.inp_drop(rel_img)
229
+
230
+ realo_realreal = s_real*rel_real
231
+ realo_imgimg = s_img*rel_img
232
+ realo = realo_realreal - realo_imgimg
233
+ real = torch.mm(realo, emb_e_real.transpose(1,0))
234
+
235
+ imgo_realimg = s_real*rel_img
236
+ imgo_imgreal = s_img*rel_real
237
+ imgo = imgo_realimg + imgo_imgreal
238
+ img = torch.mm(imgo, emb_e_img.transpose(1,0))
239
+
240
+ pred = real + img
241
+
242
+ if sigmoid:
243
+ pred = torch.sigmoid(pred)
244
+
245
+ return pred
246
+
247
+ def score_triples(self, sub, rel, obj, sigmoid=False):
248
+ '''
249
+ Inputs - subject, relation, object
250
+ Return - score
251
+ '''
252
+ sub_emb = self.emb_e(sub).squeeze(dim=1)
253
+ rel_emb = self.emb_rel(rel).squeeze(dim=1)
254
+ obj_emb = self.emb_e(obj).squeeze(dim=1)
255
+
256
+ s_real, s_img = torch.chunk(sub_emb, 2, dim=-1)
257
+ rel_real, rel_img = torch.chunk(rel_emb, 2, dim=-1)
258
+ o_real, o_img = torch.chunk(obj_emb, 2, dim=-1)
259
+
260
+ realrealreal = torch.sum(s_real*rel_real*o_real, dim=-1)
261
+ realimgimg = torch.sum(s_real*rel_img*o_img, axis=-1)
262
+ imgrealimg = torch.sum(s_img*rel_real*o_img, axis=-1)
263
+ imgimgreal = torch.sum(s_img*rel_img*o_real, axis=-1)
264
+
265
+ pred = realrealreal + realimgimg + imgrealimg - imgimgreal
266
+
267
+ if sigmoid:
268
+ pred = torch.sigmoid(pred)
269
+
270
+ return pred
271
+
272
+ def score_emb(self, emb_s, emb_r, emb_o, sigmoid=False):
273
+ '''
274
+ Inputs - embeddings of subject, relation, object
275
+ Return - score
276
+ '''
277
+
278
+ s_real, s_img = torch.chunk(emb_s, 2, dim=-1)
279
+ rel_real, rel_img = torch.chunk(emb_r, 2, dim=-1)
280
+ o_real, o_img = torch.chunk(emb_o, 2, dim=-1)
281
+
282
+ realrealreal = torch.sum(s_real*rel_real*o_real, dim=-1)
283
+ realimgimg = torch.sum(s_real*rel_img*o_img, axis=-1)
284
+ imgrealimg = torch.sum(s_img*rel_real*o_img, axis=-1)
285
+ imgimgreal = torch.sum(s_img*rel_img*o_real, axis=-1)
286
+
287
+ pred = realrealreal + realimgimg + imgrealimg - imgimgreal
288
+
289
+ if sigmoid:
290
+ pred = torch.sigmoid(pred)
291
+
292
+ return pred
293
+
294
+ def score_triples_vec(self, sub, rel, obj, sigmoid=False):
295
+ '''
296
+ Inputs - subject, relation, object
297
+ Return - a vector score for the triple instead of reducing over the embedding dimension
298
+ '''
299
+ sub_emb = self.emb_e(sub).squeeze(dim=1)
300
+ rel_emb = self.emb_rel(rel).squeeze(dim=1)
301
+ obj_emb = self.emb_e(obj).squeeze(dim=1)
302
+
303
+ s_real, s_img = torch.chunk(sub_emb, 2, dim=-1)
304
+ rel_real, rel_img = torch.chunk(rel_emb, 2, dim=-1)
305
+ o_real, o_img = torch.chunk(obj_emb, 2, dim=-1)
306
+
307
+ realrealreal = s_real*rel_real*o_real
308
+ realimgimg = s_real*rel_img*o_img
309
+ imgrealimg = s_img*rel_real*o_img
310
+ imgimgreal = s_img*rel_img*o_real
311
+
312
+ pred = realrealreal + realimgimg + imgrealimg - imgimgreal
313
+
314
+ if sigmoid:
315
+ pred = torch.sigmoid(pred)
316
+
317
+ return pred
318
+
319
+ class Conve(torch.nn.Module):
320
+
321
+ #Too slow !!!!
322
+
323
+ def __init__(self, args, num_entities, num_relations):
324
+ super(Conve, self).__init__()
325
+
326
+ if args.max_norm:
327
+ self.emb_e = torch.nn.Embedding(num_entities, args.embedding_dim, max_norm=1.0)
328
+ self.emb_rel = torch.nn.Embedding(num_relations, args.embedding_dim)
329
+ else:
330
+ self.emb_e = torch.nn.Embedding(num_entities, args.embedding_dim, padding_idx=None)
331
+ self.emb_rel = torch.nn.Embedding(num_relations, args.embedding_dim, padding_idx=None)
332
+
333
+ self.inp_drop = torch.nn.Dropout(args.input_drop)
334
+ self.hidden_drop = torch.nn.Dropout(args.hidden_drop)
335
+ self.feature_drop = torch.nn.Dropout2d(args.feat_drop)
336
+
337
+ self.embedding_dim = args.embedding_dim #default is 200
338
+ self.num_filters = args.num_filters # default is 32
339
+ self.kernel_size = args.kernel_size # default is 3
340
+ self.stack_width = args.stack_width # default is 20
341
+ self.stack_height = args.embedding_dim // self.stack_width
342
+
343
+ self.bn0 = torch.nn.BatchNorm2d(1)
344
+ self.bn1 = torch.nn.BatchNorm2d(self.num_filters)
345
+ self.bn2 = torch.nn.BatchNorm1d(args.embedding_dim)
346
+
347
+ self.conv1 = torch.nn.Conv2d(1, out_channels=self.num_filters,
348
+ kernel_size=(self.kernel_size, self.kernel_size),
349
+ stride=1, padding=0, bias=args.use_bias)
350
+ #self.conv1 = torch.nn.Conv2d(1, 32, (3, 3), 1, 0, bias=args.use_bias) # <-- default
351
+
352
+ flat_sz_h = int(2*self.stack_width) - self.kernel_size + 1
353
+ flat_sz_w = self.stack_height - self.kernel_size + 1
354
+ self.flat_sz = flat_sz_h*flat_sz_w*self.num_filters
355
+ self.fc = torch.nn.Linear(self.flat_sz, args.embedding_dim)
356
+
357
+ self.register_parameter('b', Parameter(torch.zeros(num_entities)))
358
+ self.loss = torch.nn.CrossEntropyLoss()
359
+
360
+ self.init()
361
+
362
+ def init(self):
363
+ xavier_normal_(self.emb_e.weight)
364
+ xavier_normal_(self.emb_rel.weight)
365
+
366
+ def concat(self, e1_embed, rel_embed, form='plain'):
367
+ if form == 'plain':
368
+ e1_embed = e1_embed. view(-1, 1, self.stack_width, self.stack_height)
369
+ rel_embed = rel_embed.view(-1, 1, self.stack_width, self.stack_height)
370
+ stack_inp = torch.cat([e1_embed, rel_embed], 2)
371
+
372
+ elif form == 'alternate':
373
+ e1_embed = e1_embed. view(-1, 1, self.embedding_dim)
374
+ rel_embed = rel_embed.view(-1, 1, self.embedding_dim)
375
+ stack_inp = torch.cat([e1_embed, rel_embed], 1)
376
+ stack_inp = torch.transpose(stack_inp, 2, 1).reshape((-1, 1, 2*self.stack_width, self.stack_height))
377
+
378
+ else: raise NotImplementedError
379
+ return stack_inp
380
+
381
+ def conve_architecture(self, sub_emb, rel_emb):
382
+ stacked_inputs = self.concat(sub_emb, rel_emb)
383
+ stacked_inputs = self.bn0(stacked_inputs)
384
+ x = self.inp_drop(stacked_inputs)
385
+ x = self.conv1(x)
386
+ x = self.bn1(x)
387
+ x = F.relu(x)
388
+ x = self.feature_drop(x)
389
+ #x = x.view(x.shape[0], -1)
390
+ x = x.view(-1, self.flat_sz)
391
+ x = self.fc(x)
392
+ x = self.hidden_drop(x)
393
+ x = self.bn2(x)
394
+ x = F.relu(x)
395
+
396
+ return x
397
+
398
+ def score_sr(self, sub, rel, sigmoid = False):
399
+ sub_emb = self.emb_e(sub)
400
+ rel_emb = self.emb_rel(rel)
401
+
402
+ x = self.conve_architecture(sub_emb, rel_emb)
403
+
404
+ pred = torch.mm(x, self.emb_e.weight.transpose(1,0))
405
+ pred += self.b.expand_as(pred)
406
+
407
+ if sigmoid:
408
+ pred = torch.sigmoid(pred)
409
+ return pred
410
+
411
+ def score_or(self, obj, rel, sigmoid = False):
412
+ obj_emb = self.emb_e(obj)
413
+ rel_emb = self.emb_rel(rel)
414
+
415
+ x = self.conve_architecture(obj_emb, rel_emb)
416
+ pred = torch.mm(x, self.emb_e.weight.transpose(1,0))
417
+ pred += self.b.expand_as(pred)
418
+
419
+ if sigmoid:
420
+ pred = torch.sigmoid(pred)
421
+ return pred
422
+
423
+
424
+ def forward(self, sub_emb, rel_emb, mode='rhs', sigmoid=False):
425
+ '''
426
+ When mode is 'rhs' we expect (s,r); for 'lhs', we expect (o,r)
427
+ For conve, computations for both modes are equivalent, so we do not need if-else block
428
+ '''
429
+ x = self.conve_architecture(sub_emb, rel_emb)
430
+
431
+ pred = torch.mm(x, self.emb_e.weight.transpose(1,0))
432
+ pred += self.b.expand_as(pred)
433
+
434
+ if sigmoid:
435
+ pred = torch.sigmoid(pred)
436
+
437
+ return pred
438
+
439
+ def score_triples(self, sub, rel, obj, sigmoid=False):
440
+ '''
441
+ Inputs - subject, relation, object
442
+ Return - score
443
+ '''
444
+ sub_emb = self.emb_e(sub)
445
+ rel_emb = self.emb_rel(rel)
446
+ obj_emb = self.emb_e(obj)
447
+ x = self.conve_architecture(sub_emb, rel_emb)
448
+
449
+ pred = torch.mm(x, obj_emb.transpose(1,0))
450
+ #print(pred.shape)
451
+ pred += self.b[obj].expand_as(pred) #taking the bias value for object embedding
452
+ # above works fine for single input triples;
453
+ # but if input is batch of triples, then this is a matrix of (num_trip x num_trip) where diagonal is scores
454
+ # so use torch.diagonal() after calling this function
455
+ pred = torch.diagonal(pred)
456
+ # or could have used : pred= torch.sum(x*obj_emb, dim=-1)
457
+
458
+ if sigmoid:
459
+ pred = torch.sigmoid(pred)
460
+
461
+ return pred
462
+
463
+ def score_emb(self, emb_s, emb_r, emb_o, sigmoid=False):
464
+ '''
465
+ Inputs - embeddings of subject, relation, object
466
+ Return - score
467
+ '''
468
+ x = self.conve_architecture(emb_s, emb_r)
469
+
470
+ pred = torch.mm(x, emb_o.transpose(1,0))
471
+ #pred += self.b[obj].expand_as(pred) #taking the bias value for object embedding - don't know which obj
472
+ # above works fine for single input triples;
473
+ # but if input is batch of triples, then this is a matrix of (num_trip x num_trip) where diagonal is scores
474
+ # so use torch.diagonal() after calling this function
475
+ pred = torch.diagonal(pred)
476
+ # or could have used : pred= torch.sum(x*obj_emb, dim=-1)
477
+
478
+ if sigmoid:
479
+ pred = torch.sigmoid(pred)
480
+
481
+ return pred
482
+
483
+ def score_triples_vec(self, sub, rel, obj, sigmoid=False):
484
+ '''
485
+ Inputs - subject, relation, object
486
+ Return - a vector score for the triple instead of reducing over the embedding dimension
487
+ '''
488
+ sub_emb = self.emb_e(sub)
489
+ rel_emb = self.emb_rel(rel)
490
+ obj_emb = self.emb_e(obj)
491
+
492
+ x = self.conve_architecture(sub_emb, rel_emb)
493
+
494
+ #pred = torch.mm(x, obj_emb.transpose(1,0))
495
+ pred = x*obj_emb
496
+ #print(pred.shape, self.b[obj].shape) #shapes are [7,200] and [7]
497
+ #pred += self.b[obj].expand_as(pred) #taking the bias value for object embedding - can't add scalar to vector
498
+
499
+ #pred = sub_emb*rel_emb*obj_emb
500
+
501
+ if sigmoid:
502
+ pred = torch.sigmoid(pred)
503
+
504
+ return pred
server/server.py → server.py RENAMED
@@ -9,7 +9,7 @@ import numpy as np
9
  import json
10
  import networkx as nx
11
  import spacy
12
- os.system("python -m spacy download en-core-web-sm")
13
  import pickle as pkl
14
  #%%
15
 
@@ -17,14 +17,12 @@ from torch.nn.modules.loss import CrossEntropyLoss
17
  from transformers import AutoTokenizer
18
  from transformers import BioGptForCausalLM, BartForConditionalGeneration
19
 
20
- import server_utils
21
 
22
- sys.path.append("..")
23
  import Parameters
24
  from Openai.chat import generate_abstract
25
- sys.path.append("../DiseaseSpecific")
26
- import utils, attack
27
- from attack import calculate_edge_bound, get_model_loss_without_softmax
28
 
29
 
30
  specific_model = None
@@ -51,8 +49,8 @@ np.set_printoptions(precision=5)
51
  cudnn.benchmark = False
52
 
53
  model_name = '{0}_{1}_{2}_{3}_{4}'.format(args.model, args.embedding_dim, args.input_drop, args.hidden_drop, args.feat_drop)
54
- model_path = '../DiseaseSpecific/saved_models/{0}_{1}.model'.format(args.data, model_name)
55
- data_path = os.path.join('../DiseaseSpecific/processed_data', args.data)
56
  data = utils.load_data(os.path.join(data_path, 'all.txt'))
57
 
58
  n_ent, n_rel, ent_to_id, rel_to_id = utils.generate_dicts(data_path)
@@ -596,11 +594,11 @@ def specific_func(start_entity, end_entity):
596
  o_name = entity_raw_name[id_to_entity[str(o)]]
597
  attack_data = np.array([[s, r, o]])
598
  path_list = []
599
- with open(f'../DiseaseSpecific/generate_abstract/path/random_{args.reasonable_rate}_path.json', 'r') as fl:
600
  for line in fl.readlines():
601
  line.replace('\n', '')
602
  path_list.append(line)
603
- with open(f'../DiseaseSpecific/generate_abstract/random_{args.reasonable_rate}_sentence.json', 'r') as fl:
604
  sentence_dict = json.load(fl)
605
  dpath = []
606
  for k, v in sentence_dict.items():
 
9
  import json
10
  import networkx as nx
11
  import spacy
12
+ # os.system("python -m spacy download en-core-web-sm")
13
  import pickle as pkl
14
  #%%
15
 
 
17
  from transformers import AutoTokenizer
18
  from transformers import BioGptForCausalLM, BartForConditionalGeneration
19
 
20
+ from server import server_utils
21
 
 
22
  import Parameters
23
  from Openai.chat import generate_abstract
24
+ from DiseaseSpecific import utils, attack
25
+ from DiseaseSpecific.attack import calculate_edge_bound, get_model_loss_without_softmax
 
26
 
27
 
28
  specific_model = None
 
49
  cudnn.benchmark = False
50
 
51
  model_name = '{0}_{1}_{2}_{3}_{4}'.format(args.model, args.embedding_dim, args.input_drop, args.hidden_drop, args.feat_drop)
52
+ model_path = 'DiseaseSpecific/saved_models/{0}_{1}.model'.format(args.data, model_name)
53
+ data_path = os.path.join('DiseaseSpecific/processed_data', args.data)
54
  data = utils.load_data(os.path.join(data_path, 'all.txt'))
55
 
56
  n_ent, n_rel, ent_to_id, rel_to_id = utils.generate_dicts(data_path)
 
594
  o_name = entity_raw_name[id_to_entity[str(o)]]
595
  attack_data = np.array([[s, r, o]])
596
  path_list = []
597
+ with open(f'DiseaseSpecific/generate_abstract/path/random_{args.reasonable_rate}_path.json', 'r') as fl:
598
  for line in fl.readlines():
599
  line.replace('\n', '')
600
  path_list.append(line)
601
+ with open(f'DiseaseSpecific/generate_abstract/random_{args.reasonable_rate}_sentence.json', 'r') as fl:
602
  sentence_dict = json.load(fl)
603
  dpath = []
604
  for k, v in sentence_dict.items():
server/__init__.py ADDED
File without changes
server/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (137 Bytes). View file
 
utils.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ A file modified on https://github.com/PeruBhardwaj/AttributionAttack/blob/main/KGEAttack/ConvE/utils.py
3
+ '''
4
+ #%%
5
+ import logging
6
+ import time
7
+ from tqdm import tqdm
8
+ import io
9
+ import pandas as pd
10
+ import numpy as np
11
+ import os
12
+ import json
13
+
14
+ import argparse
15
+ import torch
16
+ import random
17
+
18
+ from yaml import parse
19
+
20
+ from model import Conve, Distmult, Complex
21
+
22
+ logger = logging.getLogger(__name__)
23
+ #%%
24
+ def generate_dicts(data_path):
25
+ with open (os.path.join(data_path, 'entities_dict.json'), 'r') as f:
26
+ ent_to_id = json.load(f)
27
+ with open (os.path.join(data_path, 'relations_dict.json'), 'r') as f:
28
+ rel_to_id = json.load(f)
29
+ n_ent = len(list(ent_to_id.keys()))
30
+ n_rel = len(list(rel_to_id.keys()))
31
+
32
+ return n_ent, n_rel, ent_to_id, rel_to_id
33
+
34
+ def save_data(file_name, data):
35
+ with open(file_name, 'w') as fl:
36
+ for item in data:
37
+ fl.write("%s\n" % "\t".join(map(str, item)))
38
+
39
+ def load_data(file_name, drop = True):
40
+ df = pd.read_csv(file_name, sep='\t', header=None, names=None, dtype=str)
41
+ if drop:
42
+ df = df.drop_duplicates()
43
+ else:
44
+ pass
45
+ return df.values
46
+
47
+ def seed_all(seed=1):
48
+ random.seed(seed)
49
+ np.random.seed(seed)
50
+ torch.manual_seed(seed)
51
+ torch.cuda.manual_seed_all(seed)
52
+ os.environ['PYTHONHASHSEED'] = str(seed)
53
+ torch.backends.cudnn.deterministic = True
54
+
55
+ def add_model(args, n_ent, n_rel):
56
+ if args.model is None:
57
+ model = Distmult(args, n_ent, n_rel)
58
+ elif args.model == 'distmult':
59
+ model = Distmult(args, n_ent, n_rel)
60
+ elif args.model == 'complex':
61
+ model = Complex(args, n_ent, n_rel)
62
+ elif args.model == 'conve':
63
+ model = Conve(args, n_ent, n_rel)
64
+ else:
65
+ raise Exception("Unknown model!")
66
+
67
+ return model
68
+
69
+ def load_model(model_path, args, n_ent, n_rel, device):
70
+ # add a model and load the pre-trained params
71
+ model = add_model(args, n_ent, n_rel)
72
+ model.to(device)
73
+ logger.info('Loading saved model from {0}'.format(model_path))
74
+ state = torch.load(model_path)
75
+ model_params = state['state_dict']
76
+ params = [(key, value.size(), value.numel()) for key, value in model_params.items()]
77
+ for key, size, count in params:
78
+ logger.info('Key:{0}, Size:{1}, Count:{2}'.format(key, size, count))
79
+
80
+ model.load_state_dict(model_params)
81
+ model.eval()
82
+ logger.info(model)
83
+
84
+ return model
85
+
86
+ def add_eval_parameters(parser):
87
+
88
+ # parser.add_argument('--eval-mode', type = str, default = 'all', help = 'Method to evaluate the attack performance. Default: all. (all or single)')
89
+ parser.add_argument('--cuda-name', type = str, required = True, help = 'Start a main thread on each cuda.')
90
+ parser.add_argument('--direct', action='store_true', help = 'Directly add edge or not.')
91
+ parser.add_argument('--seperate', action='store_true', help = 'Evaluate seperatly or not')
92
+ parser.add_argument('--mode', type = str, default = '', help = ' '' or '' ')
93
+ parser.add_argument('--mask-ratio', type=str, default='', help='Mask ratio for Fig4b')
94
+ return parser
95
+
96
+ def add_attack_parameters(parser):
97
+
98
+ # parser.add_argument('--target-split', type=str, default='0_100_1', help='Ranks to use for target set. Values are 0 for ranks==1; 1 for ranks <=10; 2 for ranks>10 and ranks<=100. Default: 1')
99
+ parser.add_argument('--target-split', type=str, default='min', help='Methods for target triple selection. Default: min. (min or top_?, top means top_0.1)')
100
+ parser.add_argument('--target-size', type=int, default=50, help='Number of target triples. Default: 50')
101
+ parser.add_argument('--target-existed', action='store_true', help='Whether the targeted s_?_o already exists.')
102
+
103
+ # parser.add_argument('--budget', type=int, default=1, help='Budget for each target triple for each corruption side')
104
+
105
+ parser.add_argument('--attack-goal', type = str, default='single', help='Attack goal. Default: single. (single or global)')
106
+ parser.add_argument('--neighbor-num', type = int, default=20, help='Max neighbor num for each side. Default: 20')
107
+ parser.add_argument('--candidate-mode', type = str, default='quadratic', help = 'The method to generate candidate edge. Default: quadratic. (quadratic or linear)')
108
+ parser.add_argument('--reasonable-rate', type = float, default=0.7, help = 'The added edge\'s existance rank prob greater than this rate')
109
+ parser.add_argument('--added-edge-num', type = str, default='', help = 'How many edges to add for each target edge. Default: '' means 1.')
110
+ # parser.add_argument('--neighbor-num', type = int, default=200, help='Max neighbor num for each side. Default: 200')
111
+ # parser.add_argument('--candidate-mode', type = str, default='linear', help = 'The method to generate candidate edge. Default: quadratic. (quadratic or linear)')
112
+ parser.add_argument('--attack-batch-size', type=int, default=256, help='Batch size for processing neighbours of target')
113
+ parser.add_argument('--template-mode', type=str, default = 'manual', help = 'Template mode for transforming edge to single sentense. Default: manual. (manual or auto)')
114
+
115
+ parser.add_argument('--update-lissa', action='store_true', help = 'Update lissa cache or not.')
116
+
117
+ parser.add_argument('--GPT-batch-size', type=int, default = 64, help = 'Batch size for GPT2 when calculating LM score. Default: 64')
118
+ parser.add_argument('--LM-softmax', action='store_true', help = 'Use a softmax head on LM prob or not.')
119
+ parser.add_argument('--LMprob-mode', type=str, default='relative', help = 'Use the absolute LM score or calculate the destruction score when target word is replaced. Default: absolute. (absolute or relative)')
120
+
121
+ parser.add_argument('--load-existed', action='store_true', help = 'Use cached intermidiate results or not, when only --reasonable-rate changed, set this param to True')
122
+
123
+ return parser
124
+
125
+ def get_argument_parser():
126
+ '''Generate an argument parser'''
127
+ parser = argparse.ArgumentParser(description='Graph embedding')
128
+
129
+ parser.add_argument('--seed', type=int, default=1, metavar='S', help='Random seed (default: 1)')
130
+
131
+ parser.add_argument('--data', type=str, default='GNBR', help='Dataset to use: { GNBR }')
132
+ parser.add_argument('--model', type=str, default='distmult', help='Choose from: {distmult, conve, complex}')
133
+
134
+ parser.add_argument('--transe-margin', type=float, default=0.0, help='Margin value for TransE scoring function. Default:0.0')
135
+ parser.add_argument('--transe-norm', type=int, default=2, help='P-norm value for TransE scoring function. Default:2')
136
+
137
+ parser.add_argument('--epochs', type=int, default=100, help='Number of epochs to train (default: 100)')
138
+ parser.add_argument('--lr', type=float, default=0.001, help='Learning rate (default: 0.001)')
139
+ parser.add_argument('--lr-decay', type=float, default=0.0, help='Weight decay value to use in the optimizer. Default: 0.0')
140
+ parser.add_argument('--max-norm', action='store_true', help='Option to add unit max norm constraint to entity embeddings')
141
+
142
+ parser.add_argument('--train-batch-size', type=int, default=64, help='Batch size for train split (default: 128)')
143
+ parser.add_argument('--test-batch-size', type=int, default=128, help='Batch size for test split (default: 128)')
144
+ parser.add_argument('--valid-batch-size', type=int, default=128, help='Batch size for valid split (default: 128)')
145
+ parser.add_argument('--KG-valid-rate', type = float, default=0.1, help='Validation rate during KG embedding training. (default: 0.1)')
146
+
147
+ parser.add_argument('--save-influence-map', action='store_true', help='Save the influence map during training for gradient rollback.')
148
+ parser.add_argument('--add-reciprocals', action='store_true')
149
+
150
+ parser.add_argument('--embedding-dim', type=int, default=128, help='The embedding dimension (1D). Default: 128')
151
+ parser.add_argument('--stack-width', type=int, default=16, help='The first dimension of the reshaped/stacked 2D embedding. Second dimension is inferred. Default: 20')
152
+ #parser.add_argument('--stack_height', type=int, default=10, help='The second dimension of the reshaped/stacked 2D embedding. Default: 10')
153
+ parser.add_argument('--hidden-drop', type=float, default=0.3, help='Dropout for the hidden layer. Default: 0.3.')
154
+ parser.add_argument('--input-drop', type=float, default=0.2, help='Dropout for the input embeddings. Default: 0.2.')
155
+ parser.add_argument('--feat-drop', type=float, default=0.3, help='Dropout for the convolutional features. Default: 0.2.')
156
+ parser.add_argument('-num-filters', default=32, type=int, help='Number of filters for convolution')
157
+ parser.add_argument('-kernel-size', default=3, type=int, help='Kernel Size for convolution')
158
+
159
+ parser.add_argument('--use-bias', action='store_true', help='Use a bias in the convolutional layer. Default: True')
160
+
161
+ parser.add_argument('--reg-weight', type=float, default=5e-2, help='Weight for regularization. Default: 5e-2')
162
+ parser.add_argument('--reg-norm', type=int, default=3, help='Norm for regularization. Default: 2')
163
+ # parser.add_argument('--resume', action='store_true', help='Restore a saved model.')
164
+ # parser.add_argument('--resume-split', type=str, default='test', help='Split to evaluate a restored model')
165
+ # parser.add_argument('--reproduce-results', action='store_true', help='Use the hyperparameters to reproduce the results.')
166
+ # parser.add_argument('--original-data', type=str, default='FB15k-237', help='Dataset to use; this option is needed to set the hyperparams to reproduce the results for training after attack, default: FB15k-237')
167
+ return parser
168
+
169
+ def set_hyperparams(args):
170
+ if args.model == 'distmult':
171
+ args.lr = 0.005
172
+ args.train_batch_size = 1024
173
+ args.reg_norm = 3
174
+ elif args.model == 'complex':
175
+ args.lr = 0.005
176
+ args.reg_norm = 3
177
+ args.input_drop = 0.4
178
+ args.train_batch_size = 1024
179
+ elif args.model == 'conve':
180
+ args.lr = 0.005
181
+ args.train_batch_size = 1024
182
+ args.reg_weight = 0.0
183
+
184
+ # args.damping = 0.01
185
+ # args.lissa_repeat = 1
186
+ # args.lissa_depth = 1
187
+ # args.scale = 500
188
+ # args.lissa_batch_size = 100
189
+
190
+ args.damping = 0.01
191
+ args.lissa_repeat = 1
192
+ args.lissa_depth = 1
193
+ args.scale = 400
194
+ args.lissa_batch_size = 300
195
+ return args