Spaces:
Runtime error
Runtime error
#%% | |
import torch | |
import numpy as np | |
from torch.autograd import Variable | |
from sklearn import metrics | |
import datetime | |
from typing import Dict, Tuple, List | |
import logging | |
import os | |
import utils | |
import pickle as pkl | |
import json | |
import torch.backends.cudnn as cudnn | |
from tqdm import tqdm | |
import sys | |
sys.path.append("..") | |
import Parameters | |
parser = utils.get_argument_parser() | |
parser.add_argument('--reasonable-rate', type = float, default=0.7, help = 'The added edge\'s existance rank prob greater than this rate') | |
parser.add_argument('--mode', type=str, default='sentence', help='sentence, finetune, biogpt, bioBART') | |
parser.add_argument('--action', type=str, default='parse', help='parse or extract') | |
parser.add_argument('--init-mode', type = str, default='random', help = 'How to select target nodes') | |
parser.add_argument('--ratio', type = str, default='', help='ratio of the number of changed words') | |
args = parser.parse_args() | |
args = utils.set_hyperparams(args) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
utils.seed_all(args.seed) | |
np.set_printoptions(precision=5) | |
cudnn.benchmark = False | |
data_path = '../DiseaseSpecific/processed_data/GNBR' | |
target_path = f'processed_data/target_{args.reasonable_rate}{args.init_mode}.pkl' | |
attack_path = f'processed_data/attack_edge_{args.model}_{args.reasonable_rate}{args.init_mode}.pkl' | |
modified_attack_path = f'processed_data/attack_edge_{args.model}_{args.reasonable_rate}{args.init_mode}{args.mode}.pkl' | |
with open(attack_path, 'rb') as fl: | |
Attack_edge_list = pkl.load(fl) | |
attack_data = np.array(Attack_edge_list).reshape(-1, 3) | |
#%% | |
with open(os.path.join(data_path, 'entities_reverse_dict.json')) as fl: | |
id_to_meshid = json.load(fl) | |
with open(os.path.join(data_path, 'entities_dict.json'), 'r') as fl: | |
meshid_to_id = json.load(fl) | |
with open(Parameters.GNBRfile+'entity_raw_name', 'rb') as fl: | |
entity_raw_name = pkl.load(fl) | |
with open(Parameters.GNBRfile+'retieve_sentence_through_edgetype', 'rb') as fl: | |
retieve_sentence_through_edgetype = pkl.load(fl) | |
with open(Parameters.GNBRfile+'raw_text_of_each_sentence', 'rb') as fl: | |
raw_text_sen = pkl.load(fl) | |
with open(Parameters.GNBRfile+'original_entity_raw_name', 'rb') as fl: | |
full_entity_raw_name = pkl.load(fl) | |
for k, v in entity_raw_name.items(): | |
assert v in full_entity_raw_name[k] | |
#find unique | |
once_set = set() | |
twice_set = set() | |
with open('../DiseaseSpecific/generate_abstract/valid_entity.json', 'r') as fl: | |
valid_entity = json.load(fl) | |
valid_entity = set(valid_entity) | |
good_name = set() | |
for k, v, in full_entity_raw_name.items(): | |
names = list(v) | |
for name in names: | |
# if name == 'in a': | |
# print(names) | |
good_name.add(name) | |
# if name not in once_set: | |
# once_set.add(name) | |
# else: | |
# twice_set.add(name) | |
# assert 'WNK4' in once_set | |
# good_name = set.difference(once_set, twice_set) | |
# assert 'in a' not in good_name | |
# assert 'STE20' not in good_name | |
# assert 'STE20' not in valid_entity | |
# assert 'STE20-related proline-alanine-rich kinase' not in good_name | |
# assert 'STE20-related proline-alanine-rich kinase' not in valid_entity | |
# raise Exception | |
name_to_type = {} | |
name_to_meshid = {} | |
for k, v, in full_entity_raw_name.items(): | |
names = list(v) | |
for name in names: | |
if name in good_name: | |
name_to_type[name] = k.split('_')[0] | |
name_to_meshid[name] = k | |
import spacy | |
import networkx as nx | |
import pprint | |
def check(p, s): | |
if p < 1 or p >= len(s): | |
return True | |
return not((s[p]>='a' and s[p]<='z') or (s[p]>='A' and s[p]<='Z') or (s[p]>='0' and s[p]<='9')) | |
def raw_to_format(sen): | |
text = sen | |
l = 0 | |
ret = [] | |
while(l < len(text)): | |
bo =False | |
if text[l] != ' ': | |
for i in range(len(text), l, -1): # reversing is important !!! | |
cc = text[l:i] | |
if (cc in good_name or cc in valid_entity) and check(l-1, text) and check(i, text): | |
ret.append(cc.replace(' ', '_')) | |
l = i | |
bo = True | |
break | |
if not bo: | |
ret.append(text[l]) | |
l += 1 | |
return ''.join(ret) | |
if args.mode == 'sentence': | |
with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}_chat.json', 'r') as fl: | |
draft = json.load(fl) | |
elif args.mode == 'finetune': | |
with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}_sentence_finetune.json', 'r') as fl: | |
draft = json.load(fl) | |
elif args.mode == 'bioBART': | |
with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}{args.ratio}_bioBART_finetune.json', 'r') as fl: | |
draft = json.load(fl) | |
elif args.mode == 'biogpt': | |
with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}_biogpt.json', 'r') as fl: | |
draft = json.load(fl) | |
else: | |
raise Exception('No!!!') | |
nlp = spacy.load("en_core_web_sm") | |
type_set = set() | |
for aa in range(36): | |
dependency_sen_dict = retieve_sentence_through_edgetype[aa]['manual'] | |
tmp_dict = retieve_sentence_through_edgetype[aa]['auto'] | |
dependencys = list(dependency_sen_dict.keys()) + list(tmp_dict.keys()) | |
for dependency in dependencys: | |
dep_list = dependency.split(' ') | |
for sub_dep in dep_list: | |
sub_dep_list = sub_dep.split('|') | |
assert(len(sub_dep_list) == 3) | |
type_set.add(sub_dep_list[1]) | |
# print('Type:', type_set) | |
if args.action == 'parse': | |
# dp_path, sen_list = list(dependency_sen_dict.items())[0] | |
# check | |
# paper_id, sen_id = sen_list[0] | |
# sen = raw_text_sen[paper_id][sen_id] | |
# doc = nlp(sen['text']) | |
# print(dp_path, '\n') | |
# pprint.pprint(sen) | |
# print() | |
# for token in doc: | |
# print((token.head.text, token.text, token.dep_)) | |
out = '' | |
for k, v_dict in draft.items(): | |
input = v_dict['in'] | |
output = v_dict['out'] | |
if input == '': | |
continue | |
output = output.replace('\n', ' ') | |
doc = nlp(output) | |
for sen in doc.sents: | |
out += raw_to_format(sen.text) + '\n' | |
with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}_{args.mode}_parsein.txt', 'w') as fl: | |
fl.write(out) | |
elif args.action == 'extract': | |
# dependency_to_type_id = {} | |
# for k, v in Parameters.edge_type_to_id.items(): | |
# dependency_to_type_id[k] = {} | |
# for type in v: | |
# LL = list(retieve_sentence_through_edgetype[type]['manual'].keys()) + list(retieve_sentence_through_edgetype[type]['auto'].keys()) | |
# for dp in LL: | |
# dependency_to_type_id[k][dp] = type | |
if os.path.exists('generate_abstract/dependency_to_type_id.pickle'): | |
with open('generate_abstract/dependency_to_type_id.pickle', 'rb') as fl: | |
dependency_to_type_id = pkl.load(fl) | |
else: | |
dependency_to_type_id = {} | |
print('Loading path data ...') | |
for k in Parameters.edge_type_to_id.keys(): | |
start, end = k.split('-') | |
dependency_to_type_id[k] = {} | |
inner_edge_type_to_id = Parameters.edge_type_to_id[k] | |
inner_edge_type_dict = Parameters.edge_type_dict[k] | |
cal_manual_num = [0] * len(inner_edge_type_to_id) | |
with open('../GNBRdata/part-i-'+start+'-'+end+'-path-theme-distributions.txt', 'r') as fl: | |
for i, line in tqdm(list(enumerate(fl.readlines()))): | |
tmp = line.split('\t') | |
if i == 0: | |
head = [tmp[i] for i in range(1, len(tmp), 2)] | |
assert ' '.join(head) == ' '.join(inner_edge_type_dict[0]) | |
continue | |
probability = [float(tmp[i]) for i in range(1, len(tmp), 2)] | |
flag_list = [int(tmp[i]) for i in range(2, len(tmp), 2)] | |
indices = np.where(np.asarray(flag_list) == 1)[0] | |
if len(indices) >= 1: | |
tmp_p = [cal_manual_num[i] for i in indices] | |
p = indices[np.argmin(tmp_p)] | |
cal_manual_num[p] += 1 | |
else: | |
p = np.argmax(probability) | |
assert tmp[0].lower() not in dependency_to_type_id.keys() | |
dependency_to_type_id[k][tmp[0].lower()] = inner_edge_type_to_id[p] | |
with open('generate_abstract/dependency_to_type_id.pickle', 'wb') as fl: | |
pkl.dump(dependency_to_type_id, fl) | |
# record = [] | |
# with open(f'generate_abstract/par_parseout.txt', 'r') as fl: | |
# Tmp = [] | |
# tmp = [] | |
# for i,line in enumerate(fl.readlines()): | |
# # print(len(line), line) | |
# line = line.replace('\n', '') | |
# if len(line) > 1: | |
# tmp.append(line) | |
# else: | |
# Tmp.append(tmp) | |
# tmp = [] | |
# if len(Tmp) == 3: | |
# record.append(Tmp) | |
# Tmp = [] | |
# print(len(record)) | |
# record_index = 0 | |
# add = 0 | |
# Attack = [] | |
# for ii in range(100): | |
# # input = v_dict['in'] | |
# # output = v_dict['out'] | |
# # output = output.replace('\n', ' ') | |
# s, r, o = attack_data[ii] | |
# dependency_sen_dict = retieve_sentence_through_edgetype[int(r)]['manual'] | |
# target_dp = set() | |
# for dp_path, sen_list in dependency_sen_dict.items(): | |
# target_dp.add(dp_path) | |
# DP_list = [] | |
# for _ in range(1): | |
# dp_dict = {} | |
# data = record[record_index] | |
# record_index += 1 | |
# dp_paths = data[2] | |
# nodes_list = [] | |
# edges_list = [] | |
# for line in dp_paths: | |
# ttp, tmp = line.split('(') | |
# assert tmp[-1] == ')' | |
# tmp = tmp[:-1] | |
# e1, e2 = tmp.split(', ') | |
# if not ttp in type_set and ':' in ttp: | |
# ttp = ttp.split(':')[0] | |
# dp_dict[f'{e1}_x_{e2}'] = [e1, ttp, e2] | |
# dp_dict[f'{e2}_x_{e1}'] = [e1, ttp, e2] | |
# nodes_list.append(e1) | |
# nodes_list.append(e2) | |
# edges_list.append((e1, e2)) | |
# nodes_list = list(set(nodes_list)) | |
# pure_name = [('-'.join(name.split('-')[:-1])).replace('_', ' ') for name in nodes_list] | |
# graph = nx.Graph(edges_list) | |
# type_list = [name_to_type[name] if name in good_name else '' for name in pure_name] | |
# # print(type_list) | |
# # for i in range(len(type_list)): | |
# # print(pure_name[i], type_list[i]) | |
# for i in range(len(nodes_list)): | |
# if type_list[i] != '': | |
# for j in range(len(nodes_list)): | |
# if i != j and type_list[j] != '': | |
# if f'{type_list[i]}-{type_list[j]}' in Parameters.edge_type_to_id.keys(): | |
# # print(f'{type_list[i]}_{type_list[j]}') | |
# ret_path = [] | |
# sp = nx.shortest_path(graph, source=nodes_list[i], target=nodes_list[j]) | |
# start = sp[0] | |
# end = sp[-1] | |
# for k in range(len(sp)-1): | |
# e1, ttp, e2 = dp_dict[f'{sp[k]}_x_{sp[k+1]}'] | |
# if e1 == start: | |
# e1 = 'start_entity-x' | |
# if e2 == start: | |
# e2 = 'start_entity-x' | |
# if e1 == end: | |
# e1 = 'end_entity-x' | |
# if e2 == end: | |
# e2 = 'end_entity-x' | |
# ret_path.append(f'{"-".join(e1.split("-")[:-1])}|{ttp}|{"-".join(e2.split("-")[:-1])}'.lower()) | |
# dependency_P = ' '.join(ret_path) | |
# DP_list.append((f'{type_list[i]}-{type_list[j]}', | |
# name_to_meshid[pure_name[i]], | |
# name_to_meshid[pure_name[j]], | |
# dependency_P)) | |
# boo = False | |
# modified_attack = [] | |
# for k, ss, tt, dp in DP_list: | |
# if dp in dependency_to_type_id[k].keys(): | |
# tp = str(dependency_to_type_id[k][dp]) | |
# id_ss = str(meshid_to_id[ss]) | |
# id_tt = str(meshid_to_id[tt]) | |
# modified_attack.append(f'{id_ss}*{tp}*{id_tt}') | |
# if int(dependency_to_type_id[k][dp]) == int(r): | |
# # if id_to_meshid[s] == ss and id_to_meshid[o] == tt: | |
# boo = True | |
# modified_attack = list(set(modified_attack)) | |
# modified_attack = [k.split('*') for k in modified_attack] | |
# if boo: | |
# add += 1 | |
# # else: | |
# # print(ii) | |
# # for i in range(len(type_list)): | |
# # if type_list[i]: | |
# # print(pure_name[i], type_list[i]) | |
# # for k, ss, tt, dp in DP_list: | |
# # print(k, dp) | |
# # print(record[record_index - 1]) | |
# # raise Exception('No!!') | |
# Attack.append(modified_attack) | |
record = [] | |
with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}_{args.mode}_parseout.txt', 'r') as fl: | |
Tmp = [] | |
tmp = [] | |
for i,line in enumerate(fl.readlines()): | |
# print(len(line), line) | |
line = line.replace('\n', '') | |
if len(line) > 1: | |
tmp.append(line) | |
else: | |
if len(Tmp) == 2: | |
if len(tmp) == 1 and '/' in tmp[0].split(' ')[0]: | |
Tmp.append([]) | |
record.append(Tmp) | |
Tmp = [] | |
Tmp.append(tmp) | |
if len(Tmp) == 2 and tmp[0][:5] != '(ROOT': | |
print(record[-1][2]) | |
raise Exception('??') | |
tmp = [] | |
if len(Tmp) == 3: | |
record.append(Tmp) | |
Tmp = [] | |
with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}_{args.mode}_parsein.txt', 'r') as fl: | |
parsin = fl.readlines() | |
print('Record len', len(record), 'Parsin len:', len(parsin)) | |
record_index = 0 | |
add = 0 | |
Attack = [] | |
for ii, (k, v_dict) in enumerate(tqdm(draft.items())): | |
input = v_dict['in'] | |
output = v_dict['out'] | |
output = output.replace('\n', ' ') | |
s, r, o = attack_data[ii] | |
s = str(s) | |
r = str(r) | |
o = str(o) | |
assert ii == int(k.split('_')[-1]) | |
DP_list = [] | |
if input != '': | |
dependency_sen_dict = retieve_sentence_through_edgetype[int(r)]['manual'] | |
target_dp = set() | |
for dp_path, sen_list in dependency_sen_dict.items(): | |
target_dp.add(dp_path) | |
doc = nlp(output) | |
for sen in doc.sents: | |
dp_dict = {} | |
if record_index >= len(record): | |
break | |
data = record[record_index] | |
record_index += 1 | |
dp_paths = data[2] | |
nodes_list = [] | |
edges_list = [] | |
for line in dp_paths: | |
aa = line.split('(') | |
if len(aa) == 1: | |
print(ii) | |
print(sen) | |
print(data) | |
raise Exception | |
ttp, tmp = aa[0], aa[1] | |
assert tmp[-1] == ')' | |
tmp = tmp[:-1] | |
e1, e2 = tmp.split(', ') | |
if not ttp in type_set and ':' in ttp: | |
ttp = ttp.split(':')[0] | |
dp_dict[f'{e1}_x_{e2}'] = [e1, ttp, e2] | |
dp_dict[f'{e2}_x_{e1}'] = [e1, ttp, e2] | |
nodes_list.append(e1) | |
nodes_list.append(e2) | |
edges_list.append((e1, e2)) | |
nodes_list = list(set(nodes_list)) | |
pure_name = [('-'.join(name.split('-')[:-1])).replace('_', ' ') for name in nodes_list] | |
graph = nx.Graph(edges_list) | |
type_list = [name_to_type[name] if name in good_name else '' for name in pure_name] | |
# print(type_list) | |
for i in range(len(nodes_list)): | |
if type_list[i] != '': | |
for j in range(len(nodes_list)): | |
if i != j and type_list[j] != '': | |
if f'{type_list[i]}-{type_list[j]}' in Parameters.edge_type_to_id.keys(): | |
# print(f'{type_list[i]}_{type_list[j]}') | |
ret_path = [] | |
sp = nx.shortest_path(graph, source=nodes_list[i], target=nodes_list[j]) | |
start = sp[0] | |
end = sp[-1] | |
for k in range(len(sp)-1): | |
e1, ttp, e2 = dp_dict[f'{sp[k]}_x_{sp[k+1]}'] | |
if e1 == start: | |
e1 = 'start_entity-x' | |
if e2 == start: | |
e2 = 'start_entity-x' | |
if e1 == end: | |
e1 = 'end_entity-x' | |
if e2 == end: | |
e2 = 'end_entity-x' | |
ret_path.append(f'{"-".join(e1.split("-")[:-1])}|{ttp}|{"-".join(e2.split("-")[:-1])}'.lower()) | |
dependency_P = ' '.join(ret_path) | |
DP_list.append((f'{type_list[i]}-{type_list[j]}', | |
name_to_meshid[pure_name[i]], | |
name_to_meshid[pure_name[j]], | |
dependency_P)) | |
boo = False | |
modified_attack = [] | |
for k, ss, tt, dp in DP_list: | |
if dp in dependency_to_type_id[k].keys(): | |
tp = str(dependency_to_type_id[k][dp]) | |
id_ss = str(meshid_to_id[ss]) | |
id_tt = str(meshid_to_id[tt]) | |
modified_attack.append(f'{id_ss}*{tp}*{id_tt}') | |
if int(dependency_to_type_id[k][dp]) == int(r): | |
if id_to_meshid[s] == ss and id_to_meshid[o] == tt: | |
boo = True | |
modified_attack = list(set(modified_attack)) | |
modified_attack = [k.split('*') for k in modified_attack] | |
if boo: | |
# print(DP_list) | |
add += 1 | |
Attack.append(modified_attack) | |
print(add) | |
print('End record_index:', record_index) | |
final_Attack = Attack | |
print('Len of Attack:', len(Attack)) | |
with open(modified_attack_path, 'wb') as fl: | |
pkl.dump(final_Attack, fl) | |
else: | |
raise Exception('Wrong action !!') |