Scorpius_HF / DiseaseSpecific /KG_extractor.py
yjwtheonly
specific
ac7c391
raw
history blame
20.5 kB
#%%
import torch
import numpy as np
from torch.autograd import Variable
from sklearn import metrics
import datetime
from typing import Dict, Tuple, List
import logging
import os
import utils
import pickle as pkl
import json
import torch.backends.cudnn as cudnn
from tqdm import tqdm
import sys
sys.path.append("..")
import Parameters
parser = utils.get_argument_parser()
parser = utils.add_attack_parameters(parser)
parser.add_argument('--mode', type=str, default='sentence', help='sentence, finetune, biogpt, bioBART')
parser.add_argument('--action', type=str, default='parse', help='parse or extract')
parser.add_argument('--ratio', type = str, default='', help='ratio of the number of changed words')
args = parser.parse_args()
args = utils.set_hyperparams(args)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
utils.seed_all(args.seed)
np.set_printoptions(precision=5)
cudnn.benchmark = False
data_path = os.path.join('processed_data', args.data)
target_path = os.path.join(data_path, 'DD_target_{0}_{1}_{2}_{3}_{4}_{5}.txt'.format(args.model, args.data, args.target_split, args.target_size, 'exists:'+str(args.target_existed), args.attack_goal))
attack_path = os.path.join('attack_results', args.data, 'cos_{0}_{1}_{2}_{3}_{4}_{5}_{6}_{7}.txt'.format(args.model,
args.target_split,
args.target_size,
'exists:'+str(args.target_existed),
args.neighbor_num,
args.candidate_mode,
args.attack_goal,
str(args.reasonable_rate)))
modified_attack_path = os.path.join('attack_results', args.data, 'cos_{0}_{1}_{2}_{3}_{4}_{5}_{6}_{7}{8}.txt'.format(args.model,
args.target_split,
args.target_size,
'exists:'+str(args.target_existed),
args.neighbor_num,
args.candidate_mode,
args.attack_goal,
str(args.reasonable_rate),
args.mode))
attack_data = utils.load_data(attack_path, drop=False)
#%%
with open(os.path.join(data_path, 'entities_reverse_dict.json')) as fl:
id_to_meshid = json.load(fl)
with open(os.path.join(data_path, 'entities_dict.json'), 'r') as fl:
meshid_to_id = json.load(fl)
with open(Parameters.GNBRfile+'entity_raw_name', 'rb') as fl:
entity_raw_name = pkl.load(fl)
with open(Parameters.GNBRfile+'retieve_sentence_through_edgetype', 'rb') as fl:
retieve_sentence_through_edgetype = pkl.load(fl)
with open(Parameters.GNBRfile+'raw_text_of_each_sentence', 'rb') as fl:
raw_text_sen = pkl.load(fl)
with open(Parameters.GNBRfile+'original_entity_raw_name', 'rb') as fl:
full_entity_raw_name = pkl.load(fl)
for k, v in entity_raw_name.items():
assert v in full_entity_raw_name[k]
#find unique
once_set = set()
twice_set = set()
with open('generate_abstract/valid_entity.json', 'r') as fl:
valid_entity = json.load(fl)
valid_entity = set(valid_entity)
good_name = set()
for k, v, in full_entity_raw_name.items():
names = list(v)
for name in names:
# if name == 'in a':
# print(names)
good_name.add(name)
# if name not in once_set:
# once_set.add(name)
# else:
# twice_set.add(name)
# assert 'WNK4' in once_set
# good_name = set.difference(once_set, twice_set)
# assert 'in a' not in good_name
# assert 'STE20' not in good_name
# assert 'STE20' not in valid_entity
# assert 'STE20-related proline-alanine-rich kinase' not in good_name
# assert 'STE20-related proline-alanine-rich kinase' not in valid_entity
# raise Exception
name_to_type = {}
name_to_meshid = {}
for k, v, in full_entity_raw_name.items():
names = list(v)
for name in names:
if name in good_name:
name_to_type[name] = k.split('_')[0]
name_to_meshid[name] = k
import spacy
import networkx as nx
import pprint
def check(p, s):
if p < 1 or p >= len(s):
return True
return not((s[p]>='a' and s[p]<='z') or (s[p]>='A' and s[p]<='Z') or (s[p]>='0' and s[p]<='9'))
def raw_to_format(sen):
text = sen
l = 0
ret = []
while(l < len(text)):
bo =False
if text[l] != ' ':
for i in range(len(text), l, -1): # reversing is important !!!
cc = text[l:i]
if (cc in good_name or cc in valid_entity) and check(l-1, text) and check(i, text):
ret.append(cc.replace(' ', '_'))
l = i
bo = True
break
if not bo:
ret.append(text[l])
l += 1
return ''.join(ret)
if args.mode == 'sentence':
with open(f'generate_abstract/{args.target_split}_{args.reasonable_rate}_chat.json', 'r') as fl:
draft = json.load(fl)
elif args.mode == 'finetune':
with open(f'generate_abstract/{args.target_split}_{args.reasonable_rate}_sentence_finetune.json', 'r') as fl:
draft = json.load(fl)
elif args.mode == 'bioBART':
with open(f'generate_abstract/{args.target_split}_{args.reasonable_rate}{args.ratio}_bioBART_finetune.json', 'r') as fl:
draft = json.load(fl)
elif args.mode == 'biogpt':
with open(f'generate_abstract/{args.target_split}_{args.reasonable_rate}_biogpt.json', 'r') as fl:
draft = json.load(fl)
else:
raise Exception('No!!!')
nlp = spacy.load("en_core_web_sm")
type_set = set()
for aa in range(36):
dependency_sen_dict = retieve_sentence_through_edgetype[aa]['manual']
tmp_dict = retieve_sentence_through_edgetype[aa]['auto']
dependencys = list(dependency_sen_dict.keys()) + list(tmp_dict.keys())
for dependency in dependencys:
dep_list = dependency.split(' ')
for sub_dep in dep_list:
sub_dep_list = sub_dep.split('|')
assert(len(sub_dep_list) == 3)
type_set.add(sub_dep_list[1])
# print('Type:', type_set)
if args.action == 'parse':
# dp_path, sen_list = list(dependency_sen_dict.items())[0]
# check
# paper_id, sen_id = sen_list[0]
# sen = raw_text_sen[paper_id][sen_id]
# doc = nlp(sen['text'])
# print(dp_path, '\n')
# pprint.pprint(sen)
# print()
# for token in doc:
# print((token.head.text, token.text, token.dep_))
out = ''
for k, v_dict in draft.items():
input = v_dict['in']
output = v_dict['out']
if input == '':
continue
output = output.replace('\n', ' ')
doc = nlp(output)
for sen in doc.sents:
out += raw_to_format(sen.text) + '\n'
with open(f'generate_abstract/{args.target_split}_{args.reasonable_rate}_{args.mode}_parsein.txt', 'w') as fl:
fl.write(out)
elif args.action == 'extract':
# dependency_to_type_id = {}
# for k, v in Parameters.edge_type_to_id.items():
# dependency_to_type_id[k] = {}
# for type in v:
# LL = list(retieve_sentence_through_edgetype[type]['manual'].keys()) + list(retieve_sentence_through_edgetype[type]['auto'].keys())
# for dp in LL:
# dependency_to_type_id[k][dp] = type
if os.path.exists('generate_abstract/dependency_to_type_id.pickle'):
with open('generate_abstract/dependency_to_type_id.pickle', 'rb') as fl:
dependency_to_type_id = pkl.load(fl)
else:
dependency_to_type_id = {}
print('Loading path data ...')
for k in Parameters.edge_type_to_id.keys():
start, end = k.split('-')
dependency_to_type_id[k] = {}
inner_edge_type_to_id = Parameters.edge_type_to_id[k]
inner_edge_type_dict = Parameters.edge_type_dict[k]
cal_manual_num = [0] * len(inner_edge_type_to_id)
with open('../GNBRdata/part-i-'+start+'-'+end+'-path-theme-distributions.txt', 'r') as fl:
for i, line in tqdm(list(enumerate(fl.readlines()))):
tmp = line.split('\t')
if i == 0:
head = [tmp[i] for i in range(1, len(tmp), 2)]
assert ' '.join(head) == ' '.join(inner_edge_type_dict[0])
continue
probability = [float(tmp[i]) for i in range(1, len(tmp), 2)]
flag_list = [int(tmp[i]) for i in range(2, len(tmp), 2)]
indices = np.where(np.asarray(flag_list) == 1)[0]
if len(indices) >= 1:
tmp_p = [cal_manual_num[i] for i in indices]
p = indices[np.argmin(tmp_p)]
cal_manual_num[p] += 1
else:
p = np.argmax(probability)
assert tmp[0].lower() not in dependency_to_type_id.keys()
dependency_to_type_id[k][tmp[0].lower()] = inner_edge_type_to_id[p]
with open('generate_abstract/dependency_to_type_id.pickle', 'wb') as fl:
pkl.dump(dependency_to_type_id, fl)
# record = []
# with open(f'generate_abstract/par_parseout.txt', 'r') as fl:
# Tmp = []
# tmp = []
# for i,line in enumerate(fl.readlines()):
# # print(len(line), line)
# line = line.replace('\n', '')
# if len(line) > 1:
# tmp.append(line)
# else:
# Tmp.append(tmp)
# tmp = []
# if len(Tmp) == 3:
# record.append(Tmp)
# Tmp = []
# print(len(record))
# record_index = 0
# add = 0
# Attack = []
# for ii in range(100):
# # input = v_dict['in']
# # output = v_dict['out']
# # output = output.replace('\n', ' ')
# s, r, o = attack_data[ii]
# dependency_sen_dict = retieve_sentence_through_edgetype[int(r)]['manual']
# target_dp = set()
# for dp_path, sen_list in dependency_sen_dict.items():
# target_dp.add(dp_path)
# DP_list = []
# for _ in range(1):
# dp_dict = {}
# data = record[record_index]
# record_index += 1
# dp_paths = data[2]
# nodes_list = []
# edges_list = []
# for line in dp_paths:
# ttp, tmp = line.split('(')
# assert tmp[-1] == ')'
# tmp = tmp[:-1]
# e1, e2 = tmp.split(', ')
# if not ttp in type_set and ':' in ttp:
# ttp = ttp.split(':')[0]
# dp_dict[f'{e1}_x_{e2}'] = [e1, ttp, e2]
# dp_dict[f'{e2}_x_{e1}'] = [e1, ttp, e2]
# nodes_list.append(e1)
# nodes_list.append(e2)
# edges_list.append((e1, e2))
# nodes_list = list(set(nodes_list))
# pure_name = [('-'.join(name.split('-')[:-1])).replace('_', ' ') for name in nodes_list]
# graph = nx.Graph(edges_list)
# type_list = [name_to_type[name] if name in good_name else '' for name in pure_name]
# # print(type_list)
# # for i in range(len(type_list)):
# # print(pure_name[i], type_list[i])
# for i in range(len(nodes_list)):
# if type_list[i] != '':
# for j in range(len(nodes_list)):
# if i != j and type_list[j] != '':
# if f'{type_list[i]}-{type_list[j]}' in Parameters.edge_type_to_id.keys():
# # print(f'{type_list[i]}_{type_list[j]}')
# ret_path = []
# sp = nx.shortest_path(graph, source=nodes_list[i], target=nodes_list[j])
# start = sp[0]
# end = sp[-1]
# for k in range(len(sp)-1):
# e1, ttp, e2 = dp_dict[f'{sp[k]}_x_{sp[k+1]}']
# if e1 == start:
# e1 = 'start_entity-x'
# if e2 == start:
# e2 = 'start_entity-x'
# if e1 == end:
# e1 = 'end_entity-x'
# if e2 == end:
# e2 = 'end_entity-x'
# ret_path.append(f'{"-".join(e1.split("-")[:-1])}|{ttp}|{"-".join(e2.split("-")[:-1])}'.lower())
# dependency_P = ' '.join(ret_path)
# DP_list.append((f'{type_list[i]}-{type_list[j]}',
# name_to_meshid[pure_name[i]],
# name_to_meshid[pure_name[j]],
# dependency_P))
# boo = False
# modified_attack = []
# for k, ss, tt, dp in DP_list:
# if dp in dependency_to_type_id[k].keys():
# tp = str(dependency_to_type_id[k][dp])
# id_ss = str(meshid_to_id[ss])
# id_tt = str(meshid_to_id[tt])
# modified_attack.append(f'{id_ss}*{tp}*{id_tt}')
# if int(dependency_to_type_id[k][dp]) == int(r):
# # if id_to_meshid[s] == ss and id_to_meshid[o] == tt:
# boo = True
# modified_attack = list(set(modified_attack))
# modified_attack = [k.split('*') for k in modified_attack]
# if boo:
# add += 1
# # else:
# # print(ii)
# # for i in range(len(type_list)):
# # if type_list[i]:
# # print(pure_name[i], type_list[i])
# # for k, ss, tt, dp in DP_list:
# # print(k, dp)
# # print(record[record_index - 1])
# # raise Exception('No!!')
# Attack.append(modified_attack)
record = []
with open(f'generate_abstract/{args.target_split}_{args.reasonable_rate}_{args.mode}_parseout.txt', 'r') as fl:
Tmp = []
tmp = []
for i,line in enumerate(fl.readlines()):
# print(len(line), line)
line = line.replace('\n', '')
if len(line) > 1:
tmp.append(line)
else:
if len(Tmp) == 2:
if len(tmp) == 1 and '/' in tmp[0].split(' ')[0]:
Tmp.append([])
record.append(Tmp)
Tmp = []
Tmp.append(tmp)
if len(Tmp) == 2 and tmp[0][:5] != '(ROOT':
print(record[-1][2])
raise Exception('??')
tmp = []
if len(Tmp) == 3:
record.append(Tmp)
Tmp = []
with open(f'generate_abstract/{args.target_split}_{args.reasonable_rate}_{args.mode}_parsein.txt', 'r') as fl:
parsin = fl.readlines()
print('Record len', len(record), 'Parsin len:', len(parsin))
record_index = 0
add = 0
Attack = []
for ii, (k, v_dict) in enumerate(tqdm(draft.items())):
input = v_dict['in']
output = v_dict['out']
output = output.replace('\n', ' ')
s, r, o = attack_data[ii]
assert ii == int(k.split('_')[-1])
DP_list = []
if input != '':
dependency_sen_dict = retieve_sentence_through_edgetype[int(r)]['manual']
target_dp = set()
for dp_path, sen_list in dependency_sen_dict.items():
target_dp.add(dp_path)
doc = nlp(output)
for sen in doc.sents:
dp_dict = {}
if record_index >= len(record):
break
data = record[record_index]
record_index += 1
dp_paths = data[2]
nodes_list = []
edges_list = []
for line in dp_paths:
aa = line.split('(')
if len(aa) == 1:
print(ii)
print(sen)
print(data)
raise Exception
ttp, tmp = aa[0], aa[1]
assert tmp[-1] == ')'
tmp = tmp[:-1]
e1, e2 = tmp.split(', ')
if not ttp in type_set and ':' in ttp:
ttp = ttp.split(':')[0]
dp_dict[f'{e1}_x_{e2}'] = [e1, ttp, e2]
dp_dict[f'{e2}_x_{e1}'] = [e1, ttp, e2]
nodes_list.append(e1)
nodes_list.append(e2)
edges_list.append((e1, e2))
nodes_list = list(set(nodes_list))
pure_name = [('-'.join(name.split('-')[:-1])).replace('_', ' ') for name in nodes_list]
graph = nx.Graph(edges_list)
type_list = [name_to_type[name] if name in good_name else '' for name in pure_name]
# print(type_list)
for i in range(len(nodes_list)):
if type_list[i] != '':
for j in range(len(nodes_list)):
if i != j and type_list[j] != '':
if f'{type_list[i]}-{type_list[j]}' in Parameters.edge_type_to_id.keys():
# print(f'{type_list[i]}_{type_list[j]}')
ret_path = []
sp = nx.shortest_path(graph, source=nodes_list[i], target=nodes_list[j])
start = sp[0]
end = sp[-1]
for k in range(len(sp)-1):
e1, ttp, e2 = dp_dict[f'{sp[k]}_x_{sp[k+1]}']
if e1 == start:
e1 = 'start_entity-x'
if e2 == start:
e2 = 'start_entity-x'
if e1 == end:
e1 = 'end_entity-x'
if e2 == end:
e2 = 'end_entity-x'
ret_path.append(f'{"-".join(e1.split("-")[:-1])}|{ttp}|{"-".join(e2.split("-")[:-1])}'.lower())
dependency_P = ' '.join(ret_path)
DP_list.append((f'{type_list[i]}-{type_list[j]}',
name_to_meshid[pure_name[i]],
name_to_meshid[pure_name[j]],
dependency_P))
boo = False
modified_attack = []
for k, ss, tt, dp in DP_list:
if dp in dependency_to_type_id[k].keys():
tp = str(dependency_to_type_id[k][dp])
id_ss = str(meshid_to_id[ss])
id_tt = str(meshid_to_id[tt])
modified_attack.append(f'{id_ss}*{tp}*{id_tt}')
if int(dependency_to_type_id[k][dp]) == int(r):
if id_to_meshid[s] == ss and id_to_meshid[o] == tt:
boo = True
modified_attack = list(set(modified_attack))
modified_attack = [k.split('*') for k in modified_attack]
if boo:
# print(DP_list)
add += 1
Attack.append(modified_attack)
print(add)
print('End record_index:', record_index)
with open(modified_attack_path, 'wb') as fl:
pkl.dump(Attack, fl)
else:
raise Exception('Wrong action !!')