Spaces:
Runtime error
Runtime error
#%% | |
import logging | |
from symbol import parameters | |
from textwrap import indent | |
import os | |
import tempfile | |
import sys | |
from matplotlib import collections | |
import pandas as pd | |
import json | |
from glob import glob | |
from tqdm import tqdm | |
import numpy as np | |
from pprint import pprint | |
import torch | |
import pickle as pkl | |
from collections import Counter | |
# print(dir(collections)) | |
import networkx as nx | |
from collections import Counter | |
import utils | |
from torch.nn import functional as F | |
sys.path.append("..") | |
import Parameters | |
from DiseaseSpecific.attack import calculate_edge_bound, get_model_loss_without_softmax | |
#%% | |
def load_data(file_name): | |
df = pd.read_csv(file_name, sep='\t', header=None, names=None, dtype=str) | |
df = df.drop_duplicates() | |
return df.values | |
parser = utils.get_argument_parser() | |
parser.add_argument('--reasonable-rate', type = float, default=0.7, help = 'The added edge\'s existance rank prob greater than this rate') | |
parser.add_argument('--init-mode', type = str, default='single', help = 'How to select target nodes') # 'single' for case study | |
parser.add_argument('--added-edge-num', type = str, default = '', help = 'Added edge num') | |
args = parser.parse_args() | |
args = utils.set_hyperparams(args) | |
utils.seed_all(args.seed) | |
graph_edge_path = '../DiseaseSpecific/processed_data/GNBR/all.txt' | |
idtomeshid_path = '../DiseaseSpecific/processed_data/GNBR/entities_reverse_dict.json' | |
model_path = f'../DiseaseSpecific/saved_models/GNBR_{args.model}_128_0.2_0.3_0.3.model' | |
data_path = '../DiseaseSpecific/processed_data/GNBR' | |
with open(Parameters.GNBRfile+'original_entity_raw_name', 'rb') as fl: | |
full_entity_raw_name = pkl.load(fl) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
args.device = device | |
n_ent, n_rel, ent_to_id, rel_to_id = utils.generate_dicts(data_path) | |
model = utils.load_model(model_path, args, n_ent, n_rel, args.device) | |
print(device) | |
graph_edge = utils.load_data(graph_edge_path) | |
with open(idtomeshid_path, 'r') as fl: | |
idtomeshid = json.load(fl) | |
print(graph_edge.shape, len(idtomeshid)) | |
divide_bound, data_mean, data_std = calculate_edge_bound(graph_edge, model, args.device, n_ent) | |
print('Defender ...') | |
print(divide_bound, data_mean, data_std) | |
meshids = list(idtomeshid.values()) | |
cal = { | |
'chemical' : 0, | |
'disease' : 0, | |
'gene' : 0 | |
} | |
for meshid in meshids: | |
cal[meshid.split('_')[0]] += 1 | |
# pprint(cal) | |
def check_reasonable(s, r, o): | |
train_trip = np.asarray([[s, r, o]]) | |
train_trip = torch.from_numpy(train_trip.astype('int64')).to(device) | |
edge_loss = get_model_loss_without_softmax(train_trip, model, device).squeeze() | |
# edge_losse_log_prob = torch.log(F.softmax(-edge_loss, dim = -1)) | |
edge_loss = edge_loss.item() | |
edge_loss = (edge_loss - data_mean) / data_std | |
edge_losses_prob = 1 / ( 1 + np.exp(edge_loss - divide_bound) ) | |
bound = 1 - args.reasonable_rate | |
return (edge_losses_prob > bound), edge_losses_prob | |
edgeid_to_edgetype = {} | |
edgeid_to_reversemask = {} | |
for k, id_list in Parameters.edge_type_to_id.items(): | |
for iid, mask in zip(id_list, Parameters.reverse_mask[k]): | |
edgeid_to_edgetype[str(iid)] = k | |
edgeid_to_reversemask[str(iid)] = mask | |
reverse_tot = 0 | |
G = nx.DiGraph() | |
for s, r, o in graph_edge: | |
assert idtomeshid[s].split('_')[0] == edgeid_to_edgetype[r].split('-')[0] | |
if edgeid_to_reversemask[r] == 1: | |
reverse_tot += 1 | |
G.add_edge(int(o), int(s)) | |
else: | |
G.add_edge(int(s), int(o)) | |
# print(reverse_tot) | |
print('Edge num:', G.number_of_edges(), 'Node num:', G.number_of_nodes()) | |
pagerank_value_1 = nx.pagerank(G, max_iter = 200, tol=1.0e-7) | |
#%% | |
with open(Parameters.UMLSfile+'drug_term', 'rb') as fl: | |
drug_term = pkl.load(fl) | |
with open(Parameters.GNBRfile+'entity_raw_name', 'rb') as fl: | |
entity_raw_name = pkl.load(fl) | |
drug_meshid = [] | |
for meshid, nm in entity_raw_name.items(): | |
if nm.lower() in drug_term and meshid.split('_')[0] == 'chemical': | |
drug_meshid.append(meshid) | |
drug_meshid = set(drug_meshid) | |
pr = list(pagerank_value_1.items()) | |
pr.sort(key = lambda x: x[1]) | |
sorted_rank = { 'chemical' : [], | |
'gene' : [], | |
'disease': [], | |
'merged' : []} | |
for iid, score in pr: | |
tp = idtomeshid[str(iid)].split('_')[0] | |
if tp == 'chemical': | |
if idtomeshid[str(iid)] in drug_meshid: | |
sorted_rank[tp].append((iid, score)) | |
else: | |
sorted_rank[tp].append((iid, score)) | |
sorted_rank['merged'].append((iid, score)) | |
llen = len(sorted_rank['merged']) | |
sorted_rank['merged'] = sorted_rank['merged'][llen * 3 // 4 : ] | |
print(len(sorted_rank['chemical'])) | |
print(len(sorted_rank['gene']), len(sorted_rank['disease']), len(sorted_rank['merged'])) | |
#%% | |
Target_node_list = [] | |
Attack_edge_list = [] | |
if args.init_mode == '': | |
if args.added_edge_num != '' and args.added_edge_num != '1': | |
raise Exception('added_edge_num must be 1 when init_mode=='' ') | |
for init_p in [0.1, 0.3, 0.5, 0.7, 0.9]: | |
p = len(sorted_rank['chemical']) * init_p | |
print('Init p:', init_p) | |
target_node_list = [] | |
attack_edge_list = [] | |
num_max_eq = 0 | |
mean_rank_of_total_max = 0 | |
for pp in tqdm(range(int(p)-10, int(p)+10)): | |
target = sorted_rank['chemical'][pp][0] | |
target_node_list.append(target) | |
candidate_list = [] | |
score_list = [] | |
loss_list = [] | |
for iid, score in sorted_rank['merged']: | |
a = G.number_of_edges(iid, target) + 1 | |
if a != 1: | |
continue | |
b = G.out_degree(iid) + 1 | |
tp = idtomeshid[str(iid)].split('_')[0] | |
edge_losses = [] | |
r_list = [] | |
for r in range(len(edgeid_to_edgetype)): | |
r_tp = edgeid_to_edgetype[str(r)] | |
if (edgeid_to_reversemask[str(r)] == 0 and r_tp.split('-')[0] == tp and r_tp.split('-')[1] == 'chemical'): | |
train_trip = np.array([[iid, r, target]]) | |
train_trip = torch.from_numpy(train_trip.astype('int64')).to(device) | |
edge_loss = get_model_loss_without_softmax(train_trip, model, device).squeeze() | |
edge_losses.append(edge_loss.unsqueeze(0).detach()) | |
r_list.append(r) | |
elif(edgeid_to_reversemask[str(r)] == 1 and r_tp.split('-')[0] == 'chemical' and r_tp.split('-')[1] == tp): | |
train_trip = np.array([[iid, r, target]]) # add batch dim | |
train_trip = torch.from_numpy(train_trip.astype('int64')).to(device) | |
edge_loss = get_model_loss_without_softmax(train_trip, model, device).squeeze() | |
edge_losses.append(edge_loss.unsqueeze(0).detach()) | |
r_list.append(r) | |
if len(edge_losses)==0: | |
continue | |
min_index = torch.argmin(torch.cat(edge_losses, dim = 0)) | |
r = r_list[min_index] | |
r_tp = edgeid_to_edgetype[str(r)] | |
if (edgeid_to_reversemask[str(r)] == 0): | |
bo, prob = check_reasonable(iid, r, target) | |
if bo: | |
candidate_list.append((iid, r, target)) | |
score_list.append(score * a / b) | |
loss_list.append(edge_losses[min_index].item()) | |
if (edgeid_to_reversemask[str(r)] == 1): | |
bo, prob = check_reasonable(target, r, iid) | |
if bo: | |
candidate_list.append((target, r, iid)) | |
score_list.append(score * a / b) | |
loss_list.append(edge_losses[min_index].item()) | |
if len(candidate_list) == 0: | |
attack_edge_list.append((-1, -1, -1)) | |
continue | |
norm_score = np.array(score_list) / np.sum(score_list) | |
norm_loss = np.exp(-np.array(loss_list)) / np.sum(np.exp(-np.array(loss_list))) | |
total_score = norm_score * norm_loss | |
max_index = np.argmax(total_score) | |
attack_edge_list.append(candidate_list[max_index]) | |
score_max_index = np.argmax(norm_score) | |
if score_max_index == max_index: | |
num_max_eq += 1 | |
score_index_list = list(zip(list(range(len(norm_score))), norm_score)) | |
score_index_list.sort(key = lambda x: x[1], reverse = True) | |
max_index_in_score = score_index_list.index((max_index, norm_score[max_index])) | |
mean_rank_of_total_max += max_index_in_score / len(norm_score) | |
print('num_max_eq:', num_max_eq) | |
print('mean_rank_of_total_max:', mean_rank_of_total_max / 20) | |
Target_node_list.append(target_node_list) | |
Attack_edge_list.append(attack_edge_list) | |
else: | |
assert args.init_mode == 'random' or args.init_mode == 'single' | |
print(f'Init mode : {args.init_mode}') | |
utils.seed_all(args.seed) | |
if args.init_mode == 'random': | |
index = np.random.choice(len(sorted_rank['chemical']), 400, replace = False) | |
else: | |
# index = [5807, 6314, 5799, 5831, 3954, 5654, 5649, 5624, 2412, 2407] | |
index = np.random.choice(len(sorted_rank['chemical']), 400, replace = False) | |
with open(f'../pagerank/results/After_distmult_0.7random10.pkl', 'rb') as fl: | |
edge = pkl.load(fl) | |
with open('../pagerank/results/Init_0.7random.pkl', 'rb') as fl: | |
init = pkl.load(fl) | |
increase = (np.array(init) - np.array(edge)) / np.array(init) | |
increase = increase.reshape(-1) | |
selected_index = np.argsort(increase)[::-1][:10] | |
# print(selected_index) | |
# print(increase[selected_index]) | |
# print(np.array(init)[selected_index]) | |
# print(np.array(edge)[selected_index]) | |
index = [index[i] for i in selected_index] | |
# llen = len(sorted_rank['chemical']) | |
# index = np.random.choice(range(llen//4, llen), 4, replace = False) | |
# index = selected_index + list(index) | |
# for i in index: | |
# ii = str(sorted_rank['chemical'][i][0]) | |
# nm = entity_raw_name[idtomeshid[ii]] | |
# nmset = full_entity_raw_name[idtomeshid[ii]] | |
# print('**'*10) | |
# print(i) | |
# print(nm) | |
# print(nmset) | |
# raise Exception('stop') | |
target_node_list = [] | |
attack_edge_list = [] | |
num_max_eq = 0 | |
mean_rank_of_total_max = 0 | |
for pp in tqdm(index): | |
target = sorted_rank['chemical'][pp][0] | |
target_node_list.append(target) | |
print('Target:', entity_raw_name[idtomeshid[str(target)]]) | |
candidate_list = [] | |
score_list = [] | |
loss_list = [] | |
main_dict = {} | |
for iid, score in sorted_rank['merged']: | |
a = G.number_of_edges(iid, target) + 1 | |
if a != 1: | |
continue | |
b = G.out_degree(iid) + 1 | |
tp = idtomeshid[str(iid)].split('_')[0] | |
edge_losses = [] | |
r_list = [] | |
for r in range(len(edgeid_to_edgetype)): | |
r_tp = edgeid_to_edgetype[str(r)] | |
if (edgeid_to_reversemask[str(r)] == 0 and r_tp.split('-')[0] == tp and r_tp.split('-')[1] == 'chemical'): | |
train_trip = np.array([[iid, r, target]]) | |
train_trip = torch.from_numpy(train_trip.astype('int64')).to(device) | |
edge_loss = get_model_loss_without_softmax(train_trip, model, device).squeeze() | |
edge_losses.append(edge_loss.unsqueeze(0).detach()) | |
r_list.append(r) | |
elif(edgeid_to_reversemask[str(r)] == 1 and r_tp.split('-')[0] == 'chemical' and r_tp.split('-')[1] == tp): | |
train_trip = np.array([[iid, r, target]]) # add batch dim | |
train_trip = torch.from_numpy(train_trip.astype('int64')).to(device) | |
edge_loss = get_model_loss_without_softmax(train_trip, model, device).squeeze() | |
edge_losses.append(edge_loss.unsqueeze(0).detach()) | |
r_list.append(r) | |
if len(edge_losses)==0: | |
continue | |
min_index = torch.argmin(torch.cat(edge_losses, dim = 0)) | |
r = r_list[min_index] | |
r_tp = edgeid_to_edgetype[str(r)] | |
old_len = len(candidate_list) | |
if (edgeid_to_reversemask[str(r)] == 0): | |
bo, prob = check_reasonable(iid, r, target) | |
if bo: | |
candidate_list.append((iid, r, target)) | |
score_list.append(score * a / b) | |
loss_list.append(edge_losses[min_index].item()) | |
if (edgeid_to_reversemask[str(r)] == 1): | |
bo, prob = check_reasonable(target, r, iid) | |
if bo: | |
candidate_list.append((target, r, iid)) | |
score_list.append(score * a / b) | |
loss_list.append(edge_losses[min_index].item()) | |
if len(candidate_list) != old_len: | |
if int(iid) in main_iid: | |
main_dict[iid] = len(candidate_list) - 1 | |
if len(candidate_list) == 0: | |
if args.added_edge_num == '' or int(args.added_edge_num) == 1: | |
attack_edge_list.append((-1,-1,-1)) | |
else: | |
attack_edge_list.append([]) | |
continue | |
norm_score = np.array(score_list) / np.sum(score_list) | |
norm_loss = np.exp(-np.array(loss_list)) / np.sum(np.exp(-np.array(loss_list))) | |
total_score = norm_score * norm_loss | |
total_score_index = list(zip(range(len(total_score)), total_score)) | |
total_score_index.sort(key = lambda x: x[1], reverse = True) | |
norm_score_index = np.argsort(norm_score)[::-1] | |
norm_loss_index = np.argsort(norm_loss)[::-1] | |
total_index = np.argsort(total_score)[::-1] | |
assert total_index[0] == total_score_index[0][0] | |
# find rank of main index | |
for k, v in main_dict.items(): | |
k = int(k) | |
index = v | |
print(f'score rank of {entity_raw_name[idtomeshid[str(k)]]}: ', norm_score_index.tolist().index(index)) | |
print(f'loss rank of {entity_raw_name[idtomeshid[str(k)]]}: ', norm_loss_index.tolist().index(index)) | |
print(f'total rank of {entity_raw_name[idtomeshid[str(k)]]}: ', total_index.tolist().index(index)) | |
max_index = np.argmax(total_score) | |
assert max_index == total_score_index[0][0] | |
tmp_add = [] | |
add_num = 1 | |
if args.added_edge_num == '' or int(args.added_edge_num) == 1: | |
attack_edge_list.append(candidate_list[max_index]) | |
else: | |
add_num = int(args.added_edge_num) | |
for i in range(add_num): | |
tmp_add.append(candidate_list[total_score_index[i][0]]) | |
attack_edge_list.append(tmp_add) | |
score_max_index = np.argmax(norm_score) | |
if score_max_index == max_index: | |
num_max_eq += 1 | |
score_index_list = list(zip(list(range(len(norm_score))), norm_score)) | |
score_index_list.sort(key = lambda x: x[1], reverse = True) | |
max_index_in_score = score_index_list.index((max_index, norm_score[max_index])) | |
mean_rank_of_total_max += max_index_in_score / len(norm_score) | |
print('num_max_eq:', num_max_eq) | |
print('mean_rank_of_total_max:', mean_rank_of_total_max / 400) | |
Target_node_list = target_node_list | |
Attack_edge_list = attack_edge_list | |
print(np.array(Target_node_list).shape) | |
print(np.array(Attack_edge_list).shape) | |
# with open(f'processed_data/target_{args.reasonable_rate}{args.init_mode}.pkl', 'wb') as fl: | |
# pkl.dump(Target_node_list, fl) | |
# with open(f'processed_data/attack_edge_{args.model}_{args.reasonable_rate}{args.init_mode}{args.added_edge_num}.pkl', 'wb') as fl: | |
# pkl.dump(Attack_edge_list, fl) |